[mlpack-svn] r10735 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 17:54:17 EST 2011
Author: rcurtin
Date: 2011-12-12 17:54:17 -0500 (Mon, 12 Dec 2011)
New Revision: 10735
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
Log:
Retry the API change. This time it works...
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-12 21:48:51 UTC (rev 10734)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2011-12-12 22:54:17 UTC (rev 10735)
@@ -97,8 +97,8 @@
* @param queryTree Optionally pass a pre-built tree for the query set.
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
+ NeighborSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -131,7 +131,7 @@
* set.
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const typename TreeType::Mat& referenceSet,
+ NeighborSearch(const arma::mat& referenceSet,
const bool naive = false,
const bool singleMode = false,
const size_t leafSize = 20,
@@ -183,8 +183,8 @@
* Recurse down the trees, computing base case computations when the leaves
* are reached.
*
- * @param queryNode Node in query tree.
* @param referenceNode Node in reference tree.
+ * @param queryNode Node in query tree.
* @param lowerBound The lower bound; if above this, we can prune.
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
@@ -197,18 +197,17 @@
/**
* Perform a recursion only on the reference tree; the query point is given.
- * This method is similar to ComputeBaseCase().
+ * This method is similar to BaseCase().
*
- * @param pointId Index of query point.
- * @param point The query point.
* @param referenceNode Reference node.
+ * @param queryPoint The query point.
+ * @param queryIndex Index of query point.
* @param bestDistSoFar Best distance to a node so far -- used for pruning.
* @param neighbors List of neighbors for each point.
* @param distances List of distances for each point.
*/
- template<typename VecType>
void SingleTreeRecursion(TreeType* referenceNode,
- const VecType& queryPoint,
+ const arma::vec& queryPoint,
const size_t queryIndex,
double& bestDistSoFar,
arma::Mat<size_t>& neighbors,
@@ -234,14 +233,14 @@
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
- typename TreeType::Mat referenceCopy;
+ arma::mat referenceCopy;
//! Copy of query dataset (if we need it, because tree building modifies it).
- typename TreeType::Mat queryCopy;
+ arma::mat queryCopy;
- //! Reference dataset (data should be accessed using this).
- const typename TreeType::Mat& referenceSet;
- //! Query dataset (data should be accessed using this).
- const typename TreeType::Mat& querySet;
+ //! Reference dataset.
+ const arma::mat& referenceSet;
+ //! Query dataset (may not be given).
+ const arma::mat& querySet;
//! Indicates if O(n^2) naive search is being used.
bool naive;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-12 21:48:51 UTC (rev 10734)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp 2011-12-12 22:54:17 UTC (rev 10735)
@@ -15,8 +15,8 @@
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
+NeighborSearch(const arma::mat& referenceSet,
+ const arma::mat& querySet,
const bool naive,
const bool singleMode,
const size_t leafSize,
@@ -72,7 +72,7 @@
// Construct the object.
template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
+NeighborSearch(const arma::mat& referenceSet,
const bool naive,
const bool singleMode,
const size_t leafSize,
@@ -162,12 +162,11 @@
{
if (singleMode)
{
- // Loop over each point in the query set.
for (size_t i = 0; i < querySet.n_cols; i++)
{
double worstDistance = SortPolicy::WorstDistance();
- SingleTreeRecursion(referenceTree, querySet.col(i), i, worstDistance,
- *neighborPtr, *distancePtr);
+ SingleTreeRecursion(referenceTree, querySet.unsafe_col(i), i,
+ worstDistance, *neighborPtr, *distancePtr);
}
}
else // Dual-tree recursion.
@@ -175,7 +174,7 @@
// Start on the root of each tree.
if (queryTree)
{
- DualTreeRecursion(queryTree, referenceTree,
+ DualTreeRecursion(referenceTree, queryTree,
SortPolicy::BestNodeToNodeDistance(queryTree, referenceTree),
*neighborPtr, *distancePtr);
}
@@ -283,8 +282,8 @@
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
void NeighborSearch<SortPolicy, MetricType, TreeType>::BaseCase(
+ TreeType* queryNode,
TreeType* referenceNode,
- TreeType* queryNode,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
@@ -296,14 +295,16 @@
for (size_t queryIndex = queryNode->Begin(); queryIndex < queryNode->End();
queryIndex++)
{
- // Get the best possible distance from the query point to the node.
+ // Get the query point from the matrix.
+ arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+
double queryToNodeDistance =
- SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex),
- referenceNode);
+ SortPolicy::BestPointToNodeDistance(queryPoint, referenceNode);
if (SortPolicy::IsBetter(queryToNodeDistance,
distances(distances.n_rows - 1, queryIndex)))
{
+ // We'll do the same for the references.
for (size_t referenceIndex = referenceNode->Begin();
referenceIndex < referenceNode->End(); referenceIndex++)
{
@@ -311,9 +312,10 @@
// in the monochromatic case.
if (referenceNode != queryNode || referenceIndex != queryIndex)
{
- double distance = metric.Evaluate(querySet.col(queryIndex),
- referenceSet.col(referenceIndex));
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(queryPoint, referencePoint);
+
// If the reference point is closer than any of the current
// candidates, add it to the list.
arma::vec queryDist = distances.unsafe_col(queryIndex);
@@ -336,7 +338,7 @@
// Update the upper bound for the queryNode
queryNode->Stat().Bound() = queryWorstDistance;
-} // ComputeBaseCase()
+} // BaseCase()
/**
* The recursive function for dual tree.
@@ -375,17 +377,17 @@
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode,
+ leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode,
+ rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Right(), queryNode, rightDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode, leftDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode,
+ rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode,
+ leftDistance, neighbors, distances);
}
return;
}
@@ -398,10 +400,10 @@
double rightDistance = SortPolicy::BestNodeToNodeDistance(
queryNode->Right(), referenceNode);
- DualTreeRecursion(referenceNode, queryNode->Left(), leftDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode, queryNode->Right(), rightDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode, queryNode->Left(),
+ leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode, queryNode->Right(),
+ rightDistance, neighbors, distances);
// We need to update the upper bound based on the new upper bounds of the
// children.
@@ -426,17 +428,17 @@
// Recurse on queryNode->left() first.
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+ leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+ rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Left(), queryNode->Right(), rightDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Left(), queryNode->Left(), leftDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Left(),
+ rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Left(),
+ leftDistance, neighbors, distances);
}
leftDistance = SortPolicy::BestNodeToNodeDistance(queryNode->Right(),
@@ -447,17 +449,17 @@
// Now recurse on queryNode->right().
if (SortPolicy::IsBetter(leftDistance, rightDistance))
{
- DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+ leftDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+ rightDistance, neighbors, distances);
}
else
{
- DualTreeRecursion(referenceNode->Right(), queryNode->Right(), rightDistance,
- neighbors, distances);
- DualTreeRecursion(referenceNode->Right(), queryNode->Left(), leftDistance,
- neighbors, distances);
+ DualTreeRecursion(referenceNode->Right(), queryNode->Right(),
+ rightDistance, neighbors, distances);
+ DualTreeRecursion(referenceNode->Left(), queryNode->Right(),
+ leftDistance, neighbors, distances);
}
// Update the upper bound as above
@@ -472,10 +474,9 @@
} // ComputeDualNeighborsRecursion()
template<typename SortPolicy, typename MetricType, typename TreeType>
-template<typename VecType>
void NeighborSearch<SortPolicy, MetricType, TreeType>::SingleTreeRecursion(
TreeType* referenceNode,
- const VecType& queryPoint,
+ const arma::vec& queryPoint,
const size_t queryIndex,
double& bestDistSoFar,
arma::Mat<size_t>& neighbors,
@@ -488,12 +489,13 @@
referenceIndex < referenceNode->End(); referenceIndex++)
{
// Confirm that points do not identify themselves as neighbors
- // in the monochromatic case.
- if (!queryTree && !(referenceIndex == queryIndex))
+ // in the monochromatic case
+ if (queryTree || (referenceIndex != queryIndex))
{
- double distance = metric.Evaluate(queryPoint,
- referenceSet.col(referenceIndex));
+ arma::vec referencePoint = referenceSet.unsafe_col(referenceIndex);
+ double distance = metric.Evaluate(queryPoint, referencePoint);
+
// If the reference point is better than any of the current candidates,
// insert it into the list correctly.
arma::vec queryDist = distances.unsafe_col(queryIndex);
More information about the mlpack-svn
mailing list