[mlpack-svn] r10720 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 11:43:39 EST 2011
Author: rcurtin
Date: 2011-12-12 11:43:39 -0500 (Mon, 12 Dec 2011)
New Revision: 10720
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Refactor API to be consistent with RangeSearch.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2011-12-12 16:43:39 UTC (rev 10720)
@@ -151,7 +151,7 @@
}
Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allkfn->ComputeNeighbors(k, neighbors, distances);
+ allkfn->Search(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2011-12-12 16:43:39 UTC (rev 10720)
@@ -153,7 +153,7 @@
}
Log::Info << "Computing " << k << " nearest neighbors..." << endl;
- allknn->ComputeNeighbors(k, neighbors, distances);
+ allknn->Search(k, neighbors, distances);
Log::Info << "Neighbors computed." << endl;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-12 16:43:39 UTC (rev 10720)
@@ -97,8 +97,8 @@
* @param queryTree Optionally pass a pre-built tree for the query set.
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -131,7 +131,7 @@
* set.
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const arma::mat& referenceSet,
+ NeighborSearch(const typename TreeType::Mat& referenceSet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -156,9 +156,9 @@
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void ComputeNeighbors(const size_t k,
- arma::Mat<size_t>& resultingNeighbors,
- arma::mat& distances);
+ void Search(const size_t k,
+ arma::Mat<size_t>& resultingNeighbors,
+ arma::mat& distances);
private:
/**
@@ -174,10 +174,10 @@
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void ComputeBaseCase(TreeType* queryNode,
- TreeType* referenceNode,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ void BaseCase(TreeType* referenceNode,
+ TreeType* queryNode,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Recurse down the trees, computing base case computations when the leaves
@@ -189,11 +189,11 @@
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void ComputeDualNeighborsRecursion(TreeType* queryNode,
- TreeType* referenceNode,
- const double lowerBound,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ void DualTreeRecursion(TreeType* referenceNode,
+ TreeType* queryNode,
+ const double lowerBound,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Perform a recursion only on the reference tree; the query point is given.
@@ -206,12 +206,13 @@
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- void ComputeSingleNeighborsRecursion(const size_t pointId,
- const arma::vec& point,
- TreeType* referenceNode,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ template<typename VecType>
+ void SingleTreeRecursion(TreeType* referenceNode,
+ const VecType& queryPoint,
+ const size_t queryIndex,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
/**
* Insert a point into the neighbors and distances matrices; this is a helper
@@ -233,14 +234,14 @@
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
- arma::mat referenceCopy;
+ typename TreeType::Mat referenceCopy;
//! Copy of query dataset (if we need it, because tree building modifies it).
- arma::mat queryCopy;
+ typename TreeType::Mat queryCopy;
- //! Reference dataset.
- const arma::mat& referenceSet;
- //! Query dataset (may not be given).
- const arma::mat& querySet;
+ //! Reference dataset (data should be accessed using this).
+ const typename TreeType::Mat& referenceSet;
+ //! Query dataset (data should be accessed using this).
+ const typename TreeType::Mat& querySet;
//! Indicates if O(n^2) naive search is being used.
bool naive;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-12 16:12:00 UTC (rev 10719)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-12 16:43:39 UTC (rev 10720)
@@ -15,8 +15,8 @@
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const arma::mat& referenceSet,
- const arma::mat& querySet,
+NeighborSearch(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool naive,
const bool singleMode,
const size_t leafSize,
@@ -72,7 +72,7 @@
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const arma::mat& referenceSet,
+NeighborSearch(const typename TreeType::Mat& referenceSet,
const bool naive,
const bool singleMode,
const size_t leafSize,
@@ -126,7 +126,7 @@
* distances.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeNeighbors(
+void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances)
@@ -154,35 +154,20 @@
{
// Run the base case computation on all nodes
if (queryTree)
- ComputeBaseCase(queryTree, referenceTree, *neighborPtr, *distancePtr);
+ BaseCase(referenceTree, queryTree, *neighborPtr, *distancePtr);
else
- ComputeBaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
+ BaseCase(referenceTree, referenceTree, *neighborPtr, *distancePtr);
}
else
{
if (singleMode)
{
- // Do one tenth of the query set at a time.
- size_t chunk = querySet.n_cols / 10;
-
- for (size_t i = 0; i < 10; i++)
+ // Loop over each point in the query set.
+ for (size_t i = 0; i < querySet.n_cols; i++)
{
- for (size_t j = 0; j < chunk; j++)
- {
- double worstDistance = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion(i * chunk + j,
- querySet.unsafe_col(i * chunk + j), referenceTree, worstDistance,
- *neighborPtr, *distancePtr);
- }
- }
-
- // The last tenth is differently sized...
- for (size_t i = 0; i < querySet.n_cols % 10; i++)
- {
- size_t ind = (querySet.n_cols / 10) * 10 + i;
double worstDistance = SortPolicy::WorstDistance();
- ComputeSingleNeighborsRecursion(ind, querySet.unsafe_col(ind),
- referenceTree, worstDistance, *neighborPtr, *distancePtr);
+ SingleTreeRecursion(referenceTree, querySet.col(i), i, worstDistance,
+ *neighborPtr, *distancePtr);
}
}
else // Dual-tree recursion.
@@ -190,13 +175,13 @@
// Start on the root of each tree.
if (queryTree)
{
- ComputeDualNeighborsRecursion(queryTree, referenceTree,
+ DualTreeRecursion(queryTree, referenceTree,
SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
*neighborPtr, *distancePtr);
}
else
{
- ComputeDualNeighborsRecursion(referenceTree, referenceTree,
+ DualTreeRecursion(referenceTree, referenceTree,
SortPolicy::BestNodeToNodeDistance(referenceTree, referenceTree),
*neighborPtr, *distancePtr);
}
@@ -297,9 +282,9 @@
* Performs exhaustive computation between two leaves.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::ComputeBaseCase(
- TreeType* queryNode,
+void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
TreeType* referenceNode,
+ TreeType* queryNode,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
@@ -311,16 +296,14 @@
for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
queryIndex++)
{
- // Get the query point from the matrix.
- arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-
+ // Get the best possible distance from the query point to the node.
double queryToNodeDistance =
- SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
+ SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
+ referenceNode);
if (SortPolicy::IsBetter(queryToNodeDistance,
distances(distances.n_rows - 1, queryIndex)))
{
- // We'll do the same for the references.
for (size_t referenceIndex = referenceNode->Begin();
referenceIndex < referenceNode->End(); referenceIndex++)
{
@@ -328,10 +311,9 @@
// in the monochromatic case.
if (referenceNode != queryNode || referenceIndex != queryIndex)
{
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceIndex));
- double distance = metric.Evaluate(queryPoint, referencePoint);
-
// If the reference point is closer than any of the current
// candidates, add it to the list.
arma::vec queryDist = distances.unsafe_col(queryIndex);
@@ -360,10 +342,9 @@
* The recursive function for dual tree.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeDualNeighborsRecursion(
- TreeType* queryNode,
+void NeighborSearch<SortPolicy, MetricType, TreeType>::DualTreeRecursion(
TreeType* referenceNode,
+ TreeType* queryNode,
const double lowerBound,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -377,7 +358,7 @@
if (queryNode->IsLeaf() && referenceNode->IsLeaf())
{
// Base case: both are leaves.
- ComputeBaseCase(queryNode, referenceNode, neighbors, distances);
+ BaseCase(referenceNode, queryNode, neighbors, distances);
return;
}
@@ -394,17 +375,17 @@
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
- rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
+ neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode, referenceNode->Left(),
- leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
+ neighbors, distances);
}
return;
}
@@ -417,10 +398,10 @@
double rightDistance = SortPolicy::BestNodeToNodeDistance(
queryNode->Right(), referenceNode);
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode,
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode,
- rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode, queryNode->Left(), leftDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode, queryNode->Right(), rightDistance,
+ neighbors, distances);
// We need to update the upper bound based on the new upper bounds of the
// children.
@@ -445,17 +426,17 @@
// Recurse on queryNode->left() first.
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
- rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
+ neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Left(), referenceNode->Left(),
- leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
+ neighbors, distances);
}
leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
@@ -466,17 +447,17 @@
// Now recurse on queryNode->right().
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
- leftDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
- rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
+ neighbors, distances);
}
else
{
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Right(),
- rightDistance, neighbors, distances);
- ComputeDualNeighborsRecursion(queryNode->Right(), referenceNode->Left(),
- leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
+ neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
+ neighbors, distances);
}
// Update the upper bound as above
@@ -491,13 +472,14 @@
} // ComputeDualNeighborsRecursion()
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-ComputeSingleNeighborsRecursion(const size_t pointId,
- const arma::vec& point,
- TreeType* referenceNode,
- double& bestDistSoFar,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+template<typename VecType>
+void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
+ TreeType* referenceNode,
+ const VecType& queryPoint,
+ const size_t queryIndex,
+ double& bestDistSoFar,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
if (referenceNode->IsLeaf())
{
@@ -506,33 +488,31 @@
referenceIndex < referenceNode->End(); referenceIndex++)
{
// Confirm that points do not identify themselves as neighbors
- // in the monochromatic case
- if (!(referenceSet.memptr() == querySet.memptr() &&
- referenceIndex == pointId))
+ // in the monochromatic case.
+ if (!queryTree && !(referenceIndex == queryIndex))
{
- arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(queryPoint,
+ referenceSet.col(referenceIndex));
- double distance = metric.Evaluate(point, referencePoint);
-
// If the reference point is better than any of the current candidates,
// insert it into the list correctly.
- arma::vec queryDist = distances.unsafe_col(pointId);
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
if (insertPosition != (size_t() - 1))
- InsertNeighbor(pointId, insertPosition, referenceIndex, distance,
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance,
neighbors, distances);
}
} // for referenceIndex
- bestDistSoFar = distances(distances.n_rows - 1, pointId);
+ bestDistSoFar = distances(distances.n_rows - 1, queryIndex);
}
else
{
// We'll order the computation by distance.
- double leftDistance = SortPolicy::BestPointToNodeDistance(point,
+ double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
referenceNode->Left());
- double rightDistance = SortPolicy::BestPointToNodeDistance(point,
+ double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
referenceNode->Right());
// Recurse in the best direction first.
@@ -541,13 +521,13 @@
if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+ SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
bestDistSoFar, neighbors, distances);
if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+ SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
bestDistSoFar, neighbors, distances);
}
@@ -556,13 +536,13 @@
if (SortPolicy::IsBetter(bestDistSoFar, rightDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Right(),
+ SingleTreeRecursion(referenceNode->Right(), queryPoint, queryIndex,
bestDistSoFar, neighbors, distances);
if (SortPolicy::IsBetter(bestDistSoFar, leftDistance))
numberOfPrunes++; // Prune; no possibility of finding a better point.
else
- ComputeSingleNeighborsRecursion(pointId, point, referenceNode->Left(),
+ SingleTreeRecursion(referenceNode->Left(), queryPoint, queryIndex,
bestDistSoFar, neighbors, distances);
}
}
@@ -577,13 +557,13 @@
* @param distance Distance from query point to reference point.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearch<SortPolicy, MetricType, TreeType>::
-InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+void NeighborSearch<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
{
// We only memmove() if there is actually a need to shift something.
if (pos < (distances.n_rows - 1))
More information about the mlpack-svn
mailing list