[mlpack-git] master: Refactor Serialize(), add backwards compatibility, and update tests. (f989f1f)

gitdub at mlpack.org gitdub at mlpack.org
Wed Jun 1 13:54:29 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit f989f1f9ddcba8d30e36f43d407768ee6cd78623
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jun 1 08:48:30 2016 -0700

    Refactor Serialize(), add backwards compatibility, and update tests.


>---------------------------------------------------------------

f989f1f9ddcba8d30e36f43d407768ee6cd78623
 src/mlpack/methods/lsh/lsh_search.hpp      |  6 +++++-
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 21 ++++++++++++++++++---
 src/mlpack/tests/serialization_test.cpp    |  4 ++--
 3 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 89f0f92..d3bc2f9 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -151,7 +151,7 @@ class LSHSearch
    * @param ar Archive to serialize to.
    */
   template<typename Archive>
-  void Serialize(Archive& ar, const unsigned int /* version */);
+  void Serialize(Archive& ar, const unsigned int version);
 
   //! Return the number of distance evaluations performed.
   size_t DistanceEvaluations() const { return distanceEvaluations; }
@@ -327,6 +327,10 @@ class LSHSearch
 } // namespace neighbor
 } // namespace mlpack
 
+//! Set the serialization version of the LSHSearch class.
+BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
+    mlpack::neighbor::LSHSearch<SortPolicy>, 1);
+
 // Include implementation.
 #include "lsh_search_impl.hpp"
 
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 14a86e1..afeaf05 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -524,7 +524,7 @@ void LSHSearch<SortPolicy>::BuildHash(const arma::cube &projection)
 template<typename SortPolicy>
 template<typename Archive>
 void LSHSearch<SortPolicy>::Serialize(Archive& ar,
-                                      const unsigned int /* version */)
+                                      const unsigned int version)
 {
   using data::CreateNVP;
 
@@ -542,9 +542,24 @@ void LSHSearch<SortPolicy>::Serialize(Archive& ar,
 
   // Delete existing projections, if necessary.
   if (Archive::is_loading::value)
-    projections.zeros(0, 0, 0); // TODO: correct way to clear this?
+    projections.reset();
+
+  // Backward compatibility: older version of LSHSearch stored the projection
+  // tables in a std::vector<arma::mat>.
+  if (version == 0)
+  {
+    std::vector<arma::mat> tmpProj;
+    ar & CreateNVP(tmpProj, "projections");
+
+    projections.set_size(tmpProj[0].n_rows, tmpProj[0].n_cols, tmpProj.size());
+    for (size_t i = 0; i < tmpProj.size(); ++i)
+      projections.slice(i) = tmpProj[i];
+  }
+  else
+  {
+    ar & CreateNVP(projections, "projections");
+  }
 
-  ar & CreateNVP(projections, "projections");
   ar & CreateNVP(offsets, "offsets");
   ar & CreateNVP(hashWidth, "hashWidth");
   ar & CreateNVP(secondHashSize, "secondHashSize");
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 9bddbc2..1753e67 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1210,8 +1210,8 @@ BOOST_AUTO_TEST_CASE(LSHTest)
   BOOST_REQUIRE_EQUAL(lsh.NumProjections(), binaryLsh.NumProjections());
   for (size_t i = 0; i < lsh.NumProjections(); ++i)
   {
-    CheckMatrices(lsh.Projection(i), xmlLsh.Projection(i),
-        textLsh.Projection(i), binaryLsh.Projection(i));
+    CheckMatrices(lsh.Projections().slice(i), xmlLsh.Projections().slice(i),
+        textLsh.Projections().slice(i), binaryLsh.Projections().slice(i));
   }
 
   CheckMatrices(lsh.ReferenceSet(), xmlLsh.ReferenceSet(),




More information about the mlpack-git mailing list