[mlpack-svn] r10738 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 13 03:35:28 EST 2011
Author: rcurtin
Date: 2011-12-13 03:35:27 -0500 (Tue, 13 Dec 2011)
New Revision: 10738
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Why does this fix the test? I don't know, but there's not complete time to fix
it right now.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-13 07:23:09 UTC (rev 10737)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-13 08:35:27 UTC (rev 10738)
@@ -174,44 +174,45 @@
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void BaseCase(TreeType* referenceNode,
- TreeType* queryNode,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ void ComputeBaseCase(TreeType* queryNode,
+ TreeType* referenceNode,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Recurse down the trees, computing base case computations when the leaves
* are reached.
*
+ * @param queryNode Node in query tree.
* @param referenceNode Node in reference tree.
- * @param queryNode Node in query tree.
* @param lowerBound The lower bound; if above this, we can prune.
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void DualTreeRecursion(TreeType* referenceNode,
- TreeType* queryNode,
- const double lowerBound,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ void ComputeDualNeighborsRecursion(TreeType* queryNode,
+ TreeType* referenceNode,
+ const double lowerBound,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Perform a recursion only on the reference tree; the query point is given.
- * This method is similar to BaseCase().
+ * This method is similar to ComputeBaseCase().
*
+ * @param pointId Index of query point.
+ * @param point The query point.
* @param referenceNode Reference node.
- * @param queryPoint The query point.
- * @param queryIndex Index of query point.
* @param bestDistSoFar Best distance to a node so far -- used for pruning.
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void SingleTreeRecursion(TreeType* referenceNode,
- const arma::vec& queryPoint,
- const size_t queryIndex,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ template<typename VecType>
+ void ComputeSingleNeighborsRecursion(const size_t pointId,
+ const VecType& point,
+ TreeType* referenceNode,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Insert a point into the neighbors and distances matrices; this is a helper
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-13 07:23:09 UTC (rev 10737)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-13 08:35:27 UTC (rev 10738)
@@ -154,19 +154,35 @@
{
// Run the base case computation on all nodes
if (queryTree)
- BaseCase(referenceTree, queryTree, *neighborPtr, *distancePtr);
+ ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
else
- BaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+ ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
}
else
{
if (singleMode)
{
- for (size_t i = 0; i < querySet.n_cols; i++)
+ // Do one tenth of the query set at a time.
+ size_t chunk = querySet.n_cols / 10;
+
+ for (size_t i = 0; i < 10; i++)
{
+ for (size_t j = 0; j < chunk; j++)
+ {
+ double worstDistance = SortPolicy::WorstDistance();
+ ComputeSingleNeighborsRecursion(i * chunk + j,
+ querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
+ *neighborPtr, *distancePtr);
+ }
+ }
+
+ // The last tenth is differently sized...
+ for (size_t i = 0; i < querySet.n_cols % 10; i++)
+ {
+ size_t ind = (querySet.n_cols / 10) * 10 + i;
double worstDistance = SortPolicy::WorstDistance();
- SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), i,
- worstDistance, *neighborPtr, *distancePtr);
+ ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
+ referenceTree, worstDistance, *neighborPtr, *distancePtr);
}
}
else // Dual-tree recursion.
@@ -174,13 +190,13 @@
// Start on the root of each tree.
if (queryTree)
{
- DualTreeRecursion(referenceTree, queryTree,
+ ComputeDualNeighborsRecursion(queryTree, referenceTree,
SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
*neighborPtr, *distancePtr);
}
else
{
- DualTreeRecursion(referenceTree, referenceTree,
+ ComputeDualNeighborsRecursion(referenceTree, referenceTree,
SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
*neighborPtr, *distancePtr);
}
@@ -275,13 +291,13 @@
delete neighborPtr;
delete distancePtr;
}
-} // ComputeNeighbors
+} // Search
/**
* Performs exhaustive computation between two leaves.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
TreeType* queryNode,
TreeType* referenceNode,
arma::Mat<size_t>& neighbors,
@@ -338,15 +354,16 @@
// Update the upper bound for the queryNode
queryNode->Stat().Bound() = queryWorstDistance;
-} // BaseCase()
+} // ComputeBaseCase()
/**
* The recursive function for dual tree.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::DualTreeRecursion(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeDualNeighborsRecursion(
+ TreeType* queryNode,
TreeType* referenceNode,
- TreeType* queryNode,
const double lowerBound,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -360,7 +377,7 @@
if (queryNode->IsLeaf() && referenceNode->IsLeaf())
{
// Base case: both are leaves.
- BaseCase(referenceNode, queryNode, neighbors, distances);
+ ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
return;
}
@@ -377,16 +394,16 @@
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Left(), queryNode,
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
leftDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode,
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Right(), queryNode,
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
rightDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode,
+ ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
leftDistance, neighbors, distances);
}
return;
@@ -400,9 +417,9 @@
double rightDistance = SortPolicy::BestNodeToNodeDistance(
queryNode->Right(), referenceNode);
- DualTreeRecursion(referenceNode, queryNode->Left(),
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
leftDistance, neighbors, distances);
- DualTreeRecursion(referenceNode, queryNode->Right(),
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
rightDistance, neighbors, distances);
// We need to update the upper bound based on the new upper bounds of the
@@ -428,16 +445,16 @@
// Recurse on queryNode->left() first.
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
leftDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
rightDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+ ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
leftDistance, neighbors, distances);
}
@@ -449,16 +466,16 @@
// Now recurse on queryNode->right().
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
leftDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
rightDistance, neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+ ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
leftDistance, neighbors, distances);
}
@@ -474,13 +491,14 @@
} // ComputeDualNeighborsRecursion()
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
- TreeType* referenceNode,
- const arma::vec& queryPoint,
- const size_t queryIndex,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+template<typename VecType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+ComputeSingleNeighborsRecursion(const size_t pointId,
+ const VecType& point,
+ TreeType* referenceNode,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
if (referenceNode->IsLeaf())
{
@@ -490,31 +508,32 @@
{
// Confirm that points do not identify themselves as neighbors
// in the monochromatic case
- if (queryTree || (referenceIndex != queryIndex))
+ if (!(referenceSet.memptr() == querySet.memptr() &&
+ referenceIndex == pointId))
{
arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
- double distance = metric.Evaluate(queryPoint, referencePoint);
+ double distance = metric.Evaluate(point, referencePoint);
// If the reference point is better than any of the current candidates,
// insert it into the list correctly.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
+ arma::vec queryDist = distances.unsafe_col(pointId);
size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance,
+ InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
neighbors, distances);
}
} // for referenceIndex
- bestDistSoFar = distances(distances.n_rows - 1, queryIndex);
+ bestDistSoFar = distances(distances.n_rows - 1, pointId);
}
else
{
// We'll order the computation by distance.
- double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ double leftDistance = SortPolicy::BestPointToNodeDistance(point,
referenceNode->Left());
- double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+ double rightDistance = SortPolicy::BestPointToNodeDistance(point,
referenceNode->Right());
// Recurse in the best direction first.
@@ -523,13 +542,13 @@
if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
bestDistSoFar, neighbors, distances);
if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
bestDistSoFar, neighbors, distances);
}
@@ -538,13 +557,13 @@
if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
bestDistSoFar, neighbors, distances);
if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
+ ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
bestDistSoFar, neighbors, distances);
}
}
@@ -559,13 +578,13 @@
* @param distance Distance from query point to reference point.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
// We only memmove() if there is actually a need to shift something.
if (pos < (distances.n_rows - 1))
More information about the mlpack-svn
mailing list