[mlpack-svn] r10735 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 17:54:17 EST 2011


Author: rcurtin
Date: 2011-12-12 17:54:17 -0500 (Mon, 12 Dec 2011)
New Revision: 10735

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Retry the API change.  This time it works...


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-12 21:48:51 UTC (rev 10734)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2011-12-12 22:54:17 UTC (rev 10735)
@@ -97,8 +97,8 @@
    * @param queryTree Optionally pass a pre-built tree for the query set.
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
-                 const typename TreeType::Mat& querySet,
+  NeighborSearch(const arma::mat& referenceSet,
+                 const arma::mat& querySet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const size_t leafSize = 20,
@@ -131,7 +131,7 @@
    *      set.
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const typename TreeType::Mat& referenceSet,
+  NeighborSearch(const arma::mat& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
                  const size_t leafSize = 20,
@@ -183,8 +183,8 @@
    * Recurse down the trees, computing base case computations when the leaves
    * are reached.
    *
-   * @param queryNode Node in query tree.
    * @param referenceNode Node in reference tree.
+   * @param queryNode Node in query tree.
    * @param lowerBound The lower bound; if above this, we can prune.
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
@@ -197,18 +197,17 @@
 
   /**
    * Perform a recursion only on the reference tree; the query point is given.
-   * This method is similar to ComputeBaseCase().
+   * This method is similar to BaseCase().
    *
-   * @param pointId Index of query point.
-   * @param point The query point.
    * @param referenceNode Reference node.
+   * @param queryPoint The query point.
+   * @param queryIndex Index of query point.
    * @param bestDistSoFar Best distance to a node so far -- used for pruning.
    * @param neighbors List of neighbors for each point.
    * @param distances List of distances for each point.
    */
-  template<typename VecType>
   void SingleTreeRecursion(TreeType* referenceNode,
-                           const VecType& queryPoint,
+                           const arma::vec& queryPoint,
                            const size_t queryIndex,
                            double& bestDistSoFar,
                            arma::Mat<size_t>& neighbors,
@@ -234,14 +233,14 @@
 
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
-  typename TreeType::Mat referenceCopy;
+  arma::mat referenceCopy;
   //! Copy of query dataset (if we need it, because tree building modifies it).
-  typename TreeType::Mat queryCopy;
+  arma::mat queryCopy;
 
-  //! Reference dataset (data should be accessed using this).
-  const typename TreeType::Mat& referenceSet;
-  //! Query dataset (data should be accessed using this).
-  const typename TreeType::Mat& querySet;
+  //! Reference dataset.
+  const arma::mat& referenceSet;
+  //! Query dataset (may not be given).
+  const arma::mat& querySet;
 
   //! Indicates if O(n^2) naive search is being used.
   bool naive;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-12 21:48:51 UTC (rev 10734)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2011-12-12 22:54:17 UTC (rev 10735)
@@ -15,8 +15,8 @@
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
-               const typename TreeType::Mat& querySet,
+NeighborSearch(const arma::mat& referenceSet,
+               const arma::mat& querySet,
                const bool naive,
                const bool singleMode,
                const size_t leafSize,
@@ -72,7 +72,7 @@
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
+NeighborSearch(const arma::mat& referenceSet,
                const bool naive,
                const bool singleMode,
                const size_t leafSize,
@@ -162,12 +162,11 @@
   {
     if (singleMode)
     {
-      // Loop over each point in the query set.
       for (size_t i = 0; i < querySet.n_cols; i++)
       {
         double worstDistance = SortPolicy::WorstDistance();
-        SingleTreeRecursion(referenceTree, querySet.col(i), i, worstDistance,
-            *neighborPtr, *distancePtr);
+        SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), i,
+            worstDistance, *neighborPtr, *distancePtr);
       }
     }
     else // Dual-tree recursion.
@@ -175,7 +174,7 @@
       // Start on the root of each tree.
       if (queryTree)
       {
-        DualTreeRecursion(queryTree, referenceTree,
+        DualTreeRecursion(referenceTree, queryTree,
             SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
             *neighborPtr, *distancePtr);
       }
@@ -283,8 +282,8 @@
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
 void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
+      TreeType* queryNode,
       TreeType* referenceNode,
-      TreeType* queryNode,
       arma::Mat<size_t>& neighbors,
       arma::mat& distances)
 {
@@ -296,14 +295,16 @@
   for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
        queryIndex++)
   {
-    // Get the best possible distance from the query point to the node.
+    // Get the query point from the matrix.
+    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+
     double queryToNodeDistance =
-        SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
-                                            referenceNode);
+        SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
 
     if (SortPolicy::IsBetter(queryToNodeDistance,
         distances(distances.n_rows - 1, queryIndex)))
     {
+      // We'll do the same for the references.
       for (size_t referenceIndex = referenceNode->Begin();
           referenceIndex < referenceNode->End(); referenceIndex++)
       {
@@ -311,9 +312,10 @@
         // in the monochromatic case.
         if (referenceNode != queryNode || referenceIndex != queryIndex)
         {
-          double distance = metric.Evaluate(querySet.col(queryIndex),
-                                            referenceSet.col(referenceIndex));
+          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
 
+          double distance = metric.Evaluate(queryPoint, referencePoint);
+
           // If the reference point is closer than any of the current
           // candidates, add it to the list.
           arma::vec queryDist = distances.unsafe_col(queryIndex);
@@ -336,7 +338,7 @@
   // Update the upper bound for the queryNode
   queryNode->Stat().Bound() = queryWorstDistance;
 
-} // ComputeBaseCase()
+} // BaseCase()
 
 /**
  * The recursive function for dual tree.
@@ -375,17 +377,17 @@
 
     if (SortPolicy::IsBetter(leftDistance, rightDistance))
     {
-      DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
-          neighbors, distances);
-      DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
-          neighbors, distances);
+      DualTreeRecursion(referenceNode->Left(), queryNode,
+          leftDistance, neighbors, distances);
+      DualTreeRecursion(referenceNode->Right(), queryNode,
+          rightDistance, neighbors, distances);
     }
     else
     {
-      DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
-          neighbors, distances);
-      DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
-          neighbors, distances);
+      DualTreeRecursion(referenceNode->Right(), queryNode,
+          rightDistance, neighbors, distances);
+      DualTreeRecursion(referenceNode->Left(), queryNode,
+          leftDistance, neighbors, distances);
     }
     return;
   }
@@ -398,10 +400,10 @@
     double rightDistance = SortPolicy::BestNodeToNodeDistance(
         queryNode->Right(), referenceNode);
 
-    DualTreeRecursion(referenceNode, queryNode->Left(), leftDistance,
-        neighbors, distances);
-    DualTreeRecursion(referenceNode, queryNode->Right(), rightDistance,
-        neighbors, distances);
+    DualTreeRecursion(referenceNode, queryNode->Left(),
+        leftDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode, queryNode->Right(),
+        rightDistance, neighbors, distances);
 
     // We need to update the upper bound based on the new upper bounds of the
     // children.
@@ -426,17 +428,17 @@
   // Recurse on queryNode->left() first.
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
-        neighbors, distances);
-    DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
-        neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+        leftDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+        rightDistance, neighbors, distances);
   }
   else
   {
-    DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
-        neighbors, distances);
-    DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
-        neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+        rightDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+        leftDistance, neighbors, distances);
   }
 
   leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
@@ -447,17 +449,17 @@
   // Now recurse on queryNode->right().
   if (SortPolicy::IsBetter(leftDistance, rightDistance))
   {
-    DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
-        neighbors, distances);
-    DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
-        neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+        leftDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+        rightDistance, neighbors, distances);
   }
   else
   {
-    DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
-        neighbors, distances);
-    DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
-        neighbors, distances);
+    DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+        rightDistance, neighbors, distances);
+    DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+        leftDistance, neighbors, distances);
   }
 
   // Update the upper bound as above
@@ -472,10 +474,9 @@
 } // ComputeDualNeighborsRecursion()
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename VecType>
 void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
     TreeType* referenceNode,
-    const VecType& queryPoint,
+    const arma::vec& queryPoint,
     const size_t queryIndex,
     double& bestDistSoFar,
     arma::Mat<size_t>& neighbors,
@@ -488,12 +489,13 @@
         referenceIndex < referenceNode->End(); referenceIndex++)
     {
       // Confirm that points do not identify themselves as neighbors
-      // in the monochromatic case.
-      if (!queryTree && !(referenceIndex == queryIndex))
+      // in the monochromatic case
+      if (queryTree || (referenceIndex != queryIndex))
       {
-        double distance = metric.Evaluate(queryPoint,
-                                          referenceSet.col(referenceIndex));
+        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
 
+        double distance = metric.Evaluate(queryPoint, referencePoint);
+
         // If the reference point is better than any of the current candidates,
         // insert it into the list correctly.
         arma::vec queryDist = distances.unsafe_col(queryIndex);




More information about the mlpack-svn mailing list