[mlpack-svn] r10738 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 13 03:35:28 EST 2011


Author: rcurtin
Date: 2011-12-13 03:35:27 -0500 (Tue, 13 Dec 2011)
New Revision: 10738

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Why does this fix the test?  I don't know, but there's not complete time to fix
it right now.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-13 07:23:09 UTC (rev 10737)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-13 08:35:27 UTC (rev 10738)
@@ -174,44 +174,45 @@
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void BaseCase(TreeType* referenceNode,
-                TreeType* queryNode,
-                arma::Mat<size_t>& neighbors,
-                arma::mat& distances);
+  void ComputeBaseCase(TreeType* queryNode,
+                       TreeType* referenceNode,
+                       arma::Mat<size_t>& neighbors,
+                       arma::mat& distances);
 
   /**
    * Recurse down the trees, computing base case computations when the leaves
    * are reached.
    *
+   * @param queryNode Node in query tree.
    * @param referenceNode Node in reference tree.
-   * @param queryNode Node in query tree.
    * @param lowerBound The lower bound; if above this, we can prune.
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void DualTreeRecursion(TreeType* referenceNode,
-                         TreeType* queryNode,
-                         const double lowerBound,
-                         arma::Mat<size_t>& neighbors,
-                         arma::mat& distances);
+  void ComputeDualNeighborsRecursion(TreeType* queryNode,
+                                     TreeType* referenceNode,
+                                     const double lowerBound,
+                                     arma::Mat<size_t>& neighbors,
+                                     arma::mat& distances);
 
   /**
    * Perform a recursion only on the reference tree; the query point is given.
-   * This method is similar to BaseCase().
+   * This method is similar to ComputeBaseCase().
    *
+   * @param pointId Index of query point.
+   * @param point The query point.
    * @param referenceNode Reference node.
-   * @param queryPoint The query point.
-   * @param queryIndex Index of query point.
    * @param bestDistSoFar Best distance to a node so far -- used for pruning.
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void SingleTreeRecursion(TreeType* referenceNode,
-                           const arma::vec& queryPoint,
-                           const size_t queryIndex,
-                           double& bestDistSoFar,
-                           arma::Mat<size_t>& neighbors,
-                           arma::mat& distances);
+  template<typename VecType>
+  void ComputeSingleNeighborsRecursion(const size_t pointId,
+                                       const VecType& point,
+                                       TreeType* referenceNode,
+                                       double& bestDistSoFar,
+                                       arma::Mat<size_t>& neighbors,
+                                       arma::mat& distances);
 
   /**
    * Insert a point into the neighbors and distances matrices; this is a helper

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-13 07:23:09 UTC (rev 10737)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-13 08:35:27 UTC (rev 10738)
@@ -154,19 +154,35 @@
   {
     // Run the base case computation on all nodes
     if (queryTree)
-      BaseCase(referenceTree, queryTree, *neighborPtr, *distancePtr);
+      ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
     else
-      BaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+      ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
   }
   else
   {
     if (singleMode)
     {
-      for (size_t i = 0; i < querySet.n_cols; i++)
+      // Do one tenth of the query set at a time.
+      size_t chunk = querySet.n_cols / 10;
+
+      for (size_t i = 0; i < 10; i++)
       {
+        for (size_t j = 0; j < chunk; j++)
+        {
+          double worstDistance = SortPolicy::WorstDistance();
+          ComputeSingleNeighborsRecursion(i * chunk + j,
+              querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
+              *neighborPtr, *distancePtr);
+        }
+      }
+
+      // The last tenth is differently sized...
+      for (size_t i = 0; i < querySet.n_cols % 10; i++)
+      {
+        size_t ind = (querySet.n_cols / 10) * 10 + i;
         double worstDistance = SortPolicy::WorstDistance();
-        SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), i,
-            worstDistance, *neighborPtr, *distancePtr);
+        ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
+            referenceTree, worstDistance, *neighborPtr, *distancePtr);
       }
     }
     else // Dual-tree recursion.
@@ -174,13 +190,13 @@
       // Start on the root of each tree.
       if (queryTree)
       {
-        DualTreeRecursion(referenceTree, queryTree,
+        ComputeDualNeighborsRecursion(queryTree, referenceTree,
             SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
             *neighborPtr, *distancePtr);
       }
       else
       {
-        DualTreeRecursion(referenceTree, referenceTree,
+        ComputeDualNeighborsRecursion(referenceTree, referenceTree,
             SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
             *neighborPtr, *distancePtr);
       }
@@ -275,13 +291,13 @@
     delete neighborPtr;
     delete distancePtr;
   }
-} // ComputeNeighbors
+} // Search
 
 /**
  * Performs exhaustive computation between two leaves.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
       TreeType* queryNode,
       TreeType* referenceNode,
       arma::Mat<size_t>& neighbors,
@@ -338,15 +354,16 @@
   // Update the upper bound for the queryNode
   queryNode->Stat().Bound() = queryWorstDistance;
 
-} // BaseCase()
+} // ComputeBaseCase()
 
 /**
  * The recursive function for dual tree.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::DualTreeRecursion(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeDualNeighborsRecursion(
+    TreeType* queryNode,
     TreeType* referenceNode,
-    TreeType* queryNode,
     const double lowerBound,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
@@ -360,7 +377,7 @@
   if (queryNode->IsLeaf() && referenceNode->IsLeaf())
   {
     // Base case: both are leaves.
-    BaseCase(referenceNode, queryNode, neighbors, distances);
+    ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
     return;
   }
 
@@ -377,16 +394,16 @@
 
     if (SortPolicy::IsBetter(leftDistance, rightDistance))
     {
-      DualTreeRecursion(referenceNode->Left(), queryNode,
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
           leftDistance, neighbors, distances);
-      DualTreeRecursion(referenceNode->Right(), queryNode,
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
           rightDistance, neighbors, distances);
     }
     else
     {
-      DualTreeRecursion(referenceNode->Right(), queryNode,
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
           rightDistance, neighbors, distances);
-      DualTreeRecursion(referenceNode->Left(), queryNode,
+      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
           leftDistance, neighbors, distances);
     }
     return;
@@ -400,9 +417,9 @@
     double rightDistance = SortPolicy::BestNodeToNodeDistance(
         queryNode->Right(), referenceNode);
 
-    DualTreeRecursion(referenceNode, queryNode->Left(),
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
         leftDistance, neighbors, distances);
-    DualTreeRecursion(referenceNode, queryNode->Right(),
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
         rightDistance, neighbors, distances);
 
     // We need to update the upper bound based on the new upper bounds of the
@@ -428,16 +445,16 @@
   // Recurse on queryNode->left() first.
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
         leftDistance, neighbors, distances);
-    DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
         rightDistance, neighbors, distances);
   }
   else
   {
-    DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
         rightDistance, neighbors, distances);
-    DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
         leftDistance, neighbors, distances);
   }
 
@@ -449,16 +466,16 @@
   // Now recurse on queryNode->right().
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
         leftDistance, neighbors, distances);
-    DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
         rightDistance, neighbors, distances);
   }
   else
   {
-    DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
         rightDistance, neighbors, distances);
-    DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
         leftDistance, neighbors, distances);
   }
 
@@ -474,13 +491,14 @@
 } // ComputeDualNeighborsRecursion()
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
-    TreeType* referenceNode,
-    const arma::vec& queryPoint,
-    const size_t queryIndex,
-    double& bestDistSoFar,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
+template<typename VecType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeSingleNeighborsRecursion(const size_t pointId,
+                                const VecType& point,
+                                TreeType* referenceNode,
+                                double& bestDistSoFar,
+                                arma::Mat<size_t>& neighbors,
+                                arma::mat& distances)
 {
   if (referenceNode->IsLeaf())
   {
@@ -490,31 +508,32 @@
     {
       // Confirm that points do not identify themselves as neighbors
       // in the monochromatic case
-      if (queryTree || (referenceIndex != queryIndex))
+      if (!(referenceSet.memptr() == querySet.memptr() &&
+            referenceIndex == pointId))
       {
         arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
 
-        double distance = metric.Evaluate(queryPoint, referencePoint);
+        double distance = metric.Evaluate(point, referencePoint);
 
         // If the reference point is better than any of the current candidates,
         // insert it into the list correctly.
-        arma::vec queryDist = distances.unsafe_col(queryIndex);
+        arma::vec queryDist = distances.unsafe_col(pointId);
         size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
 
         if (insertPosition != (size_t() - 1))
-          InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance,
+          InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
               neighbors, distances);
       }
     } // for referenceIndex
 
-    bestDistSoFar = distances(distances.n_rows - 1, queryIndex);
+    bestDistSoFar = distances(distances.n_rows - 1, pointId);
   }
   else
   {
     // We'll order the computation by distance.
-    double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+    double leftDistance = SortPolicy::BestPointToNodeDistance(point,
         referenceNode->Left());
-    double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+    double rightDistance = SortPolicy::BestPointToNodeDistance(point,
         referenceNode->Right());
 
     // Recurse in the best direction first.
@@ -523,13 +542,13 @@
       if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
             bestDistSoFar, neighbors, distances);
 
       if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
             bestDistSoFar, neighbors, distances);
 
     }
@@ -538,13 +557,13 @@
       if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
             bestDistSoFar, neighbors, distances);
 
       if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
             bestDistSoFar, neighbors, distances);
     }
   }
@@ -559,13 +578,13 @@
  * @param distance Distance from query point to reference point.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
-    const size_t queryIndex,
-    const size_t pos,
-    const size_t neighbor,
-    const double distance,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(const size_t queryIndex,
+               const size_t pos,
+               const size_t neighbor,
+               const double distance,
+               arma::Mat<size_t>& neighbors,
+               arma::mat& distances)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (distances.n_rows - 1))




More information about the mlpack-svn mailing list