[mlpack-git] master: Add Serialize() for RASearch. (24b66d4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Nov 10 18:23:21 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/78cc694a4fd50a68a24f5ab9af7531873566b3ba...0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2
>---------------------------------------------------------------
commit 24b66d49e9a0f7569d2346a31ee8c348569f1b46
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Nov 10 17:58:35 2015 -0500
Add Serialize() for RASearch.
>---------------------------------------------------------------
24b66d49e9a0f7569d2346a31ee8c348569f1b46
src/mlpack/methods/rann/ra_query_stat.hpp | 10 ++++-
src/mlpack/methods/rann/ra_search.hpp | 4 ++
src/mlpack/methods/rann/ra_search_impl.hpp | 59 +++++++++++++++++++++++++++
src/mlpack/tests/serialization_test.cpp | 64 ++++++++++++++++++++++++++++++
4 files changed, 135 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_query_stat.hpp b/src/mlpack/methods/rann/ra_query_stat.hpp
index d9c2610..e853bf0 100644
--- a/src/mlpack/methods/rann/ra_query_stat.hpp
+++ b/src/mlpack/methods/rann/ra_query_stat.hpp
@@ -55,13 +55,19 @@ class RAQueryStat
//! Modify the number of samples made.
size_t& NumSamplesMade() { return numSamplesMade; }
+ //! Serialize the statistic.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(bound, "bound");
+ ar & data::CreateNVP(numSamplesMade, "numSamplesMade");
+ }
+
private:
//! The bound on the node's neighbor distances.
double bound;
-
//! The minimum number of samples made by any query in this node.
size_t numSamplesMade;
-
};
} // namespace neighbor
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 3cd8714..40a4c0f 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -290,6 +290,10 @@ class RASearch
//! Returns a string representation of this object.
std::string ToString() const;
+ //! Serialize the object.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! Permutations of reference points during tree building.
std::vector<size_t> oldFromNewReferences;
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index cc96166..27ec5fa 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -523,6 +523,65 @@ std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
return convert.str();
}
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+template<typename Archive>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // Serialize preferences for search.
+ ar & CreateNVP(naive, "naive");
+ ar & CreateNVP(singleMode, "singleMode");
+
+ // If we are doing naive search, we serialize the dataset. Otherwise we
+ // serialize the tree.
+ if (naive)
+ {
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ setOwner = true;
+ }
+
+ ar & CreateNVP(referenceSet, "referenceSet");
+ ar & CreateNVP(metric, "metric");
+
+ // If we are loading, set the tree to NULL and clean up memory if necessary.
+ if (Archive::is_loading::value)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // After we load the tree, we will own it.
+ treeOwner = true;
+ }
+
+ ar & CreateNVP(referenceTree, "referenceTree");
+ ar & CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+ // If we are loading, set the dataset accordingly and clean up memory if
+ // necessary.
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ referenceSet = &referenceTree->Dataset();
+ metric = referenceTree->Metric();
+ setOwner = false;
+ }
+ }
+}
+
} // namespace neighbor
} // namespace mlpack
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index fddb368..f859236 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -30,6 +30,7 @@
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
#include <mlpack/methods/det/dtree.hpp>
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
+#include <mlpack/methods/rann/ra_search.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -1318,4 +1319,67 @@ BOOST_AUTO_TEST_CASE(NaiveBayesSerializationTest)
}
}
+BOOST_AUTO_TEST_CASE(RASearchTest)
+{
+ using neighbor::AllkRANN;
+ using neighbor::AllkNN;
+ arma::mat dataset = arma::randu<arma::mat>(5, 200);
+ arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
+
+ // Find nearest neighbors in the top 10, with accuracy 0.95. So 95% of the
+ // results we get (at least) should fall into the top 10 of the true nearest
+ // neighbors.
+ AllkRANN allkrann(dataset, false, false, 5, 0.95);
+
+ AllkRANN krannXml(otherDataset, false, false);
+ AllkRANN krannText(otherDataset, true, false);
+ AllkRANN krannBinary(otherDataset, true, true);
+
+ SerializeObjectAll(allkrann, krannXml, krannText, krannBinary);
+
+ // Now run nearest neighbor and make sure the results are the same.
+ arma::mat querySet = arma::randu<arma::mat>(5, 100);
+
+ arma::mat distances, xmlDistances, textDistances, binaryDistances;
+ arma::Mat<size_t> neighbors, xmlNeighbors, textNeighbors, binaryNeighbors;
+
+ AllkNN allknn(dataset); // Exact search.
+ allknn.Search(querySet, 10, neighbors, distances);
+ krannXml.Search(querySet, 5, xmlNeighbors, xmlDistances);
+ krannText.Search(querySet, 5, textNeighbors, textDistances);
+ krannBinary.Search(querySet, 5, binaryNeighbors, binaryDistances);
+
+ BOOST_REQUIRE_EQUAL(xmlNeighbors.n_rows, 5);
+ BOOST_REQUIRE_EQUAL(xmlNeighbors.n_cols, 100);
+ BOOST_REQUIRE_EQUAL(textNeighbors.n_rows, 5);
+ BOOST_REQUIRE_EQUAL(textNeighbors.n_cols, 100);
+ BOOST_REQUIRE_EQUAL(binaryNeighbors.n_rows, 5);
+ BOOST_REQUIRE_EQUAL(binaryNeighbors.n_cols, 100);
+
+ size_t xmlCorrect = 0;
+ size_t textCorrect = 0;
+ size_t binaryCorrect = 0;
+ for (size_t i = 0; i < xmlNeighbors.n_cols; ++i)
+ {
+ // See how many are in the top 10.
+ for (size_t j = 0; j < xmlNeighbors.n_rows; ++j)
+ {
+ for (size_t k = 0; k < neighbors.n_rows; ++k)
+ {
+ if (neighbors[i] == xmlNeighbors[i])
+ xmlCorrect++;
+ if (neighbors[i] == textNeighbors[i])
+ textCorrect++;
+ if (neighbors[i] == binaryNeighbors[i])
+ binaryCorrect++;
+ }
+ }
+ }
+
+ // We need 95% of these to be correct.
+ BOOST_REQUIRE_GT(xmlCorrect, 95 * 5);
+ BOOST_REQUIRE_GT(binaryCorrect, 95 * 5);
+ BOOST_REQUIRE_GT(textCorrect, 95 * 5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list