[mlpack-git] master: Refactor FastMKS to new TreeType API. (f7c1693)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:42 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit f7c1693c7d6b3eeb8e55d51f80cad99d487d9311
Author: ryan <ryan at ratml.org>
Date:   Fri Jul 24 14:26:48 2015 -0400

    Refactor FastMKS to new TreeType API.


>---------------------------------------------------------------

f7c1693c7d6b3eeb8e55d51f80cad99d487d9311
 src/mlpack/methods/fastmks/fastmks.hpp      |  26 +++---
 src/mlpack/methods/fastmks/fastmks_impl.hpp | 119 ++++++++++++++++++----------
 src/mlpack/methods/fastmks/fastmks_main.cpp |  13 +--
 3 files changed, 97 insertions(+), 61 deletions(-)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 72ea34b..7b54835 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -48,12 +48,16 @@ namespace fastmks /** Fast max-kernel search. */ {
  */
 template<
     typename KernelType,
-    typename TreeType = tree::CoverTree<metric::IPMetric<KernelType>,
-        tree::FirstPointIsRoot, FastMKSStat>
+    typename MatType,
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType
 >
 class FastMKS
 {
  public:
+  //! Convenience typedef.
+  typedef TreeType<metric::IPMetric<KernelType>, FastMKSStat, MatType> Tree;
+
   /**
    * Create the FastMKS object with the given reference set (this is the set
    * that is searched).  Optionally, specify whether or not single-tree search
@@ -63,7 +67,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const typename TreeType::Mat& referenceSet,
+  FastMKS(const MatType& referenceSet,
           const bool singleMode = false,
           const bool naive = false);
 
@@ -78,7 +82,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const typename TreeType::Mat& referenceSet,
+  FastMKS(const MatType& referenceSet,
           KernelType& kernel,
           const bool singleMode = false,
           const bool naive = false);
@@ -94,7 +98,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(TreeType* referenceTree,
+  FastMKS(Tree* referenceTree,
           const bool singleMode = false);
 
   //! Destructor for the FastMKS object.
@@ -120,7 +124,7 @@ class FastMKS
    * @param indices Matrix to store resulting indices of max-kernel search in.
    * @param kernels Matrix to store resulting max-kernel values in.
    */
-  void Search(const typename TreeType::Mat& querySet,
+  void Search(const MatType& querySet,
               const size_t k,
               arma::Mat<size_t>& indices,
               arma::mat& kernels);
@@ -147,7 +151,7 @@ class FastMKS
    * @param indices Matrix to store resulting indices of max-kernel search in.
    * @param kernels Matrix to store resulting max-kernel values in.
    */
-  void Search(TreeType* querySet,
+  void Search(Tree* querySet,
               const size_t k,
               arma::Mat<size_t>& indices,
               arma::mat& kernels);
@@ -187,9 +191,9 @@ class FastMKS
 
  private:
   //! The reference dataset.
-  const typename TreeType::Mat& referenceSet;
+  const MatType& referenceSet;
   //! The tree built on the reference dataset.
-  TreeType* referenceTree;
+  Tree* referenceTree;
   //! If true, this object created the tree and is responsible for it.
   bool treeOwner;
 
@@ -210,8 +214,8 @@ class FastMKS
                       const double distance);
 };
 
-}; // namespace fastmks
-}; // namespace mlpack
+} // namespace fastmks
+} // namespace mlpack
 
 // Include implementation.
 #include "fastmks_impl.hpp"
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 1835d8d..ead5d2e 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -19,10 +19,14 @@ namespace mlpack {
 namespace fastmks {
 
 // No instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
-                                       const bool singleMode,
-                                       const bool naive) :
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(
+    const MatType& referenceSet,
+    const bool singleMode,
+    const bool naive) :
     referenceSet(referenceSet),
     referenceTree(NULL),
     treeOwner(true),
@@ -32,17 +36,20 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
   Timer::Start("tree_building");
 
   if (!naive)
-    referenceTree = new TreeType(referenceSet);
+    referenceTree = new Tree(referenceSet);
 
   Timer::Stop("tree_building");
 }
 
 // Instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
-                                       KernelType& kernel,
-                                       const bool singleMode,
-                                       const bool naive) :
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(const MatType& referenceSet,
+                                                KernelType& kernel,
+                                                const bool singleMode,
+                                                const bool naive) :
     referenceSet(referenceSet),
     referenceTree(NULL),
     treeOwner(true),
@@ -54,15 +61,18 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
 
   // If necessary, the reference tree should be built.  There is no query tree.
   if (!naive)
-    referenceTree = new TreeType(referenceSet, metric);
+    referenceTree = new Tree(referenceSet, metric);
 
   Timer::Stop("tree_building");
 }
 
 // One dataset, pre-built tree.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
-                                       const bool singleMode) :
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(Tree* referenceTree,
+                                                const bool singleMode) :
     referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
     treeOwner(false),
@@ -73,17 +83,23 @@ FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
   // Nothing to do.
 }
 
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::~FastMKS()
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+FastMKS<KernelType, MatType, TreeType>::~FastMKS()
 {
   // If we created the trees, we must delete them.
   if (treeOwner && referenceTree)
     delete referenceTree;
 }
 
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(
-    const typename TreeType::Mat& querySet,
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+    const MatType& querySet,
     const size_t k,
     arma::Mat<size_t>& indices,
     arma::mat& kernels)
@@ -132,10 +148,10 @@ void FastMKS<KernelType, TreeType>::Search(
 
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
-    typedef FastMKSRules<KernelType, TreeType> RuleType;
+    typedef FastMKSRules<KernelType, Tree> RuleType;
     RuleType rules(referenceSet, querySet, indices, kernels, metric.Kernel());
 
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     for (size_t i = 0; i < querySet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
@@ -151,17 +167,21 @@ void FastMKS<KernelType, TreeType>::Search(
   // assuming it doesn't map anything...
   Timer::Stop("computing_products");
   Timer::Start("tree_building");
-  TreeType queryTree(querySet);
+  Tree queryTree(querySet);
   Timer::Stop("tree_building");
 
   Search(&queryTree, k, indices, kernels);
 }
 
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
-                                           const size_t k,
-                                           arma::Mat<size_t>& indices,
-                                           arma::mat& kernels)
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+    Tree* queryTree,
+    const size_t k,
+    arma::Mat<size_t>& indices,
+    arma::mat& kernels)
 {
   // If either naive mode or single mode is specified, this must fail.
   if (naive || singleMode)
@@ -176,11 +196,11 @@ void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
   kernels.fill(-DBL_MAX);
 
   Timer::Start("computing_products");
-  typedef FastMKSRules<KernelType, TreeType> RuleType;
+  typedef FastMKSRules<KernelType, Tree> RuleType;
   RuleType rules(referenceSet, queryTree->Dataset(), indices, kernels,
       metric.Kernel());
 
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
   traverser.Traverse(*queryTree, *referenceTree);
 
@@ -190,10 +210,14 @@ void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
   Timer::Stop("computing_products");
 }
 
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(const size_t k,
-                                           arma::Mat<size_t>& indices,
-                                           arma::mat& kernels)
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+    const size_t k,
+    arma::Mat<size_t>& indices,
+    arma::mat& kernels)
 {
   // No remapping will be necessary because we are using the cover tree.
   Timer::Start("computing_products");
@@ -236,11 +260,11 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
   {
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
-    typedef FastMKSRules<KernelType, TreeType> RuleType;
+    typedef FastMKSRules<KernelType, Tree> RuleType;
     RuleType rules(referenceSet, referenceSet, indices, kernels,
         metric.Kernel());
 
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     for (size_t i = 0; i < referenceSet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
@@ -271,13 +295,17 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
-                                                   arma::mat& products,
-                                                   const size_t queryIndex,
-                                                   const size_t pos,
-                                                   const size_t neighbor,
-                                                   const double distance)
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
+    arma::Mat<size_t>& indices,
+    arma::mat& products,
+    const size_t queryIndex,
+    const size_t pos,
+    const size_t neighbor,
+    const double distance)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (products.n_rows - 1))
@@ -297,8 +325,11 @@ void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
 }
 
 // Return string of object.
-template<typename KernelType, typename TreeType>
-std::string FastMKS<KernelType, TreeType>::ToString() const
+template<typename KernelType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+std::string FastMKS<KernelType, MatType, TreeType>::ToString() const
 {
   std::ostringstream convert;
   convert << "FastMKS [" << this << "]" << std::endl;
diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index 1e0dc94..dc1cdf8 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -88,13 +88,13 @@ void RunFastMKS(const arma::mat& referenceData,
   else
   {
     // Create the tree with the specified base.
-    typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
-        TreeType;
+    typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
+        FirstPointIsRoot> TreeType;
     IPMetric<KernelType> metric(kernel);
     TreeType tree(referenceData, metric, base);
 
     // Create FastMKS object.
-    FastMKS<KernelType> fastmks(&tree, single);
+    FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&tree, single);
 
     // Now search with it.
     fastmks.Search(k, indices, kernels);
@@ -122,14 +122,15 @@ void RunFastMKS(const arma::mat& referenceData,
   else
   {
     // Create the tree with the specified base.
-    typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
-        TreeType;
+    typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
+        FirstPointIsRoot> TreeType;
     IPMetric<KernelType> metric(kernel);
     TreeType referenceTree(referenceData, metric, base);
     TreeType queryTree(queryData, metric, base);
 
     // Create FastMKS object.
-    FastMKS<KernelType> fastmks(&referenceTree, single);
+    FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&referenceTree,
+        single);
 
     // Now search with it.
     fastmks.Search(&queryTree, k, indices, kernels);



More information about the mlpack-git mailing list