[mlpack-svn] r12665 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:26:37 EDT 2012


Author: rcurtin
Date: 2012-05-09 16:26:37 -0400 (Wed, 09 May 2012)
New Revision: 12665

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Use DualTreeDepthFirstTraverser for naive and dual-tree calculations, and remove
all of the guts we previously needed but don't anymore (ComputeBaseCase,
ComputeDualNeighborsRecursion, ComputeSingleNeighborsRecursion).


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2012-05-09 20:26:02 UTC (rev 12664)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2012-05-09 20:26:37 UTC (rev 12665)
@@ -235,77 +235,6 @@
               arma::mat& distances);
 
  private:
-  /**
-   * Perform exhaustive computation between two leaves, comparing every node in
-   * the leaf to the other leaf to find the furthest neighbor.  The
-   * neighbors and distances matrices will be updated with the changed
-   * information.
-   *
-   * @param queryNode Node in query tree.  This should be a leaf
-   *     (bottom-level).
-   * @param referenceNode Node in reference tree.  This should be a leaf
-   *     (bottom-level).
-   * @param neighbors List of neighbors for each point.
-   * @param distances List of distances for each point.
-   */
-  void ComputeBaseCase(TreeType* queryNode,
-                       TreeType* referenceNode,
-                       arma::Mat<size_t>& neighbors,
-                       arma::mat& distances);
-
-  /**
-   * Recurse down the trees, computing base case computations when the leaves
-   * are reached.
-   *
-   * @param queryNode Node in query tree.
-   * @param referenceNode Node in reference tree.
-   * @param lowerBound The lower bound; if above this, we can prune.
-   * @param neighbors List of neighbors for each point.
-   * @param distances List of distances for each point.
-   */
-  void ComputeDualNeighborsRecursion(TreeType* queryNode,
-                                     TreeType* referenceNode,
-                                     const double lowerBound,
-                                     arma::Mat<size_t>& neighbors,
-                                     arma::mat& distances);
-
-  /**
-   * Perform a recursion only on the reference tree; the query point is given.
-   * This method is similar to ComputeBaseCase().
-   *
-   * @param pointId Index of query point.
-   * @param point The query point.
-   * @param referenceNode Reference node.
-   * @param bestDistSoFar Best distance to a node so far -- used for pruning.
-   * @param neighbors List of neighbors for each point.
-   * @param distances List of distances for each point.
-   */
-  template<typename VecType>
-  void ComputeSingleNeighborsRecursion(const size_t pointId,
-                                       const VecType& point,
-                                       TreeType* referenceNode,
-                                       double& bestDistSoFar,
-                                       arma::Mat<size_t>& neighbors,
-                                       arma::mat& distances);
-
-  /**
-   * Insert a point into the neighbors and distances matrices; this is a helper
-   * function.
-   *
-   * @param queryIndex Index of point whose neighbors we are inserting into.
-   * @param pos Position in list to insert into.
-   * @param neighbor Index of reference point which is being inserted.
-   * @param distance Distance from query point to reference point.
-   * @param neighbors List of neighbors for each point.
-   * @param distances List of distances for each point.
-   */
-  void InsertNeighbor(const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance,
-                      arma::Mat<size_t>& neighbors,
-                      arma::mat& distances);
-
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
   arma::mat referenceCopy;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-05-09 20:26:02 UTC (rev 12664)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-05-09 20:26:37 UTC (rev 12665)
@@ -12,6 +12,7 @@
 
 #include <mlpack/core/tree/traversers/single_tree_depth_first_traverser.hpp>
 #include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
+#include <mlpack/core/tree/traversers/dual_tree_depth_first_traverser.hpp>
 #include "neighbor_search_rules.hpp"
 
 using namespace mlpack::neighbor;
@@ -174,51 +175,48 @@
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
-  if (naive)
+  size_t numPrunes = 0;
+
+  if (singleMode)
   {
-    // Run the base case computation on all nodes
-    if (queryTree)
-      ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
-    else
-      ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+    // Create the helper object for the tree traversal.
+    NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
+        querySet, *neighborPtr, *distancePtr, metric);
+
+    // Create the traverser.
+    typename TreeType::template PreferredTraverser<
+      NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
+      traverser(rules);
+
+    // Now have it traverse for each point.
+    for (size_t i = 0; i < querySet.n_cols; ++i)
+      traverser.Traverse(i, *referenceTree);
+
+    numPrunes = traverser.NumPrunes();
   }
-  else
+  else // Dual-tree recursion.
   {
-    if (singleMode)
-    {
-      // Create the helper object for the tree traversal.
-      NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
-          querySet, *neighborPtr, *distancePtr, metric);
+    // Use crazy dual-tree traverser.
+    typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
 
-      // Create the traverser.
-      typename TreeType::template PreferredTraverser<
-          NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
-          traverser(rules);
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+        metric);
 
-      // Now have it traverse for each point.
-      for (size_t i = 0; i < querySet.n_cols; ++i)
-        traverser.Traverse(i, *referenceTree);
+    typedef tree::DualTreeDepthFirstTraverser<TreeType, RuleType>
+      TraverserType;
 
-      Log::Info << "Pruned " << traverser.NumPrunes() << " nodes." << std::endl;
-    }
-    else // Dual-tree recursion.
-    {
-      // Start on the root of each tree.
-      if (queryTree)
-      {
-        ComputeDualNeighborsRecursion(queryTree, referenceTree,
-            SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
-            *neighborPtr, *distancePtr);
-      }
-      else
-      {
-        ComputeDualNeighborsRecursion(referenceTree, referenceTree,
-            SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
-            *neighborPtr, *distancePtr);
-      }
-    }
+    TraverserType traverser(rules);
+
+    if (queryTree)
+      traverser.Traverse(*referenceTree, *queryTree);
+    else
+      traverser.Traverse(*referenceTree, *referenceTree);
+
+    numPrunes = traverser.NumPrunes();
   }
 
+  Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
   Timer::Stop("computing_neighbors");
 
   // Now, do we need to do mapping of indices?
@@ -309,314 +307,4 @@
   }
 } // Search
 
-/**
- * Performs exhaustive computation between two leaves.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
-      TreeType* queryNode,
-      TreeType* referenceNode,
-      arma::Mat<size_t>& neighbors,
-      arma::mat& distances)
-{
-  // Used to find the query node's new upper bound.
-  double queryWorstDistance = SortPolicy::BestDistance();
-
-  // node->Begin() is the index of the first point in the node,
-  // node->End() is one past the last index.
-  for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
-       queryIndex++)
-  {
-    // Get the query point from the matrix.
-    arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-
-    double queryToNodeDistance =
-        SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
-
-    if (SortPolicy::IsBetter(queryToNodeDistance,
-        distances(distances.n_rows - 1, queryIndex)))
-    {
-      // We'll do the same for the references.
-      for (size_t referenceIndex = referenceNode->Begin();
-          referenceIndex < referenceNode->End(); referenceIndex++)
-      {
-        // Confirm that points do not identify themselves as neighbors
-        // in the monochromatic case.
-        if (referenceNode != queryNode || referenceIndex != queryIndex)
-        {
-          arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
-
-          double distance = metric.Evaluate(queryPoint, referencePoint);
-
-          // If the reference point is closer than any of the current
-          // candidates, add it to the list.
-          arma::vec queryDist = distances.unsafe_col(queryIndex);
-          size_t insertPosition = SortPolicy::SortDistance(queryDist,
-              distance);
-
-          if (insertPosition != (size_t() - 1))
-            InsertNeighbor(queryIndex, insertPosition, referenceIndex,
-                distance, neighbors, distances);
-        }
-      }
-    }
-
-    // We need to find the upper bound distance for this query node
-    if (SortPolicy::IsBetter(queryWorstDistance,
-        distances(distances.n_rows - 1, queryIndex)))
-      queryWorstDistance = distances(distances.n_rows - 1, queryIndex);
-  }
-
-  // Update the upper bound for the queryNode
-  queryNode->Stat().Bound() = queryWorstDistance;
-
-} // ComputeBaseCase()
-
-/**
- * The recursive function for dual tree.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeDualNeighborsRecursion(
-    TreeType* queryNode,
-    TreeType* referenceNode,
-    const double lowerBound,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
-{
-  if (SortPolicy::IsBetter(queryNode->Stat().Bound(), lowerBound))
-  {
-    numberOfPrunes++; // Pruned by distance; the nodes cannot be any closer
-    return;           // than the already established lower bound.
-  }
-
-  if (queryNode->IsLeaf() && referenceNode->IsLeaf())
-  {
-    // Base case: both are leaves.
-    ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
-    return;
-  }
-
-  if (queryNode->IsLeaf())
-  {
-    // We must keep descending down the reference node to get to a leaf.
-
-    // We'll order the computation by distance; descend in the direction of less
-    // distance first.
-    double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
-        referenceNode->Left());
-    double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
-        referenceNode->Right());
-
-    if (SortPolicy::IsBetter(leftDistance, rightDistance))
-    {
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
-          leftDistance, neighbors, distances);
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
-          rightDistance, neighbors, distances);
-    }
-    else
-    {
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
-          rightDistance, neighbors, distances);
-      ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
-          leftDistance, neighbors, distances);
-    }
-    return;
-  }
-
-  if (referenceNode->IsLeaf())
-  {
-    // We must descend down the query node to get to a leaf.
-    double leftDistance = SortPolicy::BestNodeToNodeDistance(
-        queryNode->Left(), referenceNode);
-    double rightDistance = SortPolicy::BestNodeToNodeDistance(
-        queryNode->Right(), referenceNode);
-
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
-        rightDistance, neighbors, distances);
-
-    // We need to update the upper bound based on the new upper bounds of the
-    // children.
-    double leftBound = queryNode->Left()->Stat().Bound();
-    double rightBound = queryNode->Right()->Stat().Bound();
-
-    if (SortPolicy::IsBetter(leftBound, rightBound))
-      queryNode->Stat().Bound() = rightBound;
-    else
-      queryNode->Stat().Bound() = leftBound;
-
-    return;
-  }
-
-  // Neither side is a leaf; so we recurse on all combinations of both.  The
-  // calculations are ordered by distance.
-  double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
-      referenceNode->Left());
-  double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
-      referenceNode->Right());
-
-  // Recurse on queryNode->left() first.
-  if (SortPolicy::IsBetter(leftDistance, rightDistance))
-  {
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-  }
-  else
-  {
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-  }
-
-  leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
-      referenceNode->Left());
-  rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
-      referenceNode->Right());
-
-  // Now recurse on queryNode->right().
-  if (SortPolicy::IsBetter(leftDistance, rightDistance))
-  {
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-  }
-  else
-  {
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
-        rightDistance, neighbors, distances);
-    ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
-        leftDistance, neighbors, distances);
-  }
-
-  // Update the upper bound as above
-  double leftBound = queryNode->Left()->Stat().Bound();
-  double rightBound = queryNode->Right()->Stat().Bound();
-
-  if (SortPolicy::IsBetter(leftBound, rightBound))
-    queryNode->Stat().Bound() = rightBound;
-  else
-    queryNode->Stat().Bound() = leftBound;
-
-} // ComputeDualNeighborsRecursion()
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename VecType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeSingleNeighborsRecursion(const size_t pointId,
-                                const VecType& point,
-                                TreeType* referenceNode,
-                                double& bestDistSoFar,
-                                arma::Mat<size_t>& neighbors,
-                                arma::mat& distances)
-{
-  if (referenceNode->IsLeaf())
-  {
-    // Base case: reference node is a leaf.
-    for (size_t referenceIndex = referenceNode->Begin();
-        referenceIndex < referenceNode->End(); referenceIndex++)
-    {
-      // Confirm that points do not identify themselves as neighbors
-      // in the monochromatic case
-      if (!(referenceSet.memptr() == querySet.memptr() &&
-            referenceIndex == pointId))
-      {
-        arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
-
-        double distance = metric.Evaluate(point, referencePoint);
-
-        // If the reference point is better than any of the current candidates,
-        // insert it into the list correctly.
-        arma::vec queryDist = distances.unsafe_col(pointId);
-        size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
-        if (insertPosition != (size_t() - 1))
-          InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
-              neighbors, distances);
-      }
-    } // for referenceIndex
-
-    bestDistSoFar = distances(distances.n_rows - 1, pointId);
-  }
-  else
-  {
-    // We'll order the computation by distance.
-    double leftDistance = SortPolicy::BestPointToNodeDistance(point,
-        referenceNode->Left());
-    double rightDistance = SortPolicy::BestPointToNodeDistance(point,
-        referenceNode->Right());
-
-    // Recurse in the best direction first.
-    if (SortPolicy::IsBetter(leftDistance, rightDistance))
-    {
-      if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
-        numberOfPrunes++; // Prune; no possibility of finding a better point.
-      else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
-            bestDistSoFar, neighbors, distances);
-
-      if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
-        numberOfPrunes++; // Prune; no possibility of finding a better point.
-      else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
-            bestDistSoFar, neighbors, distances);
-
-    }
-    else
-    {
-      if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
-        numberOfPrunes++; // Prune; no possibility of finding a better point.
-      else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
-            bestDistSoFar, neighbors, distances);
-
-      if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
-        numberOfPrunes++; // Prune; no possibility of finding a better point.
-      else
-        ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
-            bestDistSoFar, neighbors, distances);
-    }
-  }
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
-    const size_t queryIndex,
-    const size_t pos,
-    const size_t neighbor,
-    const double distance,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances)
-{
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (distances.n_rows - 1))
-  {
-    int len = (distances.n_rows - 1) - pos;
-    memmove(distances.colptr(queryIndex) + (pos + 1),
-        distances.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(neighbors.colptr(queryIndex) + (pos + 1),
-        neighbors.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
-  }
-
-  // Now put the new information in the right index.
-  distances(pos, queryIndex) = distance;
-  neighbors(pos, queryIndex) = neighbor;
-}
-
 #endif




More information about the mlpack-svn mailing list