[mlpack-git] master: Add support for Greedy Single Tree Search, inside NeighborSearch class. (1c855b0)

gitdub at mlpack.org gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e

>---------------------------------------------------------------

commit 1c855b0323c7f736daa32f766871b83809f78678
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Aug 16 03:58:55 2016 -0300

    Add support for Greedy Single Tree Search, inside NeighborSearch class.


>---------------------------------------------------------------

1c855b0323c7f736daa32f766871b83809f78678
 .../methods/neighbor_search/neighbor_search.hpp    |  3 +-
 .../neighbor_search/neighbor_search_impl.hpp       | 71 ++++++++++++++++------
 .../neighbor_search/neighbor_search_rules.hpp      |  8 +++
 .../neighbor_search/neighbor_search_rules_impl.hpp |  8 +++
 4 files changed, 72 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 88daf5f..beac1bf 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -78,7 +78,8 @@ class NeighborSearch
   {
     NAIVE_MODE,
     SINGLE_TREE_MODE,
-    DUAL_TREE_MODE
+    DUAL_TREE_MODE,
+    GREEDY_SINGLE_TREE_MODE
   };
 
   /**
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index f769de1..7bb1589 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -9,7 +9,7 @@
 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_IMPL_HPP
 
 #include <mlpack/core.hpp>
-
+#include <mlpack/core/tree/greedy_single_tree_traverser.hpp>
 #include "neighbor_search_rules.hpp"
 #include <mlpack/core/tree/spill_tree/is_spill_tree.hpp>
 
@@ -84,15 +84,13 @@ SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
         &referenceTree->Dataset()),
     treeOwner(mode != NAIVE_MODE),
     setOwner(false),
-    searchMode(mode),
-    naive(mode == NAIVE_MODE),
-    singleMode(mode == SINGLE_TREE_MODE),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  SetSearchMode(mode);
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -118,15 +116,13 @@ SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
         &referenceTree->Dataset()),
     treeOwner(mode != NAIVE_MODE),
     setOwner(mode == NAIVE_MODE),
-    searchMode(mode),
-    naive(mode == NAIVE_MODE),
-    singleMode(mode == SINGLE_TREE_MODE),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  SetSearchMode(mode);
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -149,15 +145,13 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
     setOwner(false),
-    searchMode(mode),
-    naive(mode == NAIVE_MODE),
-    singleMode(mode == SINGLE_TREE_MODE),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  SetSearchMode(mode);
   if (mode == NAIVE_MODE)
     throw std::invalid_argument("invalid constructor for naive mode");
   if (epsilon < 0)
@@ -181,15 +175,13 @@ SingleTreeTraversalType>::NeighborSearch(const SearchMode mode,
     referenceSet(new MatType()), // Empty matrix.
     treeOwner(false),
     setOwner(true),
-    searchMode(mode),
-    naive(mode == NAIVE_MODE),
-    singleMode(mode == SINGLE_TREE_MODE),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  SetSearchMode(mode);
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
   // Build the tree on the empty dataset, if necessary.
@@ -618,6 +610,29 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
       delete queryTree;
       break;
     }
+    case GREEDY_SINGLE_TREE_MODE:
+    {
+      // Create the helper object for the tree traversal.
+      RuleType rules(*referenceSet, querySet, k, metric);
+
+      // Create the traverser.
+      tree::GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
+
+      // Now have it traverse for each point.
+      for (size_t i = 0; i < querySet.n_cols; ++i)
+        traverser.Traverse(i, *referenceTree);
+
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
+
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
+
+      rules.GetResults(*neighborPtr, *distancePtr);
+      break;
+    }
   }
 
   Timer::Stop("computing_neighbors");
@@ -814,8 +829,8 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
-      true /* don't return the same point as nearest neighbor */);
+  RuleType rules(*referenceSet, *referenceSet, k, metric, true
+      /* don't return the same point as nearest neighbor */);
 
   switch (searchMode)
   {
@@ -898,6 +913,24 @@ DualTreeTraversalType, SingleTreeTraversalType>::Search(
       treeNeedsReset = true;
       break;
     }
+    case GREEDY_SINGLE_TREE_MODE:
+    {
+      // Create the traverser.
+      tree::GreedySingleTreeTraverser<Tree, RuleType> traverser(rules);
+
+      // Now have it traverse for each point.
+      for (size_t i = 0; i < referenceSet->n_cols; ++i)
+        traverser.Traverse(i, *referenceTree);
+
+      scores += rules.Scores();
+      baseCases += rules.BaseCases();
+
+      Log::Info << rules.Scores() << " node combinations were scored."
+          << std::endl;
+      Log::Info << rules.BaseCases() << " base cases were calculated."
+          << std::endl;
+      break;
+    }
   }
 
   rules.GetResults(*neighborPtr, *distancePtr);
@@ -1112,6 +1145,10 @@ DualTreeTraversalType, SingleTreeTraversalType>::SetSearchMode(
       naive = false;
       singleMode = false;
       break;
+    case GREEDY_SINGLE_TREE_MODE:
+      naive = false;
+      singleMode = true;
+      break;
   }
 }
 
@@ -1131,9 +1168,9 @@ DualTreeTraversalType, SingleTreeTraversalType>::UpdateSearchMode()
 {
   if (naive)
     searchMode = NAIVE_MODE;
-  else if (singleMode)
+  else if (singleMode && (searchMode != GREEDY_SINGLE_TREE_MODE))
     searchMode = SINGLE_TREE_MODE;
-  else
+  else if (!singleMode)
     searchMode = DUAL_TREE_MODE;
 }
 
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index e7a7ce1..55d2163 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -77,6 +77,14 @@ class NeighborSearchRules
   double Score(const size_t queryIndex, TreeType& referenceNode);
 
   /**
+   * Get the child node with the best score.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   */
+  TreeType& GetBestChild(const size_t queryIndex, TreeType& referenceNode);
+
+  /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
    * for recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).  This is used when the score has already
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 70f0bc0..e06683a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -146,6 +146,14 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+inline TreeType& NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+GetBestChild(const size_t queryIndex, TreeType& referenceNode)
+{
+  ++scores;
+  return SortPolicy::GetBestChild(querySet.col(queryIndex), referenceNode);
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     const size_t queryIndex,
     TreeType& /* referenceNode */,




More information about the mlpack-git mailing list