[mlpack-svn] r15919 - mlpack/trunk/src/mlpack/methods/rann

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 4 00:30:09 EDT 2013


Author: rcurtin
Date: Fri Oct  4 00:30:09 2013
New Revision: 15919

Log:
Overhaul RASearchRules so that they will work correctly with cover trees.


Modified:
   mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp	Fri Oct  4 00:30:09 2013
@@ -165,11 +165,11 @@
 
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-SuccessProbability(const size_t n,
-                   const size_t k,
-                   const size_t m,
-                   const size_t t) const
+double RASearchRules<SortPolicy, MetricType, TreeType>::SuccessProbability(
+    const size_t n,
+    const size_t k,
+    const size_t m,
+    const size_t t) const
 {
   if (k == 1)
   {
@@ -262,8 +262,9 @@
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline force_inline
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
   // and we should not return identical points.
@@ -290,32 +291,10 @@
   return distance;
 }
 
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Prescore(TreeType& queryNode,
-         TreeType& referenceNode,
-         TreeType& referenceChildNode,
-         const double baseCaseResult) const
-{
-  return 0.0;
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-PrescoreQ(TreeType& queryNode,
-          TreeType& queryChildNode,
-          TreeType& referenceNode,
-          const double baseCaseResult) const
-{
-  return 0.0;
-}
-
-
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex, TreeType& referenceNode)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode)
 {
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
@@ -326,10 +305,10 @@
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex,
-      TreeType& referenceNode,
-      const double baseCaseResult)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode,
+    const double baseCaseResult)
 {
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
@@ -340,11 +319,11 @@
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex,
-      TreeType& referenceNode,
-      const double distance,
-      const double bestDistance)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode,
+    const double distance,
+    const double bestDistance)
 {
   // If this is better than the best distance we've seen so far, maybe there
   // will be something down this node.  Also check if enough samples are already
@@ -360,7 +339,7 @@
     {
       // Check if this node can be approximated by sampling.
       size_t samplesReqd = (size_t) std::ceil(samplingRatio *
-          (double) referenceNode.Count());
+          (double) referenceNode.NumDescendants());
       samplesReqd = std::min(samplesReqd,
           numSamplesReqd - numSamplesMade[queryIndex]);
 
@@ -376,13 +355,12 @@
           // Then samplesReqd <= singleSampleLimit.
           // Hence, approximate the node by sampling enough number of points.
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                 distinctSamples);
           for (size_t i = 0; i < distinctSamples.n_elem; i++)
             // The counting of the samples are done in the 'BaseCase' function
             // so no book-keeping is required here.
-            BaseCase(queryIndex,
-                     referenceNode.Begin() + (size_t) distinctSamples[i]);
+            BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
 
           // Node approximated, so we can prune it.
           return DBL_MAX;
@@ -393,15 +371,15 @@
           {
             // Approximate node by sampling enough number of points.
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                   distinctSamples);
             for (size_t i = 0; i < distinctSamples.n_elem; i++)
               // The counting of the samples are done in the 'BaseCase' function
               // so no book-keeping is required here.
               BaseCase(queryIndex,
-                       referenceNode.Begin() + (size_t) distinctSamples[i]);
+                  referenceNode.Descendant(distinctSamples[i]));
 
-            // (Leaf) node approximated so can prune it.
+            // (Leaf) node approximated, so we can prune it.
             return DBL_MAX;
           }
           else
@@ -429,8 +407,8 @@
 
     // If enough samples are already made, this step does not change the result
     // of the search.
-    numSamplesMade[queryIndex] +=
-      (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+    numSamplesMade[queryIndex] += (size_t) std::floor(
+        samplingRatio * (double) referenceNode.NumDescendants());
 
     return DBL_MAX;
   }
@@ -464,7 +442,7 @@
 
     // Check if this node can be approximated by sampling.
     size_t samplesReqd = (size_t) std::ceil(samplingRatio *
-        (double) referenceNode.Count());
+        (double) referenceNode.NumDescendants());
     samplesReqd = std::min(samplesReqd, numSamplesReqd -
         numSamplesMade[queryIndex]);
 
@@ -481,13 +459,12 @@
         // Then, samplesReqd <= singleSampleLimit.  Hence, approximate the node
         // by sampling enough number of points.
         arma::uvec distinctSamples;
-        ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+        ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
             distinctSamples);
         for (size_t i = 0; i < distinctSamples.n_elem; i++)
           // The counting of the samples are done in the 'BaseCase' function so
           // no book-keeping is required here.
-          BaseCase(queryIndex, referenceNode.Begin() + (size_t)
-              distinctSamples[i]);
+          BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
 
         // Node approximated, so we can prune it.
         return DBL_MAX;
@@ -498,13 +475,12 @@
         {
           // Approximate node by sampling enough points.
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                 distinctSamples);
           for (size_t i = 0; i < distinctSamples.n_elem; i++)
             // The counting of the samples are done in the 'BaseCase' function
             // so no book-keeping is required here.
-            BaseCase(queryIndex,
-                     referenceNode.Begin() + (size_t) distinctSamples[i]);
+            BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
 
           // (Leaf) node approximated, so we can prune it.
           return DBL_MAX;
@@ -526,7 +502,7 @@
     // these samples need not be computed.  If enough samples are already made,
     // this step does not change the result of the search.
     numSamplesMade[queryIndex] += (size_t) std::floor(samplingRatio *
-        (double) referenceNode.Count());
+        (double) referenceNode.NumDescendants());
 
     return DBL_MAX;
   }
@@ -589,7 +565,7 @@
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
     const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
-      + maxDescendantDistance;
+        + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
   }
@@ -656,7 +632,7 @@
     {
       // Check if this node can be approximated by sampling.
       size_t samplesReqd = (size_t) std::ceil(samplingRatio
-          * (double) referenceNode.Count());
+          * (double) referenceNode.NumDescendants());
       samplesReqd = std::min(samplesReqd, numSamplesReqd -
           queryNode.Stat().NumSamplesMade());
 
@@ -682,17 +658,17 @@
         {
           // Then samplesReqd <= singleSampleLimit.  Hence, approximate node by
           // sampling enough number of points for every query in the query node.
-          for (size_t queryIndex = queryNode.Begin();
-              queryIndex < queryNode.End(); queryIndex++)
+          for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
           {
+            const size_t queryIndex = queryNode.Descendant(i);
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                   distinctSamples);
-            for (size_t i = 0; i < distinctSamples.n_elem; i++)
+            for (size_t j = 0; j < distinctSamples.n_elem; j++)
               // The counting of the samples are done in the 'BaseCase' function
               // so no book-keeping is required here.
-              BaseCase(queryIndex, referenceNode.Begin() + (size_t)
-                  distinctSamples[i]);
+              BaseCase(queryIndex,
+                  referenceNode.Descendant(distinctSamples[j]));
           }
 
           // Update the number of samples made for the queryNode and also update
@@ -712,17 +688,17 @@
           {
             // Approximate node by sampling enough number of points for every
             // query in the query node.
-            for (size_t queryIndex = queryNode.Begin();
-                 queryIndex < queryNode.End(); queryIndex++)
+            for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
             {
+              const size_t queryIndex = queryNode.Descendant(i);
               arma::uvec distinctSamples;
-              ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+              ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                     distinctSamples);
-              for (size_t i = 0; i < distinctSamples.n_elem; i++)
+              for (size_t j = 0; j < distinctSamples.n_elem; j++)
                 // The counting of the samples are done in the 'BaseCase'
                 // function so no book-keeping is required here.
-                BaseCase(queryIndex, referenceNode.Begin() +
-                    (size_t) distinctSamples[i]);
+                BaseCase(queryIndex,
+                    referenceNode.Descendant(distinctSamples[j]));
             }
 
             // Update the number of samples made for the queryNode and also
@@ -776,7 +752,7 @@
     // this step does not change the result of the search since this queryNode
     // will never be descended anymore.
     queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
-        (double) referenceNode.Count());
+        (double) referenceNode.NumDescendants());
 
     // Since we are not going to descend down the query tree for this reference
     // node, there is no point updating the number of samples made for the child
@@ -803,7 +779,7 @@
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
     const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
-      + maxDescendantDistance;
+        + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
   }
@@ -841,9 +817,8 @@
     // The number of samples made for a node is propagated up from the
     // child nodes if the child nodes have made samples that the parent
     // (which is the current 'queryNode') is not aware of.
-    queryNode.Stat().NumSamplesMade()
-      = std::max(queryNode.Stat().NumSamplesMade(),
-                 numSamplesMadeInChildNodes);
+    queryNode.Stat().NumSamplesMade() = std::max(
+        queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
   }
 
   // Now check if the node-pair interaction can be pruned by sampling
@@ -851,8 +826,8 @@
   // If this is better than the best distance we've seen so far,
   // maybe there will be something down this node.
   // Also check if enough samples are already made for this query.
-  if (SortPolicy::IsBetter(oldScore, bestDistance)
-      && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
+  if (SortPolicy::IsBetter(oldScore, bestDistance) &&
+      queryNode.Stat().NumSamplesMade() < numSamplesReqd)
   {
     // We cannot prune this node
     // Try approximating this node by sampling
@@ -863,27 +838,27 @@
     // So no checks regarding that is made any more.
     //
     // check if this node can be approximated by sampling
-    size_t samplesReqd =
-      (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
-    samplesReqd
-      = std::min(samplesReqd,
-                 numSamplesReqd - queryNode.Stat().NumSamplesMade());
+    size_t samplesReqd = (size_t) std::ceil(
+        samplingRatio * (double) referenceNode.NumDescendants());
+    samplesReqd  = std::min(samplesReqd,
+        numSamplesReqd - queryNode.Stat().NumSamplesMade());
 
     if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
     {
-      // if too many samples required and not at a leaf, then can't prune
+      // If too many samples are required and we are not at a leaf, then we
+      // can't prune.
 
-      // Since query tree descend is necessary now,
-      // propagate the number of samples made down to the children
+      // Since query tree descent is necessary now, propagate the number of
+      // samples made down to the children.
 
       // Go through all children and propagate the number of
       // samples made to the children.
       // Only update if the parent node has made samples the children
       // have not seen
       for (size_t i = 0; i < queryNode.NumChildren(); i++)
-        queryNode.Child(i).Stat().NumSamplesMade()
-          = std::max(queryNode.Stat().NumSamplesMade(),
-                     queryNode.Child(i).Stat().NumSamplesMade());
+        queryNode.Child(i).Stat().NumSamplesMade() = std::max(
+            queryNode.Stat().NumSamplesMade(),
+            queryNode.Child(i).Stat().NumSamplesMade());
 
       return oldScore;
     }
@@ -894,17 +869,16 @@
         // Then samplesReqd <= singleSampleLimit.
         // Hence approximate node by sampling enough number of points
         // for every query in the 'queryNode'
-        for (size_t queryIndex = queryNode.Begin();
-             queryIndex < queryNode.End(); queryIndex++)
+        for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
         {
+          const size_t queryIndex = queryNode.Descendant(i);
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
-                                distinctSamples);
-          for (size_t i = 0; i < distinctSamples.n_elem; i++)
+          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+              distinctSamples);
+          for (size_t j = 0; j < distinctSamples.n_elem; j++)
             // The counting of the samples are done in the 'BaseCase'
             // function so no book-keeping is required here.
-            BaseCase(queryIndex,
-                     referenceNode.Begin() + (size_t) distinctSamples[i]);
+            BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[j]));
         }
 
         // update the number of samples made for the queryNode and
@@ -924,17 +898,17 @@
         {
           // Approximate node by sampling enough number of points
           // for every query in the 'queryNode'.
-          for (size_t queryIndex = queryNode.Begin();
-               queryIndex < queryNode.End(); queryIndex++)
+          for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
           {
+            const size_t queryIndex = queryNode.Descendant(i);
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
                                   distinctSamples);
-            for (size_t i = 0; i < distinctSamples.n_elem; i++)
+            for (size_t j = 0; j < distinctSamples.n_elem; j++)
               // The counting of the samples are done in the 'BaseCase'
               // function so no book-keeping is required here.
               BaseCase(queryIndex,
-                       referenceNode.Begin() + (size_t) distinctSamples[i]);
+                  referenceNode.Descendant(distinctSamples[j]));
           }
 
           // Update the number of samples made for the query node and also
@@ -972,7 +946,7 @@
     // this step does not change the result of the search since this query node
     // will never be descended anymore.
     queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
-        (double) referenceNode.Count());
+        (double) referenceNode.NumDescendants());
 
     // Since we are not going to descend down the query tree for this reference
     // node, there is no point updating the number of samples made for the child



More information about the mlpack-svn mailing list