[mlpack-svn] r16295 - mlpack/trunk/src/mlpack/methods/rann

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 13 14:57:13 EST 2014


Author: rcurtin
Date: Thu Feb 13 14:57:12 2014
New Revision: 16295

Log:
Refactoring from Saheb: don't do naive search with trees.


Modified:
   mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp
   mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp	Thu Feb 13 14:57:12 2014
@@ -287,7 +287,7 @@
    */
   void ResetQueryTree();
 
-  // Returns a string representation of this object. 
+  // Returns a string representation of this object.
   std::string ToString() const;
 
  private:

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp	Thu Feb 13 14:57:12 2014
@@ -41,11 +41,12 @@
   Timer::Start("tree_building");
 
   // Construct as a naive object if we need to.
-  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-      (naive ? referenceCopy.n_cols : leafSize));
+  if (!naive)
+  {
+    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
 
-  queryTree = new TreeType(queryCopy, oldFromNewQueries,
-      (naive ? querySet.n_cols : leafSize));
+    queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+  }
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -75,8 +76,10 @@
   Timer::Start("tree_building");
 
   // Construct as a naive object if we need to.
-  referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-      (naive ? referenceSet.n_cols : leafSize));
+  if (!naive)
+  {
+    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
+  }
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -177,7 +180,35 @@
 
   size_t numPrunes = 0;
 
-  if (singleMode || naive)
+  if (naive)
+  {
+    // We don't need to run the base case on every possible combination of
+    // points; we can achieve the rank approximation guarantee with probability
+    // alpha by sampling the reference set.
+    typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+                   metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+                   singleSampleLimit);
+
+    // Find how many samples from the reference set we need and sample uniformly
+    // from the reference set without replacement.
+    const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+        tau, alpha);
+    arma::uvec distinctSamples;
+    rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+        distinctSamples);
+
+    // Run the base case on each combination of query point and sampled
+    // reference point.
+    for (size_t i = 0; i < querySet.n_cols; ++i)
+    {
+      for (size_t j = 0; j < distinctSamples.n_elem; ++j)
+      {
+        rules.BaseCase(i, (size_t) distinctSamples[j]);
+      }
+    }
+  }
+  else if (singleMode)
   {
     // Create the helper object for the tree traversal.  Initialization of
     // RASearchRules already implicitly performs the naive tree traversal.
@@ -363,13 +394,13 @@
   convert << "  Reference Set: " << referenceSet.n_rows << "x" ;
   convert <<  referenceSet.n_cols << std::endl;
   if (&referenceSet != &querySet)
-    convert << "  QuerySet: " << querySet.n_rows << "x" << querySet.n_cols 
+    convert << "  QuerySet: " << querySet.n_rows << "x" << querySet.n_cols
         << std::endl;
-  if (naive)  
+  if (naive)
     convert << "  Naive: TRUE" << std::endl;
   if (singleMode)
     convert << "  Single Node: TRUE" << std::endl;
-  convert << "  Metric: " << std::endl << 
+  convert << "  Metric: " << std::endl <<
       mlpack::util::Indent(metric.ToString(),2);
   return convert.str();
 }

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules.hpp	Thu Feb 13 14:57:12 2014
@@ -10,6 +10,7 @@
 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
 #include "../neighbor_search/ns_traversal_info.hpp"
+#include "ra_search.hpp" // For friend declaration.
 
 namespace mlpack {
 namespace neighbor {
@@ -323,6 +324,10 @@
                const double distance,
                const double bestDistance);
 
+  // So that RASearch can access ObtainDistinctSamples() and
+  // MinimumSamplesReqd().  Maybe refactoring is a better solution but this is
+  // okay for now.
+  friend class RASearch<SortPolicy, MetricType, TreeType>;
 }; // class RASearchRules
 
 }; // namespace neighbor



More information about the mlpack-svn mailing list