[mlpack-git] master: Add Serialize() and tests. (eb41f4b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:28 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef

>---------------------------------------------------------------

commit eb41f4bc27b484c347acc006255104e2f8cc4eef
Author: ryan <ryan at ratml.org>
Date:   Tue Dec 22 17:02:00 2015 -0500

    Add Serialize() and tests.


>---------------------------------------------------------------

eb41f4bc27b484c347acc006255104e2f8cc4eef
 src/mlpack/methods/fastmks/fastmks.hpp      |  4 ++
 src/mlpack/methods/fastmks/fastmks_impl.hpp | 57 +++++++++++++++++++++++++++++
 src/mlpack/methods/fastmks/fastmks_stat.hpp | 15 ++++++++
 src/mlpack/tests/fastmks_test.cpp           | 27 ++++++++++++++
 4 files changed, 103 insertions(+)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 755b7e9..cf87c8c 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -213,6 +213,10 @@ class FastMKS
   //! Modify whether or not single-tree search is used.
   bool& SingleMode() { return singleMode; }
 
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   //! The reference dataset.  We never own this; only the tree or a higher level
   //! does.
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index df28e57..c6bd65f 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -418,6 +418,63 @@ void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
   indices(pos, queryIndex) = neighbor;
 }
 
+//! Serialize the model.
+template<typename KernelType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+template<typename Archive>
+void FastMKS<KernelType, MatType, TreeType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // Serialize preferences for search.
+  ar & CreateNVP(naive, "naive");
+  ar & CreateNVP(singleMode, "singleMode");
+
+  // If we are doing naive search, serialize the dataset.  Otherwise we
+  // serialize the tree.
+  if (naive)
+  {
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      setOwner = true;
+    }
+
+    ar & CreateNVP(referenceSet, "referenceSet");
+    ar & CreateNVP(metric, "metric");
+  }
+  else
+  {
+    // Delete the current reference tree, if necessary.
+    if (Archive::is_loading::value)
+    {
+      if (treeOwner && referenceTree)
+        delete referenceTree;
+
+      treeOwner = true;
+    }
+
+    ar & CreateNVP(referenceTree, "referenceTree");
+
+    if (Archive::is_loading::value)
+    {
+      if (setOwner && referenceSet)
+        delete referenceSet;
+
+      referenceSet = &referenceTree->Dataset();
+      metric = metric::IPMetric<KernelType>(referenceTree->Metric().Kernel());
+      setOwner = false;
+    }
+  }
+}
+
 } // namespace fastmks
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp
index 33d6e42..c6273b3 100644
--- a/src/mlpack/methods/fastmks/fastmks_stat.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp
@@ -93,6 +93,21 @@ class FastMKSStat
   //! evaluation.
   void*& LastKernelNode() { return lastKernelNode; }
 
+  //! Serialize the statistic.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(bound, "bound");
+    ar & data::CreateNVP(selfKernel, "selfKernel");
+
+    // Void out last kernel information on load.
+    if (Archive::is_loading::value)
+    {
+      lastKernel = 0.0;
+      lastKernelNode = NULL;
+    }
+  }
+
  private:
   //! The bound for pruning.
   double bound;
diff --git a/src/mlpack/tests/fastmks_test.cpp b/src/mlpack/tests/fastmks_test.cpp
index fba3b66..0a659a9 100644
--- a/src/mlpack/tests/fastmks_test.cpp
+++ b/src/mlpack/tests/fastmks_test.cpp
@@ -11,6 +11,7 @@
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
 
 using namespace mlpack;
 using namespace mlpack::tree;
@@ -280,4 +281,30 @@ BOOST_AUTO_TEST_CASE(SimpleTrainKernelTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(5, 200);
+
+  FastMKS<LinearKernel> f(dataset);
+
+  FastMKS<LinearKernel> fXml, fText, fBinary;
+  arma::mat otherDataset = arma::randu<arma::mat>(3, 10);
+  fBinary.Train(otherDataset);
+
+  SerializeObjectAll(f, fXml, fText, fBinary);
+
+  arma::mat kernels, xmlKernels, textKernels, binaryKernels;
+  arma::Mat<size_t> indices, xmlIndices, textIndices, binaryIndices;
+
+  arma::mat querySet = arma::randu<arma::mat>(5, 100);
+
+  f.Search(querySet, 5, indices, kernels);
+  fXml.Search(querySet, 5, xmlIndices, xmlKernels);
+  fText.Search(querySet, 5, textIndices, textKernels);
+  fBinary.Search(querySet, 5, binaryIndices, binaryKernels);
+
+  CheckMatrices(indices, xmlIndices, textIndices, binaryIndices);
+  CheckMatrices(kernels, xmlKernels, textKernels, binaryKernels);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list