[mlpack-git] master: Add Serialize() and tests. (eb41f4b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:28 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef
>---------------------------------------------------------------
commit eb41f4bc27b484c347acc006255104e2f8cc4eef
Author: ryan <ryan at ratml.org>
Date: Tue Dec 22 17:02:00 2015 -0500
Add Serialize() and tests.
>---------------------------------------------------------------
eb41f4bc27b484c347acc006255104e2f8cc4eef
src/mlpack/methods/fastmks/fastmks.hpp | 4 ++
src/mlpack/methods/fastmks/fastmks_impl.hpp | 57 +++++++++++++++++++++++++++++
src/mlpack/methods/fastmks/fastmks_stat.hpp | 15 ++++++++
src/mlpack/tests/fastmks_test.cpp | 27 ++++++++++++++
4 files changed, 103 insertions(+)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 755b7e9..cf87c8c 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -213,6 +213,10 @@ class FastMKS
//! Modify whether or not single-tree search is used.
bool& SingleMode() { return singleMode; }
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! The reference dataset. We never own this; only the tree or a higher level
//! does.
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index df28e57..c6bd65f 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -418,6 +418,63 @@ void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
indices(pos, queryIndex) = neighbor;
}
+//! Serialize the model.
+template<typename KernelType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+template<typename Archive>
+void FastMKS<KernelType, MatType, TreeType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // Serialize preferences for search.
+ ar & CreateNVP(naive, "naive");
+ ar & CreateNVP(singleMode, "singleMode");
+
+ // If we are doing naive search, serialize the dataset. Otherwise we
+ // serialize the tree.
+ if (naive)
+ {
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ setOwner = true;
+ }
+
+ ar & CreateNVP(referenceSet, "referenceSet");
+ ar & CreateNVP(metric, "metric");
+ }
+ else
+ {
+ // Delete the current reference tree, if necessary.
+ if (Archive::is_loading::value)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ treeOwner = true;
+ }
+
+ ar & CreateNVP(referenceTree, "referenceTree");
+
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ referenceSet = &referenceTree->Dataset();
+ metric = metric::IPMetric<KernelType>(referenceTree->Metric().Kernel());
+ setOwner = false;
+ }
+ }
+}
+
} // namespace fastmks
} // namespace mlpack
diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp
index 33d6e42..c6273b3 100644
--- a/src/mlpack/methods/fastmks/fastmks_stat.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp
@@ -93,6 +93,21 @@ class FastMKSStat
//! evaluation.
void*& LastKernelNode() { return lastKernelNode; }
+ //! Serialize the statistic.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(bound, "bound");
+ ar & data::CreateNVP(selfKernel, "selfKernel");
+
+ // Void out last kernel information on load.
+ if (Archive::is_loading::value)
+ {
+ lastKernel = 0.0;
+ lastKernelNode = NULL;
+ }
+ }
+
private:
//! The bound for pruning.
double bound;
diff --git a/src/mlpack/tests/fastmks_test.cpp b/src/mlpack/tests/fastmks_test.cpp
index fba3b66..0a659a9 100644
--- a/src/mlpack/tests/fastmks_test.cpp
+++ b/src/mlpack/tests/fastmks_test.cpp
@@ -11,6 +11,7 @@
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
+#include "serialization.hpp"
using namespace mlpack;
using namespace mlpack::tree;
@@ -280,4 +281,30 @@ BOOST_AUTO_TEST_CASE(SimpleTrainKernelTest)
}
}
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(5, 200);
+
+ FastMKS<LinearKernel> f(dataset);
+
+ FastMKS<LinearKernel> fXml, fText, fBinary;
+ arma::mat otherDataset = arma::randu<arma::mat>(3, 10);
+ fBinary.Train(otherDataset);
+
+ SerializeObjectAll(f, fXml, fText, fBinary);
+
+ arma::mat kernels, xmlKernels, textKernels, binaryKernels;
+ arma::Mat<size_t> indices, xmlIndices, textIndices, binaryIndices;
+
+ arma::mat querySet = arma::randu<arma::mat>(5, 100);
+
+ f.Search(querySet, 5, indices, kernels);
+ fXml.Search(querySet, 5, xmlIndices, xmlKernels);
+ fText.Search(querySet, 5, textIndices, textKernels);
+ fBinary.Search(querySet, 5, binaryIndices, binaryKernels);
+
+ CheckMatrices(indices, xmlIndices, textIndices, binaryIndices);
+ CheckMatrices(kernels, xmlKernels, textKernels, binaryKernels);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list