[mlpack-git] master: Refactor Serialize(), add backwards compatibility, and update tests. (f989f1f)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Jun 1 13:54:29 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eba4f9924694bc10daec74ff5059dbb8af001416...e3a23c256f017ebb8185b15847c82f51d359cdfd
>---------------------------------------------------------------
commit f989f1f9ddcba8d30e36f43d407768ee6cd78623
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jun 1 08:48:30 2016 -0700
Refactor Serialize(), add backwards compatibility, and update tests.
>---------------------------------------------------------------
f989f1f9ddcba8d30e36f43d407768ee6cd78623
src/mlpack/methods/lsh/lsh_search.hpp | 6 +++++-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 21 ++++++++++++++++++---
src/mlpack/tests/serialization_test.cpp | 4 ++--
3 files changed, 25 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 89f0f92..d3bc2f9 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -151,7 +151,7 @@ class LSHSearch
* @param ar Archive to serialize to.
*/
template<typename Archive>
- void Serialize(Archive& ar, const unsigned int /* version */);
+ void Serialize(Archive& ar, const unsigned int version);
//! Return the number of distance evaluations performed.
size_t DistanceEvaluations() const { return distanceEvaluations; }
@@ -327,6 +327,10 @@ class LSHSearch
} // namespace neighbor
} // namespace mlpack
+//! Set the serialization version of the LSHSearch class.
+BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
+ mlpack::neighbor::LSHSearch<SortPolicy>, 1);
+
// Include implementation.
#include "lsh_search_impl.hpp"
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 14a86e1..afeaf05 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -524,7 +524,7 @@ void LSHSearch<SortPolicy>::BuildHash(const arma::cube &projection)
template<typename SortPolicy>
template<typename Archive>
void LSHSearch<SortPolicy>::Serialize(Archive& ar,
- const unsigned int /* version */)
+ const unsigned int version)
{
using data::CreateNVP;
@@ -542,9 +542,24 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
// Delete existing projections, if necessary.
if (Archive::is_loading::value)
- projections.zeros(0, 0, 0); // TODO: correct way to clear this?
+ projections.reset();
+
+ // Backward compatibility: older version of LSHSearch stored the projection
+ // tables in a std::vector<arma::mat>.
+ if (version == 0)
+ {
+ std::vector<arma::mat> tmpProj;
+ ar & CreateNVP(tmpProj, "projections");
+
+ projections.set_size(tmpProj[0].n_rows, tmpProj[0].n_cols, tmpProj.size());
+ for (size_t i = 0; i < tmpProj.size(); ++i)
+ projections.slice(i) = tmpProj[i];
+ }
+ else
+ {
+ ar & CreateNVP(projections, "projections");
+ }
- ar & CreateNVP(projections, "projections");
ar & CreateNVP(offsets, "offsets");
ar & CreateNVP(hashWidth, "hashWidth");
ar & CreateNVP(secondHashSize, "secondHashSize");
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 9bddbc2..1753e67 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1210,8 +1210,8 @@ BOOST_AUTO_TEST_CASE(LSHTest)
BOOST_REQUIRE_EQUAL(lsh.NumProjections(), binaryLsh.NumProjections());
for (size_t i = 0; i < lsh.NumProjections(); ++i)
{
- CheckMatrices(lsh.Projection(i), xmlLsh.Projection(i),
- textLsh.Projection(i), binaryLsh.Projection(i));
+ CheckMatrices(lsh.Projections().slice(i), xmlLsh.Projections().slice(i),
+ textLsh.Projections().slice(i), binaryLsh.Projections().slice(i));
}
CheckMatrices(lsh.ReferenceSet(), xmlLsh.ReferenceSet(),
More information about the mlpack-git
mailing list