[mlpack-git] master: Add Train() with a tree. (da8ed43)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 18:34:43 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eb41f4bc27b484c347acc006255104e2f8cc4eef...977afbec0648056124dcb206e0bf972a161d9b51

>---------------------------------------------------------------

commit da8ed43691ea7a178f7349a57e203f5fb529e63f
Author: ryan <ryan at ratml.org>
Date:   Tue Dec 22 18:18:58 2015 -0500

    Add Train() with a tree.


>---------------------------------------------------------------

da8ed43691ea7a178f7349a57e203f5fb529e63f
 src/mlpack/methods/fastmks/fastmks.hpp      |  9 +++++++++
 src/mlpack/methods/fastmks/fastmks_impl.hpp | 25 +++++++++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 3a73efc..35cdacb 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -134,6 +134,15 @@ class FastMKS
   void Train(const MatType& referenceSet, KernelType& kernel);
 
   /**
+   * Train the FastMKS model on the given reference tree.  This takes ownership
+   * of the tree, so you do not need to delete it!  This will throw an exception
+   * if the model is searching in naive mode (i.e. if Naive() == true).
+   *
+   * @param tree Tree to use as reference data.
+   */
+  void Train(Tree* referenceTree);
+
+  /**
    * Search for the points in the reference set with maximum kernel evaluation
    * to each point in the given query set.  The resulting kernel evaluations are
    * stored in the kernels matrix, and the corresponding point indices are
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index c6bd65f..84c4fd4 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -173,6 +173,31 @@ template<typename KernelType,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Train(Tree* tree)
+{
+  if (naive)
+    throw std::invalid_argument("cannot call FastMKS::Train() with a tree when "
+        "in naive search mode");
+
+  if (setOwner)
+    delete this->referenceSet;
+
+  this->referenceSet = &tree->Dataset();
+  this->metric = metric::IPMetric<KernelType>(tree->Metric().Kernel());
+  this->setOwner = false;
+
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  this->referenceTree = tree;
+  this->treeOwner = true;
+}
+
+template<typename KernelType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
 void FastMKS<KernelType, MatType, TreeType>::Search(
     const MatType& querySet,
     const size_t k,



More information about the mlpack-git mailing list