[mlpack-git] master, mlpack-1.0.x: Refactoring from Saheb: don't do naive search with trees. (1e92e90)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:43:55 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 1e92e906755c693063b8db9971f86a837196bb30
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 13 19:57:12 2014 +0000

    Refactoring from Saheb: don't do naive search with trees.


>---------------------------------------------------------------

1e92e906755c693063b8db9971f86a837196bb30
 src/mlpack/methods/rann/ra_search.hpp       |  0
 src/mlpack/methods/rann/ra_search_impl.hpp  | 45 ++++++++++++++++++++++++-----
 src/mlpack/methods/rann/ra_search_rules.hpp |  5 ++++
 3 files changed, 43 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index 4389650..dd46d18 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -41,11 +41,12 @@ RASearch(const typename TreeType::Mat& referenceSet,
   Timer::Start("tree_building");
 
   // Construct as a naive object if we need to.
-  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-      (naive ? referenceCopy.n_cols : leafSize));
+  if (!naive)
+  {
+    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
 
-  queryTree = new TreeType(queryCopy, oldFromNewQueries,
-      (naive ? querySet.n_cols : leafSize));
+    queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+  }
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -75,8 +76,10 @@ RASearch(const typename TreeType::Mat& referenceSet,
   Timer::Start("tree_building");
 
   // Construct as a naive object if we need to.
-  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-      (naive ? referenceSet.n_cols : leafSize));
+  if (!naive)
+  {
+    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
+  }
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -177,7 +180,35 @@ Search(const size_t k,
 
   size_t numPrunes = 0;
 
-  if (singleMode || naive)
+  if (naive)
+  {
+    // We don't need to run the base case on every possible combination of
+    // points; we can achieve the rank approximation guarantee with probability
+    // alpha by sampling the reference set.
+    typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+                   metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+                   singleSampleLimit);
+
+    // Find how many samples from the reference set we need and sample uniformly
+    // from the reference set without replacement.
+    const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+        tau, alpha);
+    arma::uvec distinctSamples;
+    rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+        distinctSamples);
+
+    // Run the base case on each combination of query point and sampled
+    // reference point.
+    for (size_t i = 0; i < querySet.n_cols; ++i)
+    {
+      for (size_t j = 0; j < distinctSamples.n_elem; ++j)
+      {
+        rules.BaseCase(i, (size_t) distinctSamples[j]);
+      }
+    }
+  }
+  else if (singleMode)
   {
     // Create the helper object for the tree traversal.  Initialization of
     // RASearchRules already implicitly performs the naive tree traversal.
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 453a9bf..e53be47 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -10,6 +10,7 @@
 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
 #include "../neighbor_search/ns_traversal_info.hpp"
+#include "ra_search.hpp" // For friend declaration.
 
 namespace mlpack {
 namespace neighbor {
@@ -323,6 +324,10 @@ class RASearchRules
                const double distance,
                const double bestDistance);
 
+  // So that RASearch can access ObtainDistinctSamples() and
+  // MinimumSamplesReqd().  Maybe refactoring is a better solution but this is
+  // okay for now.
+  friend class RASearch<SortPolicy, MetricType, TreeType>;
 }; // class RASearchRules
 
 }; // namespace neighbor



More information about the mlpack-git mailing list