[mlpack-git] master: Add Train() and tests. (7d68833)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:16 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef
>---------------------------------------------------------------
commit 7d6883306c83303e7a1bb66939d0213ac77d7466
Author: ryan <ryan at ratml.org>
Date: Tue Dec 22 15:44:56 2015 -0500
Add Train() and tests.
>---------------------------------------------------------------
7d6883306c83303e7a1bb66939d0213ac77d7466
src/mlpack/methods/fastmks/fastmks.hpp | 18 ++++++++
src/mlpack/methods/fastmks/fastmks_impl.hpp | 46 ++++++++++++++++++++
src/mlpack/tests/fastmks_test.cpp | 67 +++++++++++++++++++++++++++++
3 files changed, 131 insertions(+)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 7e1da5c..755b7e9 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -116,6 +116,24 @@ class FastMKS
~FastMKS();
/**
+ * "Train" the FastMKS model on the given reference set (this will just build
+ * a tree, if the current search mode is not naive mode).
+ *
+ * @param referenceSet Set of reference points.
+ */
+ void Train(const MatType& referenceSet);
+
+ /**
+ * "Train" the FastMKS model on the given reference set and use the given
+ * kernel. This will just build a tree and replace the metric, if the current
+ * search mode is not naive mode.
+ *
+ * @param referenceSet Set of reference points.
+ * @param kernel Kernel to use for search.
+ */
+ void Train(const MatType& referenceSet, KernelType& kernel);
+
+ /**
* Search for the points in the reference set with maximum kernel evaluation
* to each point in the given query set. The resulting kernel evaluations are
* stored in the kernels matrix, and the corresponding point indices are
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 1d7630f..df28e57 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -127,6 +127,52 @@ template<typename KernelType,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet)
+{
+ if (setOwner)
+ delete this->referenceSet;
+
+ this->referenceSet = &referenceSet;
+ this->setOwner = false;
+
+ if (!naive)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+ referenceTree = new Tree(referenceSet, metric);
+ treeOwner = true;
+ }
+}
+
+template<typename KernelType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet,
+ KernelType& kernel)
+{
+ if (setOwner)
+ delete this->referenceSet;
+
+ this->referenceSet = &referenceSet;
+ this->metric = metric::IPMetric<KernelType>(kernel);
+ this->setOwner = false;
+
+ if (!naive)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+ referenceTree = new Tree(referenceSet, metric);
+ treeOwner = true;
+ }
+}
+
+template<typename KernelType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
void FastMKS<KernelType, MatType, TreeType>::Search(
const MatType& querySet,
const size_t k,
diff --git a/src/mlpack/tests/fastmks_test.cpp b/src/mlpack/tests/fastmks_test.cpp
index 388d2b1..fba3b66 100644
--- a/src/mlpack/tests/fastmks_test.cpp
+++ b/src/mlpack/tests/fastmks_test.cpp
@@ -213,4 +213,71 @@ BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
std::invalid_argument);
}
+// Make sure the simplest overload of Train() works.
+BOOST_AUTO_TEST_CASE(SimpleTrainTest)
+{
+ arma::mat referenceSet = arma::randu<arma::mat>(5, 100);
+
+ FastMKS<LinearKernel> f(referenceSet);
+ FastMKS<LinearKernel> f2;
+ f2.Train(referenceSet);
+
+ arma::Mat<size_t> indices, indices2;
+ arma::mat products, products2;
+
+ arma::mat querySet = arma::randu<arma::mat>(5, 20);
+
+ f.Search(querySet, 3, indices, products);
+ f2.Search(querySet, 3, indices2, products2);
+
+ BOOST_REQUIRE_EQUAL(indices.n_rows, indices2.n_rows);
+ BOOST_REQUIRE_EQUAL(products.n_rows, products2.n_rows);
+ BOOST_REQUIRE_EQUAL(indices.n_cols, indices2.n_cols);
+ BOOST_REQUIRE_EQUAL(products.n_cols, products2.n_cols);
+
+ for (size_t i = 0; i < products.n_elem; ++i)
+ {
+ if (std::abs(products[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(products2[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(products[i], products2[i], 1e-5);
+
+ BOOST_REQUIRE_EQUAL(indices[i], indices2[i]);
+ }
+}
+
+// Test the Train() overload that takes a kernel too.
+BOOST_AUTO_TEST_CASE(SimpleTrainKernelTest)
+{
+ arma::mat referenceSet = arma::randu<arma::mat>(5, 100);
+ GaussianKernel gk(2.0);
+
+ FastMKS<GaussianKernel> f(referenceSet, gk);
+ FastMKS<GaussianKernel> f2;
+ f2.Train(referenceSet, gk);
+
+ arma::Mat<size_t> indices, indices2;
+ arma::mat products, products2;
+
+ arma::mat querySet = arma::randu<arma::mat>(5, 20);
+
+ f.Search(querySet, 3, indices, products);
+ f2.Search(querySet, 3, indices2, products2);
+
+ BOOST_REQUIRE_EQUAL(indices.n_rows, indices2.n_rows);
+ BOOST_REQUIRE_EQUAL(products.n_rows, products2.n_rows);
+ BOOST_REQUIRE_EQUAL(indices.n_cols, indices2.n_cols);
+ BOOST_REQUIRE_EQUAL(products.n_cols, products2.n_cols);
+
+ for (size_t i = 0; i < products.n_elem; ++i)
+ {
+ if (std::abs(products[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(products2[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(products[i], products2[i], 1e-5);
+
+ BOOST_REQUIRE_EQUAL(indices[i], indices2[i]);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list