[mlpack-svn] r12665 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:26:37 EDT 2012
Author: rcurtin
Date: 2012-05-09 16:26:37 -0400 (Wed, 09 May 2012)
New Revision: 12665
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Use DualTreeDepthFirstTraverser for naive and dual-tree calculations, and remove
all of the guts we previously needed but don't anymore (ComputeBaseCase,
ComputeDualNeighborsRecursion, ComputeSingleNeighborsRecursion).
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2012-05-09 20:26:02 UTC (rev 12664)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2012-05-09 20:26:37 UTC (rev 12665)
@@ -235,77 +235,6 @@
arma::mat& distances);
private:
- /**
- * Perform exhaustive computation between two leaves, comparing every node in
- * the leaf to the other leaf to find the furthest neighbor. The
- * neighbors and distances matrices will be updated with the changed
- * information.
- *
- * @param queryNode Node in query tree. This should be a leaf
- * (bottom-level).
- * @param referenceNode Node in reference tree. This should be a leaf
- * (bottom-level).
- * @param neighbors List of neighbors for each point.
- * @param distances List of distances for each point.
- */
- void ComputeBaseCase(TreeType* queryNode,
- TreeType* referenceNode,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- /**
- * Recurse down the trees, computing base case computations when the leaves
- * are reached.
- *
- * @param queryNode Node in query tree.
- * @param referenceNode Node in reference tree.
- * @param lowerBound The lower bound; if above this, we can prune.
- * @param neighbors List of neighbors for each point.
- * @param distances List of distances for each point.
- */
- void ComputeDualNeighborsRecursion(TreeType* queryNode,
- TreeType* referenceNode,
- const double lowerBound,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- /**
- * Perform a recursion only on the reference tree; the query point is given.
- * This method is similar to ComputeBaseCase().
- *
- * @param pointId Index of query point.
- * @param point The query point.
- * @param referenceNode Reference node.
- * @param bestDistSoFar Best distance to a node so far -- used for pruning.
- * @param neighbors List of neighbors for each point.
- * @param distances List of distances for each point.
- */
- template<typename VecType>
- void ComputeSingleNeighborsRecursion(const size_t pointId,
- const VecType& point,
- TreeType* referenceNode,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- /**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- * @param neighbors List of neighbors for each point.
- * @param distances List of distances for each point.
- */
- void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
arma::mat referenceCopy;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2012-05-09 20:26:02 UTC (rev 12664)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2012-05-09 20:26:37 UTC (rev 12665)
@@ -12,6 +12,7 @@
#include <mlpack/core/tree/traversers/single_tree_depth_first_traverser.hpp>
#include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
+#include <mlpack/core/tree/traversers/dual_tree_depth_first_traverser.hpp>
#include "neighbor_search_rules.hpp"
using namespace mlpack::neighbor;
@@ -174,51 +175,48 @@
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
- if (naive)
+ size_t numPrunes = 0;
+
+ if (singleMode)
{
- // Run the base case computation on all nodes
- if (queryTree)
- ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
- else
- ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+ // Create the helper object for the tree traversal.
+ NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
+ querySet, *neighborPtr, *distancePtr, metric);
+
+ // Create the traverser.
+ typename TreeType::template PreferredTraverser<
+ NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
+ traverser(rules);
+
+ // Now have it traverse for each point.
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
}
- else
+ else // Dual-tree recursion.
{
- if (singleMode)
- {
- // Create the helper object for the tree traversal.
- NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
- querySet, *neighborPtr, *distancePtr, metric);
+ // Use crazy dual-tree traverser.
+ typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
- // Create the traverser.
- typename TreeType::template PreferredTraverser<
- NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
- traverser(rules);
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr,
+ metric);
- // Now have it traverse for each point.
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
+ typedef tree::DualTreeDepthFirstTraverser<TreeType, RuleType>
+ TraverserType;
- Log::Info << "Pruned " << traverser.NumPrunes() << " nodes." << std::endl;
- }
- else // Dual-tree recursion.
- {
- // Start on the root of each tree.
- if (queryTree)
- {
- ComputeDualNeighborsRecursion(queryTree, referenceTree,
- SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
- *neighborPtr, *distancePtr);
- }
- else
- {
- ComputeDualNeighborsRecursion(referenceTree, referenceTree,
- SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
- *neighborPtr, *distancePtr);
- }
- }
+ TraverserType traverser(rules);
+
+ if (queryTree)
+ traverser.Traverse(*referenceTree, *queryTree);
+ else
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ numPrunes = traverser.NumPrunes();
}
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
Timer::Stop("computing_neighbors");
// Now, do we need to do mapping of indices?
@@ -309,314 +307,4 @@
}
} // Search
-/**
- * Performs exhaustive computation between two leaves.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
- TreeType* queryNode,
- TreeType* referenceNode,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- // Used to find the query node's new upper bound.
- double queryWorstDistance = SortPolicy::BestDistance();
-
- // node->Begin() is the index of the first point in the node,
- // node->End() is one past the last index.
- for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
- queryIndex++)
- {
- // Get the query point from the matrix.
- arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-
- double queryToNodeDistance =
- SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
-
- if (SortPolicy::IsBetter(queryToNodeDistance,
- distances(distances.n_rows - 1, queryIndex)))
- {
- // We'll do the same for the references.
- for (size_t referenceIndex = referenceNode->Begin();
- referenceIndex < referenceNode->End(); referenceIndex++)
- {
- // Confirm that points do not identify themselves as neighbors
- // in the monochromatic case.
- if (referenceNode != queryNode || referenceIndex != queryIndex)
- {
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
-
- double distance = metric.Evaluate(queryPoint, referencePoint);
-
- // If the reference point is closer than any of the current
- // candidates, add it to the list.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist,
- distance);
-
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex,
- distance, neighbors, distances);
- }
- }
- }
-
- // We need to find the upper bound distance for this query node
- if (SortPolicy::IsBetter(queryWorstDistance,
- distances(distances.n_rows - 1, queryIndex)))
- queryWorstDistance = distances(distances.n_rows - 1, queryIndex);
- }
-
- // Update the upper bound for the queryNode
- queryNode->Stat().Bound() = queryWorstDistance;
-
-} // ComputeBaseCase()
-
-/**
- * The recursive function for dual tree.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeDualNeighborsRecursion(
- TreeType* queryNode,
- TreeType* referenceNode,
- const double lowerBound,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- if (SortPolicy::IsBetter(queryNode->Stat().Bound(), lowerBound))
- {
- numberOfPrunes++; // Pruned by distance; the nodes cannot be any closer
- return; // than the already established lower bound.
- }
-
- if (queryNode->IsLeaf() && referenceNode->IsLeaf())
- {
- // Base case: both are leaves.
- ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
- return;
- }
-
- if (queryNode->IsLeaf())
- {
- // We must keep descending down the reference node to get to a leaf.
-
- // We'll order the computation by distance; descend in the direction of less
- // distance first.
- double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
- referenceNode->Left());
- double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode,
- referenceNode->Right());
-
- if (SortPolicy::IsBetter(leftDistance, rightDistance))
- {
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
- rightDistance, neighbors, distances);
- }
- else
- {
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
- leftDistance, neighbors, distances);
- }
- return;
- }
-
- if (referenceNode->IsLeaf())
- {
- // We must descend down the query node to get to a leaf.
- double leftDistance = SortPolicy::BestNodeToNodeDistance(
- queryNode->Left(), referenceNode);
- double rightDistance = SortPolicy::BestNodeToNodeDistance(
- queryNode->Right(), referenceNode);
-
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
- rightDistance, neighbors, distances);
-
- // We need to update the upper bound based on the new upper bounds of the
- // children.
- double leftBound = queryNode->Left()->Stat().Bound();
- double rightBound = queryNode->Right()->Stat().Bound();
-
- if (SortPolicy::IsBetter(leftBound, rightBound))
- queryNode->Stat().Bound() = rightBound;
- else
- queryNode->Stat().Bound() = leftBound;
-
- return;
- }
-
- // Neither side is a leaf; so we recurse on all combinations of both. The
- // calculations are ordered by distance.
- double leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
- referenceNode->Left());
- double rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Left(),
- referenceNode->Right());
-
- // Recurse on queryNode->left() first.
- if (SortPolicy::IsBetter(leftDistance, rightDistance))
- {
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- }
- else
- {
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- }
-
- leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
- referenceNode->Left());
- rightDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
- referenceNode->Right());
-
- // Now recurse on queryNode->right().
- if (SortPolicy::IsBetter(leftDistance, rightDistance))
- {
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- }
- else
- {
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- }
-
- // Update the upper bound as above
- double leftBound = queryNode->Left()->Stat().Bound();
- double rightBound = queryNode->Right()->Stat().Bound();
-
- if (SortPolicy::IsBetter(leftBound, rightBound))
- queryNode->Stat().Bound() = rightBound;
- else
- queryNode->Stat().Bound() = leftBound;
-
-} // ComputeDualNeighborsRecursion()
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename VecType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeSingleNeighborsRecursion(const size_t pointId,
- const VecType& point,
- TreeType* referenceNode,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- if (referenceNode->IsLeaf())
- {
- // Base case: reference node is a leaf.
- for (size_t referenceIndex = referenceNode->Begin();
- referenceIndex < referenceNode->End(); referenceIndex++)
- {
- // Confirm that points do not identify themselves as neighbors
- // in the monochromatic case
- if (!(referenceSet.memptr() == querySet.memptr() &&
- referenceIndex == pointId))
- {
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
-
- double distance = metric.Evaluate(point, referencePoint);
-
- // If the reference point is better than any of the current candidates,
- // insert it into the list correctly.
- arma::vec queryDist = distances.unsafe_col(pointId);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
-
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
- neighbors, distances);
- }
- } // for referenceIndex
-
- bestDistSoFar = distances(distances.n_rows - 1, pointId);
- }
- else
- {
- // We'll order the computation by distance.
- double leftDistance = SortPolicy::BestPointToNodeDistance(point,
- referenceNode->Left());
- double rightDistance = SortPolicy::BestPointToNodeDistance(point,
- referenceNode->Right());
-
- // Recurse in the best direction first.
- if (SortPolicy::IsBetter(leftDistance, rightDistance))
- {
- if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
- numberOfPrunes++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
- bestDistSoFar, neighbors, distances);
-
- if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
- numberOfPrunes++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
- bestDistSoFar, neighbors, distances);
-
- }
- else
- {
- if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
- numberOfPrunes++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
- bestDistSoFar, neighbors, distances);
-
- if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
- numberOfPrunes++; // Prune; no possibility of finding a better point.
- else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
- bestDistSoFar, neighbors, distances);
- }
- }
-}
-
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
- {
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
-}
-
#endif
More information about the mlpack-svn
mailing list