[mlpack-git] master: Add Serialize() and test for LSH. (05cb18d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:38 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d

>---------------------------------------------------------------

commit 05cb18dc6a0bc815aa4a92d4d90eee5c6319e1a7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Nov 20 22:32:41 2015 +0000

    Add Serialize() and test for LSH.


>---------------------------------------------------------------

05cb18dc6a0bc815aa4a92d4d90eee5c6319e1a7
 src/mlpack/methods/lsh/lsh_search.hpp      |  8 +++++
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 35 ++++++++++++++++++++++
 src/mlpack/tests/serialization_test.cpp    | 47 ++++++++++++++++++++++++++++++
 3 files changed, 90 insertions(+)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index e14b142..01f7744 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -141,6 +141,14 @@ class LSHSearch
               arma::mat& distances,
               const size_t numTablesToSearch = 0);
 
+  /**
+   * Serialize the LSH model.
+   *
+   * @param ar Archive to serialize to.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
   //! Returns a string representation of this object.
   std::string ToString() const;
 
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index eeeb504..f74c11f 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -498,6 +498,41 @@ void LSHSearch<SortPolicy>::BuildHash()
 }
 
 template<typename SortPolicy>
+template<typename Archive>
+void LSHSearch<SortPolicy>::Serialize(Archive& ar,
+                                      const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // If we are loading, we are going to own the reference set.
+  if (Archive::is_loading::value)
+  {
+    if (ownsSet)
+      delete referenceSet;
+    ownsSet = true;
+  }
+  ar & CreateNVP(referenceSet, "referenceSet");
+
+  ar & CreateNVP(numProj, "numProj");
+  ar & CreateNVP(numTables, "numTables");
+
+  // Delete existing projections, if necessary.
+  if (Archive::is_loading::value)
+    projections.clear();
+
+  ar & CreateNVP(projections, "projections");
+  ar & CreateNVP(offsets, "offsets");
+  ar & CreateNVP(hashWidth, "hashWidth");
+  ar & CreateNVP(secondHashSize, "secondHashSize");
+  ar & CreateNVP(secondHashWeights, "secondHashWeights");
+  ar & CreateNVP(bucketSize, "bucketSize");
+  ar & CreateNVP(secondHashTable, "secondHashTable");
+  ar & CreateNVP(bucketContentSize, "bucketContentSize");
+  ar & CreateNVP(bucketRowInHashTable, "bucketRowInHashTable");
+  ar & CreateNVP(distanceEvaluations, "distanceEvaluations");
+}
+
+template<typename SortPolicy>
 std::string LSHSearch<SortPolicy>::ToString() const
 {
   std::ostringstream convert;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index b89f8ce..d42e6d5 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -31,6 +31,7 @@
 #include <mlpack/methods/det/dtree.hpp>
 #include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
 #include <mlpack/methods/rann/ra_search.hpp>
+#include <mlpack/methods/lsh/lsh_search.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -41,6 +42,7 @@ using namespace mlpack::tree;
 using namespace mlpack::perceptron;
 using namespace mlpack::regression;
 using namespace mlpack::naive_bayes;
+using namespace mlpack::neighbor;
 
 using namespace arma;
 using namespace boost;
@@ -1382,4 +1384,49 @@ BOOST_AUTO_TEST_CASE(RASearchTest)
   BOOST_REQUIRE_GT(textCorrect, 95 * 5);
 }
 
+/**
+ * Test that an LSH model can be serialized and deserialized.
+ */
+BOOST_AUTO_TEST_CASE(LSHTest)
+{
+  // Since we still don't have good tests for LSH, basically what we're going to
+  // do is serialize an LSH model, and make sure we can deserialize it and that
+  // we still get results when we call Search().
+  arma::mat referenceData = arma::randu<arma::mat>(10, 100);
+
+  LSHSearch<> lsh(referenceData, 5, 10); // Arbitrary chosen parameters.
+
+  LSHSearch<> xmlLsh;
+  arma::mat textData = arma::randu<arma::mat>(5, 50);
+  LSHSearch<> textLsh(textData, 4, 5);
+  LSHSearch<> binaryLsh(referenceData, 15, 2);
+
+  // Now serialize.
+  SerializeObjectAll(lsh, xmlLsh, textLsh, binaryLsh);
+
+  // Check what we can about the serialized objects.
+  BOOST_REQUIRE_EQUAL(lsh.NumProjections(), xmlLsh.NumProjections());
+  BOOST_REQUIRE_EQUAL(lsh.NumProjections(), textLsh.NumProjections());
+  BOOST_REQUIRE_EQUAL(lsh.NumProjections(), binaryLsh.NumProjections());
+  for (size_t i = 0; i < lsh.NumProjections(); ++i)
+  {
+    CheckMatrices(lsh.Projection(i), xmlLsh.Projection(i),
+        textLsh.Projection(i), binaryLsh.Projection(i));
+  }
+
+  CheckMatrices(lsh.ReferenceSet(), xmlLsh.ReferenceSet(),
+      textLsh.ReferenceSet(), binaryLsh.ReferenceSet());
+  CheckMatrices(lsh.Offsets(), xmlLsh.Offsets(), textLsh.Offsets(),
+      binaryLsh.Offsets());
+  CheckMatrices(lsh.SecondHashWeights(), xmlLsh.SecondHashWeights(),
+      textLsh.SecondHashWeights(), binaryLsh.SecondHashWeights());
+
+  BOOST_REQUIRE_EQUAL(lsh.BucketSize(), xmlLsh.BucketSize());
+  BOOST_REQUIRE_EQUAL(lsh.BucketSize(), textLsh.BucketSize());
+  BOOST_REQUIRE_EQUAL(lsh.BucketSize(), binaryLsh.BucketSize());
+
+  CheckMatrices(lsh.SecondHashTable(), xmlLsh.SecondHashTable(),
+      textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list