[mlpack-svn] r10720 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 11:43:39 EST 2011


Author: rcurtin
Date: 2011-12-12 11:43:39 -0500 (Mon, 12 Dec 2011)
New Revision: 10720

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Refactor API to be consistent with RangeSearch.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp	2011-12-12 16:43:39 UTC (rev 10720)
@@ -151,7 +151,7 @@
   }
 
   Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-  allkfn->ComputeNeighbors(k, neighbors, distances);
+  allkfn->Search(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2011-12-12 16:43:39 UTC (rev 10720)
@@ -153,7 +153,7 @@
   }
 
   Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-  allknn->ComputeNeighbors(k, neighbors, distances);
+  allknn->Search(k, neighbors, distances);
 
   Log::Info << "Neighbors computed." << endl;
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-12 16:43:39 UTC (rev 10720)
@@ -97,8 +97,8 @@
    * @param queryTree Optionally pass a pre-built tree for the query set.
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const arma::mat& referenceSet,
-                 const arma::mat& querySet,
+  NeighborSearch(const typename TreeType::Mat& referenceSet,
+                 const typename TreeType::Mat& querySet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const size_t leafSize = 20,
@@ -131,7 +131,7 @@
    *      set.
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const arma::mat& referenceSet,
+  NeighborSearch(const typename TreeType::Mat& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const size_t leafSize = 20,
@@ -156,9 +156,9 @@
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    */
-  void ComputeNeighbors(const size_t k,
-                        arma::Mat<size_t>& resultingNeighbors,
-                        arma::mat& distances);
+  void Search(const size_t k,
+              arma::Mat<size_t>& resultingNeighbors,
+              arma::mat& distances);
 
  private:
   /**
@@ -174,10 +174,10 @@
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void ComputeBaseCase(TreeType* queryNode,
-                       TreeType* referenceNode,
-                       arma::Mat<size_t>& neighbors,
-                       arma::mat& distances);
+  void BaseCase(TreeType* referenceNode,
+                TreeType* queryNode,
+                arma::Mat<size_t>& neighbors,
+                arma::mat& distances);
 
   /**
    * Recurse down the trees, computing base case computations when the leaves
@@ -189,11 +189,11 @@
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void ComputeDualNeighborsRecursion(TreeType* queryNode,
-                                     TreeType* referenceNode,
-                                     const double lowerBound,
-                                     arma::Mat<size_t>& neighbors,
-                                     arma::mat& distances);
+  void DualTreeRecursion(TreeType* referenceNode,
+                         TreeType* queryNode,
+                         const double lowerBound,
+                         arma::Mat<size_t>& neighbors,
+                         arma::mat& distances);
 
   /**
    * Perform a recursion only on the reference tree; the query point is given.
@@ -206,12 +206,13 @@
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  void ComputeSingleNeighborsRecursion(const size_t pointId,
-                                       const arma::vec& point,
-                                       TreeType* referenceNode,
-                                       double& bestDistSoFar,
-                                       arma::Mat<size_t>& neighbors,
-                                       arma::mat& distances);
+  template<typename VecType>
+  void SingleTreeRecursion(TreeType* referenceNode,
+                           const VecType& queryPoint,
+                           const size_t queryIndex,
+                           double& bestDistSoFar,
+                           arma::Mat<size_t>& neighbors,
+                           arma::mat& distances);
 
   /**
    * Insert a point into the neighbors and distances matrices; this is a helper
@@ -233,14 +234,14 @@
 
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
-  arma::mat referenceCopy;
+  typename TreeType::Mat referenceCopy;
   //! Copy of query dataset (if we need it, because tree building modifies it).
-  arma::mat queryCopy;
+  typename TreeType::Mat queryCopy;
 
-  //! Reference dataset.
-  const arma::mat& referenceSet;
-  //! Query dataset (may not be given).
-  const arma::mat& querySet;
+  //! Reference dataset (data should be accessed using this).
+  const typename TreeType::Mat& referenceSet;
+  //! Query dataset (data should be accessed using this).
+  const typename TreeType::Mat& querySet;
 
   //! Indicates if O(n^2) naive search is being used.
   bool naive;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-12 16:43:39 UTC (rev 10720)
@@ -15,8 +15,8 @@
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const arma::mat& referenceSet,
-               const arma::mat& querySet,
+NeighborSearch(const typename TreeType::Mat& referenceSet,
+               const typename TreeType::Mat& querySet,
                const bool naive,
                const bool singleMode,
                const size_t leafSize,
@@ -72,7 +72,7 @@
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const arma::mat& referenceSet,
+NeighborSearch(const typename TreeType::Mat& referenceSet,
                const bool naive,
                const bool singleMode,
                const size_t leafSize,
@@ -126,7 +126,7 @@
  * distances.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeNeighbors(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
     const size_t k,
     arma::Mat<size_t>& resultingNeighbors,
     arma::mat& distances)
@@ -154,35 +154,20 @@
   {
     // Run the base case computation on all nodes
     if (queryTree)
-      ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
+      BaseCase(referenceTree, queryTree, *neighborPtr, *distancePtr);
     else
-      ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+      BaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
   }
   else
   {
     if (singleMode)
     {
-      // Do one tenth of the query set at a time.
-      size_t chunk = querySet.n_cols / 10;
-
-      for (size_t i = 0; i < 10; i++)
+      // Loop over each point in the query set.
+      for (size_t i = 0; i < querySet.n_cols; i++)
       {
-        for (size_t j = 0; j < chunk; j++)
-        {
-          double worstDistance = SortPolicy::WorstDistance();
-          ComputeSingleNeighborsRecursion(i * chunk + j,
-              querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
-              *neighborPtr, *distancePtr);
-        }
-      }
-
-      // The last tenth is differently sized...
-      for (size_t i = 0; i < querySet.n_cols % 10; i++)
-      {
-        size_t ind = (querySet.n_cols / 10) * 10 + i;
         double worstDistance = SortPolicy::WorstDistance();
-        ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
-            referenceTree, worstDistance, *neighborPtr, *distancePtr);
+        SingleTreeRecursion(referenceTree, querySet.col(i), i, worstDistance,
+            *neighborPtr, *distancePtr);
       }
     }
     else // Dual-tree recursion.
@@ -190,13 +175,13 @@
       // Start on the root of each tree.
       if (queryTree)
       {
-        ComputeDualNeighborsRecursion(queryTree, referenceTree,
+        DualTreeRecursion(queryTree, referenceTree,
             SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
             *neighborPtr, *distancePtr);
       }
       else
       {
-        ComputeDualNeighborsRecursion(referenceTree, referenceTree,
+        DualTreeRecursion(referenceTree, referenceTree,
             SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
             *neighborPtr, *distancePtr);
       }
@@ -297,9 +282,9 @@
  * Performs exhaustive computation between two leaves.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
-      TreeType* queryNode,
+void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
       TreeType* referenceNode,
+      TreeType* queryNode,
       arma::Mat<size_t>& neighbors,
       arma::mat& distances)
 {
@@ -311,16 +296,14 @@
   for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
        queryIndex++)
   {
-    // Get the query point from the matrix.
-    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-
+    // Get the best possible distance from the query point to the node.
     double queryToNodeDistance =
-        SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
+        SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
+                                            referenceNode);
 
     if (SortPolicy::IsBetter(queryToNodeDistance,
         distances(distances.n_rows - 1, queryIndex)))
     {
-      // We'll do the same for the references.
       for (size_t referenceIndex = referenceNode->Begin();
           referenceIndex < referenceNode->End(); referenceIndex++)
       {
@@ -328,10 +311,9 @@
         // in the monochromatic case.
         if (referenceNode != queryNode || referenceIndex != queryIndex)
         {
-          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+          double distance = metric.Evaluate(querySet.col(queryIndex),
+                                            referenceSet.col(referenceIndex));
 
-          double distance = metric.Evaluate(queryPoint, referencePoint);
-
           // If the reference point is closer than any of the current
           // candidates, add it to the list.
           arma::vec queryDist = distances.unsafe_col(queryIndex);
@@ -360,10 +342,9 @@
  * The recursive function for dual tree.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeDualNeighborsRecursion(
-    TreeType* queryNode,
+void NeighborSearch<SortPolicy, MetricType, TreeType>::DualTreeRecursion(
     TreeType* referenceNode,
+    TreeType* queryNode,
     const double lowerBound,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
@@ -377,7 +358,7 @@
   if (queryNode->IsLeaf() && referenceNode->IsLeaf())
   {
     // Base case: both are leaves.
-    ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
+    BaseCase(referenceNode, queryNode, neighbors, distances);
     return;
   }
 
@@ -394,17 +375,17 @@
 
     if (SortPolicy::IsBetter(leftDistance, rightDistance))
     {
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
-          leftDistance, neighbors, distances);
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
-          rightDistance, neighbors, distances);
+      DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
+          neighbors, distances);
+      DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
+          neighbors, distances);
     }
     else
     {
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
-          rightDistance, neighbors, distances);
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
-          leftDistance, neighbors, distances);
+      DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
+          neighbors, distances);
+      DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
+          neighbors, distances);
     }
     return;
   }
@@ -417,10 +398,10 @@
     double rightDistance = SortPolicy::BestNodeToNodeDistance(
         queryNode->Right(), referenceNode);
 
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
-        rightDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode, queryNode->Left(), leftDistance,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode, queryNode->Right(), rightDistance,
+        neighbors, distances);
 
     // We need to update the upper bound based on the new upper bounds of the
     // children.
@@ -445,17 +426,17 @@
   // Recurse on queryNode->left() first.
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
+        neighbors, distances);
   }
   else
   {
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
+        neighbors, distances);
   }
 
   leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
@@ -466,17 +447,17 @@
   // Now recurse on queryNode->right().
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
+        neighbors, distances);
   }
   else
   {
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
+        neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
+        neighbors, distances);
   }
 
   // Update the upper bound as above
@@ -491,13 +472,14 @@
 } // ComputeDualNeighborsRecursion()
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeSingleNeighborsRecursion(const size_t pointId,
-                                const arma::vec& point,
-                                TreeType* referenceNode,
-                                double& bestDistSoFar,
-                                arma::Mat<size_t>& neighbors,
-                                arma::mat& distances)
+template<typename VecType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
+    TreeType* referenceNode,
+    const VecType& queryPoint,
+    const size_t queryIndex,
+    double& bestDistSoFar,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
 {
   if (referenceNode->IsLeaf())
   {
@@ -506,33 +488,31 @@
         referenceIndex < referenceNode->End(); referenceIndex++)
     {
       // Confirm that points do not identify themselves as neighbors
-      // in the monochromatic case
-      if (!(referenceSet.memptr() == querySet.memptr() &&
-            referenceIndex == pointId))
+      // in the monochromatic case.
+      if (!queryTree && !(referenceIndex == queryIndex))
       {
-        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+        double distance = metric.Evaluate(queryPoint,
+                                          referenceSet.col(referenceIndex));
 
-        double distance = metric.Evaluate(point, referencePoint);
-
         // If the reference point is better than any of the current candidates,
         // insert it into the list correctly.
-        arma::vec queryDist = distances.unsafe_col(pointId);
+        arma::vec queryDist = distances.unsafe_col(queryIndex);
         size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
 
         if (insertPosition != (size_t() - 1))
-          InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
+          InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance,
               neighbors, distances);
       }
     } // for referenceIndex
 
-    bestDistSoFar = distances(distances.n_rows - 1, pointId);
+    bestDistSoFar = distances(distances.n_rows - 1, queryIndex);
   }
   else
   {
     // We'll order the computation by distance.
-    double leftDistance = SortPolicy::BestPointToNodeDistance(point,
+    double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
         referenceNode->Left());
-    double rightDistance = SortPolicy::BestPointToNodeDistance(point,
+    double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
         referenceNode->Right());
 
     // Recurse in the best direction first.
@@ -541,13 +521,13 @@
       if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+        SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
             bestDistSoFar, neighbors, distances);
 
       if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+        SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
             bestDistSoFar, neighbors, distances);
 
     }
@@ -556,13 +536,13 @@
       if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+        SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
             bestDistSoFar, neighbors, distances);
 
       if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
         numberOfPrunes++; // Prune; no possibility of finding a better point.
       else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+        SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
             bestDistSoFar, neighbors, distances);
     }
   }
@@ -577,13 +557,13 @@
  * @param distance Distance from query point to reference point.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-InsertNeighbor(const size_t queryIndex,
-               const size_t pos,
-               const size_t neighbor,
-               const double distance,
-               arma::Mat<size_t>& neighbors,
-               arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+    const size_t queryIndex,
+    const size_t pos,
+    const size_t neighbor,
+    const double distance,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (distances.n_rows - 1))




More information about the mlpack-svn mailing list