[mlpack-git] master: Add Serialize() for RASearch. (24b66d4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Nov 10 18:23:21 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/78cc694a4fd50a68a24f5ab9af7531873566b3ba...0f4e83dc9cc4dcdc315d2cceee32b23ebab114c2

>---------------------------------------------------------------

commit 24b66d49e9a0f7569d2346a31ee8c348569f1b46
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Nov 10 17:58:35 2015 -0500

    Add Serialize() for RASearch.


>---------------------------------------------------------------

24b66d49e9a0f7569d2346a31ee8c348569f1b46
 src/mlpack/methods/rann/ra_query_stat.hpp  | 10 ++++-
 src/mlpack/methods/rann/ra_search.hpp      |  4 ++
 src/mlpack/methods/rann/ra_search_impl.hpp | 59 +++++++++++++++++++++++++++
 src/mlpack/tests/serialization_test.cpp    | 64 ++++++++++++++++++++++++++++++
 4 files changed, 135 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_query_stat.hpp b/src/mlpack/methods/rann/ra_query_stat.hpp
index d9c2610..e853bf0 100644
--- a/src/mlpack/methods/rann/ra_query_stat.hpp
+++ b/src/mlpack/methods/rann/ra_query_stat.hpp
@@ -55,13 +55,19 @@ class RAQueryStat
   //! Modify the number of samples made.
   size_t& NumSamplesMade() { return numSamplesMade; }
 
+  //! Serialize the statistic.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(bound, "bound");
+    ar & data::CreateNVP(numSamplesMade, "numSamplesMade");
+  }
+
  private:
   //! The bound on the node's neighbor distances.
   double bound;
-
   //! The minimum number of samples made by any query in this node.
   size_t numSamplesMade;
-
 };
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 3cd8714..40a4c0f 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -290,6 +290,10 @@ class RASearch
   //! Returns a string representation of this object.
   std::string ToString() const;
 
+  //! Serialize the object.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   //! Permutations of reference points during tree building.
   std::vector<size_t> oldFromNewReferences;
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index cc96166..27ec5fa 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -523,6 +523,65 @@ std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
   return convert.str();
 }
 
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+template<typename Archive>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // Serialize preferences for search.
+  ar & CreateNVP(naive, "naive");
+  ar & CreateNVP(singleMode, "singleMode");
+
+  // If we are doing naive search, we serialize the dataset.  Otherwise we
+  // serialize the tree.
+  if (naive)
+  {
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      setOwner = true;
+    }
+
+    ar & CreateNVP(referenceSet, "referenceSet");
+    ar & CreateNVP(metric, "metric");
+
+    // If we are loading, set the tree to NULL and clean up memory if necessary.
+    if (Archive::is_loading::value)
+    {
+      if (treeOwner && referenceTree)
+        delete referenceTree;
+
+      // After we load the tree, we will own it.
+      treeOwner = true;
+    }
+
+    ar & CreateNVP(referenceTree, "referenceTree");
+    ar & CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+    // If we are loading, set the dataset accordingly and clean up memory if
+    // necessary.
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      referenceSet = &referenceTree->Dataset();
+      metric = referenceTree->Metric();
+      setOwner = false;
+    }
+  }
+}
+
 } // namespace neighbor
 } // namespace mlpack
 
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index fddb368..f859236 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -30,6 +30,7 @@
 #include <mlpack/methods/softmax_regression/softmax_regression.hpp>
 #include <mlpack/methods/det/dtree.hpp>
 #include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
+#include <mlpack/methods/rann/ra_search.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -1318,4 +1319,67 @@ BOOST_AUTO_TEST_CASE(NaiveBayesSerializationTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(RASearchTest)
+{
+  using neighbor::AllkRANN;
+  using neighbor::AllkNN;
+  arma::mat dataset = arma::randu<arma::mat>(5, 200);
+  arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
+
+  // Find nearest neighbors in the top 10, with accuracy 0.95.  So 95% of the
+  // results we get (at least) should fall into the top 10 of the true nearest
+  // neighbors.
+  AllkRANN allkrann(dataset, false, false, 5, 0.95);
+
+  AllkRANN krannXml(otherDataset, false, false);
+  AllkRANN krannText(otherDataset, true, false);
+  AllkRANN krannBinary(otherDataset, true, true);
+
+  SerializeObjectAll(allkrann, krannXml, krannText, krannBinary);
+
+  // Now run nearest neighbor and make sure the results are the same.
+  arma::mat querySet = arma::randu<arma::mat>(5, 100);
+
+  arma::mat distances, xmlDistances, textDistances, binaryDistances;
+  arma::Mat<size_t> neighbors, xmlNeighbors, textNeighbors, binaryNeighbors;
+
+  AllkNN allknn(dataset); // Exact search.
+  allknn.Search(querySet, 10, neighbors, distances);
+  krannXml.Search(querySet, 5, xmlNeighbors, xmlDistances);
+  krannText.Search(querySet, 5, textNeighbors, textDistances);
+  krannBinary.Search(querySet, 5, binaryNeighbors, binaryDistances);
+
+  BOOST_REQUIRE_EQUAL(xmlNeighbors.n_rows, 5);
+  BOOST_REQUIRE_EQUAL(xmlNeighbors.n_cols, 100);
+  BOOST_REQUIRE_EQUAL(textNeighbors.n_rows, 5);
+  BOOST_REQUIRE_EQUAL(textNeighbors.n_cols, 100);
+  BOOST_REQUIRE_EQUAL(binaryNeighbors.n_rows, 5);
+  BOOST_REQUIRE_EQUAL(binaryNeighbors.n_cols, 100);
+
+  size_t xmlCorrect = 0;
+  size_t textCorrect = 0;
+  size_t binaryCorrect = 0;
+  for (size_t i = 0; i < xmlNeighbors.n_cols; ++i)
+  {
+    // See how many are in the top 10.
+    for (size_t j = 0; j < xmlNeighbors.n_rows; ++j)
+    {
+      for (size_t k = 0; k < neighbors.n_rows; ++k)
+      {
+        if (neighbors[i] == xmlNeighbors[i])
+          xmlCorrect++;
+        if (neighbors[i] == textNeighbors[i])
+          textCorrect++;
+        if (neighbors[i] == binaryNeighbors[i])
+          binaryCorrect++;
+      }
+    }
+  }
+
+  // We need 95% of these to be correct.
+  BOOST_REQUIRE_GT(xmlCorrect, 95 * 5);
+  BOOST_REQUIRE_GT(binaryCorrect, 95 * 5);
+  BOOST_REQUIRE_GT(textCorrect, 95 * 5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list