[mlpack-git] master: Add Serialize() and test for LSH. (05cb18d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:38 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d
>---------------------------------------------------------------
commit 05cb18dc6a0bc815aa4a92d4d90eee5c6319e1a7
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Nov 20 22:32:41 2015 +0000
Add Serialize() and test for LSH.
>---------------------------------------------------------------
05cb18dc6a0bc815aa4a92d4d90eee5c6319e1a7
src/mlpack/methods/lsh/lsh_search.hpp | 8 +++++
src/mlpack/methods/lsh/lsh_search_impl.hpp | 35 ++++++++++++++++++++++
src/mlpack/tests/serialization_test.cpp | 47 ++++++++++++++++++++++++++++++
3 files changed, 90 insertions(+)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index e14b142..01f7744 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -141,6 +141,14 @@ class LSHSearch
arma::mat& distances,
const size_t numTablesToSearch = 0);
+ /**
+ * Serialize the LSH model.
+ *
+ * @param ar Archive to serialize to.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
//! Returns a string representation of this object.
std::string ToString() const;
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index eeeb504..f74c11f 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -498,6 +498,41 @@ void LSHSearch<SortPolicy>::BuildHash()
}
template<typename SortPolicy>
+template<typename Archive>
+void LSHSearch<SortPolicy>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // If we are loading, we are going to own the reference set.
+ if (Archive::is_loading::value)
+ {
+ if (ownsSet)
+ delete referenceSet;
+ ownsSet = true;
+ }
+ ar & CreateNVP(referenceSet, "referenceSet");
+
+ ar & CreateNVP(numProj, "numProj");
+ ar & CreateNVP(numTables, "numTables");
+
+ // Delete existing projections, if necessary.
+ if (Archive::is_loading::value)
+ projections.clear();
+
+ ar & CreateNVP(projections, "projections");
+ ar & CreateNVP(offsets, "offsets");
+ ar & CreateNVP(hashWidth, "hashWidth");
+ ar & CreateNVP(secondHashSize, "secondHashSize");
+ ar & CreateNVP(secondHashWeights, "secondHashWeights");
+ ar & CreateNVP(bucketSize, "bucketSize");
+ ar & CreateNVP(secondHashTable, "secondHashTable");
+ ar & CreateNVP(bucketContentSize, "bucketContentSize");
+ ar & CreateNVP(bucketRowInHashTable, "bucketRowInHashTable");
+ ar & CreateNVP(distanceEvaluations, "distanceEvaluations");
+}
+
+template<typename SortPolicy>
std::string LSHSearch<SortPolicy>::ToString() const
{
std::ostringstream convert;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index b89f8ce..d42e6d5 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -31,6 +31,7 @@
#include <mlpack/methods/det/dtree.hpp>
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
#include <mlpack/methods/rann/ra_search.hpp>
+#include <mlpack/methods/lsh/lsh_search.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -41,6 +42,7 @@ using namespace mlpack::tree;
using namespace mlpack::perceptron;
using namespace mlpack::regression;
using namespace mlpack::naive_bayes;
+using namespace mlpack::neighbor;
using namespace arma;
using namespace boost;
@@ -1382,4 +1384,49 @@ BOOST_AUTO_TEST_CASE(RASearchTest)
BOOST_REQUIRE_GT(textCorrect, 95 * 5);
}
+/**
+ * Test that an LSH model can be serialized and deserialized.
+ */
+BOOST_AUTO_TEST_CASE(LSHTest)
+{
+ // Since we still don't have good tests for LSH, basically what we're going to
+ // do is serialize an LSH model, and make sure we can deserialize it and that
+ // we still get results when we call Search().
+ arma::mat referenceData = arma::randu<arma::mat>(10, 100);
+
+ LSHSearch<> lsh(referenceData, 5, 10); // Arbitrary chosen parameters.
+
+ LSHSearch<> xmlLsh;
+ arma::mat textData = arma::randu<arma::mat>(5, 50);
+ LSHSearch<> textLsh(textData, 4, 5);
+ LSHSearch<> binaryLsh(referenceData, 15, 2);
+
+ // Now serialize.
+ SerializeObjectAll(lsh, xmlLsh, textLsh, binaryLsh);
+
+ // Check what we can about the serialized objects.
+ BOOST_REQUIRE_EQUAL(lsh.NumProjections(), xmlLsh.NumProjections());
+ BOOST_REQUIRE_EQUAL(lsh.NumProjections(), textLsh.NumProjections());
+ BOOST_REQUIRE_EQUAL(lsh.NumProjections(), binaryLsh.NumProjections());
+ for (size_t i = 0; i < lsh.NumProjections(); ++i)
+ {
+ CheckMatrices(lsh.Projection(i), xmlLsh.Projection(i),
+ textLsh.Projection(i), binaryLsh.Projection(i));
+ }
+
+ CheckMatrices(lsh.ReferenceSet(), xmlLsh.ReferenceSet(),
+ textLsh.ReferenceSet(), binaryLsh.ReferenceSet());
+ CheckMatrices(lsh.Offsets(), xmlLsh.Offsets(), textLsh.Offsets(),
+ binaryLsh.Offsets());
+ CheckMatrices(lsh.SecondHashWeights(), xmlLsh.SecondHashWeights(),
+ textLsh.SecondHashWeights(), binaryLsh.SecondHashWeights());
+
+ BOOST_REQUIRE_EQUAL(lsh.BucketSize(), xmlLsh.BucketSize());
+ BOOST_REQUIRE_EQUAL(lsh.BucketSize(), textLsh.BucketSize());
+ BOOST_REQUIRE_EQUAL(lsh.BucketSize(), binaryLsh.BucketSize());
+
+ CheckMatrices(lsh.SecondHashTable(), xmlLsh.SecondHashTable(),
+ textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list