[mlpack-git] master, mlpack-1.0.x: Overhaul NeighborSearchRules to work correctly with TraversalInfo objects. This is related to #243. (e57ad28)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:25 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit e57ad282e684ae2802faedaa0e7ca13ffc1d3d61
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 6 20:17:50 2014 +0000
Overhaul NeighborSearchRules to work correctly with TraversalInfo objects. This
is related to #243.
>---------------------------------------------------------------
e57ad282e684ae2802faedaa0e7ca13ffc1d3d61
.../neighbor_search/neighbor_search_impl.hpp | 6 +-
.../neighbor_search/neighbor_search_rules.hpp | 29 +++
.../neighbor_search/neighbor_search_rules_impl.hpp | 235 ++++++++++++++-------
3 files changed, 186 insertions(+), 84 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 0ff34fc..0905d81 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -184,6 +184,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
// Set the size of the neighbor and distance matrices.
neighborPtr->set_size(k, querySet.n_cols);
+ neighborPtr->fill(size_t() - 1);
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
@@ -210,9 +211,8 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
traverser.Traverse(*queryTree, *referenceTree);
- Log::Info << traverser.NumVisited() << " node combinations were visited.\n";
- Log::Info << traverser.NumScores() << " node combinations were scored.\n";
- Log::Info << traverser.NumBaseCases() << " base cases were calculated.\n";
+ Log::Info << rules.Scores() << " node combinations were scored.\n";
+ Log::Info << rules.BaseCases() << " base cases were calculated.\n";
}
Timer::Stop("computing_neighbors");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 6d24a75..3f46b0d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -8,6 +8,8 @@
#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
+#include "ns_traversal_info.hpp"
+
namespace mlpack {
namespace neighbor {
@@ -73,6 +75,24 @@ class NeighborSearchRules
TreeType& referenceNode,
const double oldScore) const;
+ //! Get the number of base cases that have been performed.
+ size_t BaseCases() const { return baseCases; }
+ //! Modify the number of base cases that have been performed.
+ size_t& BaseCases() { return baseCases; }
+
+ //! Get the number of scores that have been performed.
+ size_t Scores() const { return scores; }
+ //! Modify the number of scores that have been performed.
+ size_t& Scores() { return scores; }
+
+ //! Convenience typedef.
+ typedef NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ //! Get the traversal info.
+ const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ //! Modify the traversal info.
+ TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -96,6 +116,15 @@ class NeighborSearchRules
//! The last base case result.
double lastBaseCase;
+ //! The number of base cases that have been performed.
+ size_t baseCases;
+ //! The number of scores that have been performed.
+ size_t scores;
+
+ //! Traversal info for the parent combination; this is updated by the
+ //! traversal before each call to Score().
+ TraversalInfoType traversalInfo;
+
/**
* Recalculate the bound for a given query node.
*/
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index fe103ec..031965c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -26,8 +26,16 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
distances(distances),
metric(metric),
lastQueryIndex(querySet.n_cols),
- lastReferenceIndex(referenceSet.n_cols)
-{ /* Nothing left to do. */ }
+ lastReferenceIndex(referenceSet.n_cols),
+ baseCases(0),
+ scores(0)
+{
+ // We must set the traversal info last query and reference node pointers to
+ // something that is both invalid (i.e. not a tree node) and not NULL. We'll
+ // use the this pointer.
+ traversalInfo.LastQueryNode() = (TreeType*) this;
+ traversalInfo.LastReferenceNode() = (TreeType*) this;
+}
template<typename SortPolicy, typename MetricType, typename TreeType>
inline force_inline // Absolutely MUST be inline so optimizations can happen.
@@ -45,11 +53,14 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
+ ++baseCases;
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
- const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ const size_t insertPosition = SortPolicy::SortDistance(queryDist,
+ queryIndices, distance);
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
@@ -68,6 +79,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
const size_t queryIndex,
TreeType& referenceNode)
{
+ ++scores; // Count number of Score() calls.
double distance;
if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
{
@@ -124,97 +136,128 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+ ++scores; // Count number of Score() calls.
+
+ // Update our bound.
+ const double bestDistance = CalculateBound(queryNode);
+
+ // Use the traversal info to see if a parent-child or parent-parent prune is
+ // possible. This is a looser bound than we could make, but it might be
+ // sufficient.
+ const double queryParentDist = queryNode.ParentDistance();
+ const double queryDescDist = queryNode.FurthestDescendantDistance();
+ const double refParentDist = referenceNode.ParentDistance();
+ const double refDescDist = referenceNode.FurthestDescendantDistance();
+ const double score = traversalInfo.LastScore();
+ double adjustedScore;
+
+ // In some cases we can just use the last base case.
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ adjustedScore = traversalInfo.LastBaseCase();
+ }
+ else if (score == 0.0) // Nothing we can do here.
+ {
+ adjustedScore = 0.0;
+ }
+ else
+ {
+ const double lastQueryDescDist =
+ traversalInfo.LastQueryNode()->FurthestDescendantDistance();
+ const double lastRefDescDist =
+ traversalInfo.LastReferenceNode()->FurthestDescendantDistance();
+ adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
+ adjustedScore = SortPolicy::CombineWorst(score, lastRefDescDist);
+ }
+
+ // Assemble an adjusted score. For nearest neighbor search, this adjusted
+ // score is a lower bound on MinDistance(queryNode, referenceNode) that is
+ // assembled without actually calculating MinDistance(). For furthest
+ // neighbor search, it is an upper bound on
+ // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable
+ // then the node should not be pruned by this.
+ if (traversalInfo.LastQueryNode() == queryNode.Parent())
+ {
+ const double queryAdjust = queryParentDist + queryDescDist;
+ adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
+ }
+ else
+ {
+ adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
+ }
+
+ if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
+ {
+ const double refAdjust = refParentDist + refDescDist;
+ adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
+ }
+ else
+ {
+ adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
+ }
+
+ // Can we prune?
+ if (SortPolicy::IsBetter(bestDistance, adjustedScore))
+ {
+ if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
+ {
+ // There isn't any need to set the traversal information because no
+ // descendant combinations will be visited, and those are the only
+ // combinations that would depend on the traversal information.
+ return DBL_MAX;
+ }
+ }
+
double distance;
if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
{
// The first point in the node is the centroid, so we can calculate the
// distance between the two points using BaseCase() and then find the
// bounds. This is potentially loose for non-ball bounds.
- bool alreadyDone = false;
- double baseCase;
- if (tree::TreeTraits<TreeType>::HasSelfChildren)
+ double baseCase = -1.0;
+ if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+ (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) &&
+ (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)))
{
- // In this case, we may have already calculated the base case.
- TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode();
- TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode();
-
- // Does the query node have the base case cached?
- if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0)))
- {
- baseCase = queryNode.Stat().LastDistance();
- alreadyDone = true;
- }
-
- // Does the reference node have the base case cached?
- if ((lastQuery != NULL) &&
- (queryNode.Point(0) == lastQuery->Point(0)))
- {
- baseCase = referenceNode.Stat().LastDistance();
- alreadyDone = true;
- }
-
- // Is the query node a self-child, and if so, does the query node's parent
- // have the base case cached?
- if ((queryNode.Parent() != NULL) &&
- (queryNode.Parent()->Point(0) == queryNode.Point(0)))
- {
- TreeType* lastParentRef = (TreeType*)
- queryNode.Parent()->Stat().LastDistanceNode();
- if (lastParentRef->Point(0) == referenceNode.Point(0))
- {
- baseCase = queryNode.Parent()->Stat().LastDistance();
- alreadyDone = true;
- }
- }
-
- // Is the reference node a self-child, and if so, does the reference
- // node's parent have the base case cached?
- if ((referenceNode.Parent() != NULL) &&
- (referenceNode.Parent()->Point(0) == referenceNode.Point(0)))
- {
- TreeType* lastParentRef = (TreeType*)
- referenceNode.Parent()->Stat().LastDistanceNode();
- if (lastParentRef->Point(0) == queryNode.Point(0))
- {
- baseCase = referenceNode.Parent()->Stat().LastDistance();
- alreadyDone = true;
- }
- }
- }
-
- // If we did not find a cached base case, then recalculate it.
- if (!alreadyDone)
- {
- baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+ // We already calculated it.
+ baseCase = traversalInfo.LastBaseCase();
}
else
{
- // Set lastQueryIndex and lastReferenceIndex, so that BaseCase() does not
- // duplicate work.
- lastQueryIndex = queryNode.Point(0);
- lastReferenceIndex = referenceNode.Point(0);
- lastBaseCase = baseCase;
+ baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
}
distance = SortPolicy::CombineBest(baseCase,
queryNode.FurthestDescendantDistance() +
referenceNode.FurthestDescendantDistance());
- // Update the last distance calculation for the query and reference nodes.
- queryNode.Stat().LastDistanceNode() = (void*) &referenceNode;
- queryNode.Stat().LastDistance() = baseCase;
- referenceNode.Stat().LastDistanceNode() = (void*) &queryNode;
- referenceNode.Stat().LastDistance() = baseCase;
+ lastQueryIndex = queryNode.Point(0);
+ lastReferenceIndex = referenceNode.Point(0);
+ lastBaseCase = baseCase;
+
+ traversalInfo.LastBaseCase() = baseCase;
}
else
{
distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode);
}
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
+ if (SortPolicy::IsBetter(distance, bestDistance))
+ {
+ // Set traversal information.
+ traversalInfo.LastQueryNode() = &queryNode;
+ traversalInfo.LastReferenceNode() = &referenceNode;
+ traversalInfo.LastScore() = distance;
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
+ return distance;
+ }
+ else
+ {
+ // There isn't any need to set the traversal information because no
+ // descendant combinations will be visited, and those are the only
+ // combinations that would depend on the traversal information.
+ return DBL_MAX;
+ }
}
template<typename SortPolicy, typename MetricType, typename TreeType>
@@ -257,6 +300,30 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
// So we will loop over the points in queryNode and the children in queryNode
// to calculate all five of these quantities.
+ // Hm, can we populate our distances vector with estimates from the parent?
+ // This is written specifically for the cover tree and assumes only one point
+ // in a node.
+// if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0)
+// {
+// size_t parentIndexStart = 0;
+// for (size_t i = 0; i < neighbors.n_rows; ++i)
+// {
+// const double pointDistance = distances(i, queryNode.Point(0));
+// if (pointDistance == DBL_MAX)
+// {
+// // Cool, can we take an estimate from the parent?
+// const double parentWorstBound = distances(distances.n_rows - 1,
+// queryNode.Parent()->Point(0));
+// if (parentWorstBound != DBL_MAX)
+// {
+// const double parentAdjustedDistance = parentWorstBound +
+// queryNode.ParentDistance();
+// distances(i, queryNode.Point(0)) = parentAdjustedDistance;
+// }
+// }
+// }
+// }
+
double worstPointDistance = SortPolicy::BestDistance();
double bestPointDistance = SortPolicy::WorstDistance();
@@ -264,7 +331,8 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
// candidates (for (1) and (2)).
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+ const double distance = distances(distances.n_rows - 1,
+ queryNode.Point(i));
if (SortPolicy::IsBetter(distance, bestPointDistance))
bestPointDistance = distance;
if (SortPolicy::IsBetter(worstPointDistance, distance))
@@ -302,16 +370,21 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
worstChildBound : worstPointDistance;
// This is bound (2).
- const double secondBound = SortPolicy::CombineWorst(bestPointDistance,
- 2 * queryMaxDescendantDistance);
+ const double secondBound = SortPolicy::CombineWorst(
+ SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance),
+ queryNode.FurthestPointDistance());
// Bound (3) is bestAdjustedChildBound.
// Bounds (4) and (5) are the parent bounds.
const double fourthBound = (queryNode.Parent() != NULL) ?
queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
- const double fifthBound = (queryNode.Parent() != NULL) ?
- queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
+// const double fifthBound = (queryNode.Parent() != NULL) ?
+// queryNode.Parent()->Stat().SecondBound() -
+// queryNode.Parent()->FurthestDescendantDistance() -
+// queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance +
+// queryNode.FurthestPointDistance() + queryNode.ParentDistance() :
+// SortPolicy::WorstDistance();
// Now, we will take the best of these. Unfortunately due to the way
// IsBetter() is defined, this sort of has to be a little ugly.
@@ -326,16 +399,16 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
const double interB =
(SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
bestAdjustedChildBound : secondBound;
- const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
- fifthBound;
+// const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
+// fifthBound;
// Update the first and second bounds of the node.
queryNode.Stat().FirstBound() = interA;
- queryNode.Stat().SecondBound() = interC;
+ queryNode.Stat().SecondBound() = interB;
// Update the actual bound of the node.
- queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
- interC;
+ queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB :
+ interB;
return queryNode.Stat().Bound();
}
More information about the mlpack-git
mailing list