[mlpack-git] master: Add Train() and tests. (7d68833)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:16 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef

>---------------------------------------------------------------

commit 7d6883306c83303e7a1bb66939d0213ac77d7466
Author: ryan <ryan at ratml.org>
Date:   Tue Dec 22 15:44:56 2015 -0500

    Add Train() and tests.


>---------------------------------------------------------------

7d6883306c83303e7a1bb66939d0213ac77d7466
 src/mlpack/methods/fastmks/fastmks.hpp      | 18 ++++++++
 src/mlpack/methods/fastmks/fastmks_impl.hpp | 46 ++++++++++++++++++++
 src/mlpack/tests/fastmks_test.cpp           | 67 +++++++++++++++++++++++++++++
 3 files changed, 131 insertions(+)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 7e1da5c..755b7e9 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -116,6 +116,24 @@ class FastMKS
   ~FastMKS();
 
   /**
+   * "Train" the FastMKS model on the given reference set (this will just build
+   * a tree, if the current search mode is not naive mode).
+   *
+   * @param referenceSet Set of reference points.
+   */
+  void Train(const MatType& referenceSet);
+
+  /**
+   * "Train" the FastMKS model on the given reference set and use the given
+   * kernel.  This will just build a tree and replace the metric, if the current
+   * search mode is not naive mode.
+   *
+   * @param referenceSet Set of reference points.
+   * @param kernel Kernel to use for search.
+   */
+  void Train(const MatType& referenceSet, KernelType& kernel);
+
+  /**
    * Search for the points in the reference set with maximum kernel evaluation
    * to each point in the given query set.  The resulting kernel evaluations are
    * stored in the kernels matrix, and the corresponding point indices are
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 1d7630f..df28e57 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -127,6 +127,52 @@ template<typename KernelType,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet)
+{
+  if (setOwner)
+    delete this->referenceSet;
+
+  this->referenceSet = &referenceSet;
+  this->setOwner = false;
+
+  if (!naive)
+  {
+    if (treeOwner && referenceTree)
+      delete referenceTree;
+    referenceTree = new Tree(referenceSet, metric);
+    treeOwner = true;
+  }
+}
+
+template<typename KernelType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Train(const MatType& referenceSet,
+                                                   KernelType& kernel)
+{
+  if (setOwner)
+    delete this->referenceSet;
+
+  this->referenceSet = &referenceSet;
+  this->metric = metric::IPMetric<KernelType>(kernel);
+  this->setOwner = false;
+
+  if (!naive)
+  {
+    if (treeOwner && referenceTree)
+      delete referenceTree;
+    referenceTree = new Tree(referenceSet, metric);
+    treeOwner = true;
+  }
+}
+
+template<typename KernelType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
 void FastMKS<KernelType, MatType, TreeType>::Search(
     const MatType& querySet,
     const size_t k,
diff --git a/src/mlpack/tests/fastmks_test.cpp b/src/mlpack/tests/fastmks_test.cpp
index 388d2b1..fba3b66 100644
--- a/src/mlpack/tests/fastmks_test.cpp
+++ b/src/mlpack/tests/fastmks_test.cpp
@@ -213,4 +213,71 @@ BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
       std::invalid_argument);
 }
 
+// Make sure the simplest overload of Train() works.
+BOOST_AUTO_TEST_CASE(SimpleTrainTest)
+{
+  arma::mat referenceSet = arma::randu<arma::mat>(5, 100);
+
+  FastMKS<LinearKernel> f(referenceSet);
+  FastMKS<LinearKernel> f2;
+  f2.Train(referenceSet);
+
+  arma::Mat<size_t> indices, indices2;
+  arma::mat products, products2;
+
+  arma::mat querySet = arma::randu<arma::mat>(5, 20);
+
+  f.Search(querySet, 3, indices, products);
+  f2.Search(querySet, 3, indices2, products2);
+
+  BOOST_REQUIRE_EQUAL(indices.n_rows, indices2.n_rows);
+  BOOST_REQUIRE_EQUAL(products.n_rows, products2.n_rows);
+  BOOST_REQUIRE_EQUAL(indices.n_cols, indices2.n_cols);
+  BOOST_REQUIRE_EQUAL(products.n_cols, products2.n_cols);
+
+  for (size_t i = 0; i < products.n_elem; ++i)
+  {
+    if (std::abs(products[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(products2[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(products[i], products2[i], 1e-5);
+
+    BOOST_REQUIRE_EQUAL(indices[i], indices2[i]);
+  }
+}
+
+// Test the Train() overload that takes a kernel too.
+BOOST_AUTO_TEST_CASE(SimpleTrainKernelTest)
+{
+  arma::mat referenceSet = arma::randu<arma::mat>(5, 100);
+  GaussianKernel gk(2.0);
+
+  FastMKS<GaussianKernel> f(referenceSet, gk);
+  FastMKS<GaussianKernel> f2;
+  f2.Train(referenceSet, gk);
+
+  arma::Mat<size_t> indices, indices2;
+  arma::mat products, products2;
+
+  arma::mat querySet = arma::randu<arma::mat>(5, 20);
+
+  f.Search(querySet, 3, indices, products);
+  f2.Search(querySet, 3, indices2, products2);
+
+  BOOST_REQUIRE_EQUAL(indices.n_rows, indices2.n_rows);
+  BOOST_REQUIRE_EQUAL(products.n_rows, products2.n_rows);
+  BOOST_REQUIRE_EQUAL(indices.n_cols, indices2.n_cols);
+  BOOST_REQUIRE_EQUAL(products.n_cols, products2.n_cols);
+
+  for (size_t i = 0; i < products.n_elem; ++i)
+  {
+    if (std::abs(products[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(products2[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(products[i], products2[i], 1e-5);
+
+    BOOST_REQUIRE_EQUAL(indices[i], indices2[i]);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list