[mlpack-svn] r15802 - mlpack/trunk/src/mlpack/methods/range_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Sep 17 22:05:42 EDT 2013
Author: rcurtin
Date: Tue Sep 17 22:05:42 2013
New Revision: 15802
Log:
Refactor RangeSearch to work properly for both cover trees and kd-trees with the
new cover tree traversal.
Modified:
mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp Tue Sep 17 22:05:42 2013
@@ -207,10 +207,11 @@
//! Mappings to old query indices (used when this object builds trees).
std::vector<size_t> oldFromNewQueries;
- //! Indicates ownership of the reference tree (meaning we need to delete it).
- bool ownReferenceTree;
- //! Indicates ownership of the query tree (meaning we need to delete it).
- bool ownQueryTree;
+ //! If true, this object is responsible for deleting the trees.
+ bool treeOwner;
+ //! If true, a query set was passed; if false, the query set is the reference
+ //! set.
+ bool hasQuerySet;
//! If true, O(n^2) naive computation is used.
bool naive;
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp Tue Sep 17 22:05:42 2013
@@ -28,8 +28,8 @@
queryCopy(querySet),
referenceSet(referenceCopy),
querySet(queryCopy),
- ownReferenceTree(true),
- ownQueryTree(true),
+ treeOwner(true),
+ hasQuerySet(true),
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
metric(metric),
@@ -59,8 +59,8 @@
referenceSet(referenceCopy),
querySet(referenceCopy),
queryTree(NULL),
- ownReferenceTree(true),
- ownQueryTree(false),
+ treeOwner(true),
+ hasQuerySet(false),
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
metric(metric),
@@ -73,6 +73,10 @@
referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
(naive ? referenceCopy.n_cols : leafSize));
+ // If using dual-tree mode, then we need a second tree.
+ if (!singleMode)
+ queryTree = new TreeType(*referenceTree);
+
Timer::Stop("range_search/tree_building");
}
@@ -88,8 +92,8 @@
querySet(querySet),
referenceTree(referenceTree),
queryTree(queryTree),
- ownReferenceTree(false),
- ownQueryTree(false),
+ treeOwner(false),
+ hasQuerySet(true),
naive(false),
singleMode(singleMode),
metric(metric),
@@ -108,22 +112,31 @@
querySet(referenceSet),
referenceTree(referenceTree),
queryTree(NULL),
- ownReferenceTree(false),
- ownQueryTree(false),
+ treeOwner(false),
+ hasQuerySet(false),
naive(false),
singleMode(singleMode),
metric(metric),
numPrunes(0)
{
- // Nothing else to initialize.
+ // If doing dual-tree range search, we must clone the reference tree.
+ if (!singleMode)
+ queryTree = new TreeType(*referenceTree);
}
template<typename MetricType, typename TreeType>
RangeSearch<MetricType, TreeType>::~RangeSearch()
{
- if (ownReferenceTree)
- delete referenceTree;
- if (ownQueryTree)
+ if (treeOwner)
+ {
+ if (referenceTree)
+ delete referenceTree;
+ if (queryTree)
+ delete queryTree;
+ }
+
+ // If doing dual-tree search with one dataset, we cloned the reference tree.
+ if (!treeOwner && !hasQuerySet && !singleMode)
delete queryTree;
}
@@ -145,9 +158,9 @@
std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
std::vector<std::vector<double> >* distancePtr = &distances;
- if (ownQueryTree || (ownReferenceTree && !queryTree))
+ if (treeOwner && !(singleMode && hasQuerySet))
distancePtr = new std::vector<std::vector<double> >;
- if (ownReferenceTree || ownQueryTree)
+ if (treeOwner)
neighborPtr = new std::vector<std::vector<size_t> >;
// Resize each vector.
@@ -177,10 +190,7 @@
// Create the traverser.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
- if (queryTree)
- traverser.Traverse(*queryTree, *referenceTree);
- else
- traverser.Traverse(*referenceTree, *referenceTree);
+ traverser.Traverse(*queryTree, *referenceTree);
numPrunes = traverser.NumPrunes();
}
@@ -192,12 +202,12 @@
<< "." << std::endl;
// Map points back to original indices, if necessary.
- if (!ownReferenceTree && !ownQueryTree)
+ if (!treeOwner)
{
// No mapping needed. We are done.
return;
}
- else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+ else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
{
neighbors.clear();
neighbors.resize(querySet.n_cols);
@@ -222,71 +232,48 @@
delete neighborPtr;
delete distancePtr;
}
- else if (ownReferenceTree)
+ else if (treeOwner && !hasQuerySet)
{
- if (!queryTree) // No query tree -- map both references and queries.
- {
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
-
- for (size_t i = 0; i < distances.size(); i++)
- {
- // Map distances (copy a column).
- size_t refMapping = oldFromNewReferences[i];
- distances[refMapping] = (*distancePtr)[i];
-
- // Copy each neighbor individually, because we need to map it.
- neighbors[refMapping].resize(distances[refMapping].size());
- for (size_t j = 0; j < distances[refMapping].size(); j++)
- {
- neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
- }
+ neighbors.clear();
+ neighbors.resize(querySet.n_cols);
+ distances.clear();
+ distances.resize(querySet.n_cols);
- // Finished with temporary objects.
- delete neighborPtr;
- delete distancePtr;
- }
- else // Map only references.
+ for (size_t i = 0; i < distances.size(); i++)
{
- neighbors.clear();
- neighbors.resize(querySet.n_cols);
+ // Map distances (copy a column).
+ size_t refMapping = oldFromNewReferences[i];
+ distances[refMapping] = (*distancePtr)[i];
- // Map indices of neighbors.
- for (size_t i = 0; i < neighbors.size(); i++)
+ // Copy each neighbor individually, because we need to map it.
+ neighbors[refMapping].resize(distances[refMapping].size());
+ for (size_t j = 0; j < distances[refMapping].size(); j++)
{
- neighbors[i].resize((*neighborPtr)[i].size());
- for (size_t j = 0; j < neighbors[i].size(); j++)
- {
- neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
- }
+ neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
}
-
- // Finished with temporary object.
- delete neighborPtr;
}
+
+ // Finished with temporary objects.
+ delete neighborPtr;
+ delete distancePtr;
}
- else if (ownQueryTree)
+ else if (treeOwner && hasQuerySet && singleMode) // Map only references.
{
neighbors.clear();
neighbors.resize(querySet.n_cols);
- distances.clear();
- distances.resize(querySet.n_cols);
- for (size_t i = 0; i < distances.size(); i++)
+ // Map indices of neighbors.
+ for (size_t i = 0; i < neighbors.size(); i++)
{
- // Map distances (copy a column).
- distances[oldFromNewQueries[i]] = (*distancePtr)[i];
-
- // Map neighbors.
- neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
+ neighbors[i].resize((*neighborPtr)[i].size());
+ for (size_t j = 0; j < neighbors[i].size(); j++)
+ {
+ neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+ }
}
- // Finished with temporary objects.
+ // Finished with temporary object.
delete neighborPtr;
- delete distancePtr;
}
}
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp Tue Sep 17 22:05:42 2013
@@ -52,20 +52,6 @@
double Score(const size_t queryIndex, TreeType& referenceNode);
/**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult);
-
- /**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned). This is used when the score has already
@@ -91,20 +77,6 @@
double Score(TreeType& queryNode, TreeType& referenceNode);
/**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult);
-
- /**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned). This is used when the score has already
@@ -138,12 +110,16 @@
//! The instantiated metric.
MetricType& metric;
+ //! The last query index.
+ size_t lastQueryIndex;
+ //! The last reference index.
+ size_t lastReferenceIndex;
+
//! Add all the points in the given node to the results for the given query
//! point. If the base case has already been calculated, we make sure to not
//! add that to the results twice.
void AddResult(const size_t queryIndex,
- TreeType& referenceNode,
- const bool hasBaseCase);
+ TreeType& referenceNode);
};
}; // namespace range
Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp Tue Sep 17 22:05:42 2013
@@ -26,7 +26,9 @@
range(range),
neighbors(neighbors),
distances(distances),
- metric(metric)
+ metric(metric),
+ lastQueryIndex(querySet.n_cols),
+ lastReferenceIndex(referenceSet.n_cols)
{
// Nothing to do.
}
@@ -42,9 +44,20 @@
if ((&referenceSet == &querySet) && (queryIndex == referenceIndex))
return 0.0;
+ // If we have just performed this base case, don't do it again.
+ if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
+ return 0.0; // No value to return... this shouldn't do anything bad.
+
+// if (queryIndex == 0 && referenceIndex == 0)
+// Log::Warn << "base case 0 0 called!\n";
+
const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
+ // Update last indices, so we don't accidentally perform a base case twice.
+ lastQueryIndex = queryIndex;
+ lastReferenceIndex = referenceIndex;
+
if (range.Contains(distance))
{
neighbors[queryIndex].push_back(referenceIndex);
@@ -59,36 +72,43 @@
double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode)
{
- const math::Range distances =
- referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+ // We must get the minimum and maximum distances and store them in this
+ // object.
+ math::Range distances;
- // If the ranges do not overlap, prune this node.
- if (!distances.Contains(range))
- return DBL_MAX;
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ // In this situation, we calculate the base case. So we should check to be
+ // sure we haven't already done that.
+ double baseCase;
+ if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+ (referenceNode.Parent() != NULL) &&
+ (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
+ {
+ // If the tree has self-children and this is a self-child, the base case
+ // was already calculated.
+ baseCase = referenceNode.Parent()->Stat().LastDistance();
+ lastQueryIndex = queryIndex;
+ lastReferenceIndex = referenceNode.Point(0);
+ }
+ else
+ {
+ // We must calculate the base case by hand.
+ baseCase = BaseCase(queryIndex, referenceNode.Point(0));
+ }
+
+ // This may be possibly loose for non-ball bound trees.
+ distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance();
+ distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance();
- // In this case, all of the points in the reference node will be part of the
- // results.
- if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+ // Update last distance calculation.
+ referenceNode.Stat().LastDistance() = baseCase;
+ }
+ else
{
- AddResult(queryIndex, referenceNode, false);
- return DBL_MAX; // We don't need to go any deeper.
+ distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
}
- // Otherwise the score doesn't matter. Recursion order is irrelevant in range
- // search.
- return 0.0;
-}
-
-//! Single-tree scoring function.
-template<typename MetricType, typename TreeType>
-double RangeSearchRules<MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult)
-{
- const math::Range distances = referenceNode.RangeDistance(
- querySet.unsafe_col(queryIndex), baseCaseResult);
-
// If the ranges do not overlap, prune this node.
if (!distances.Contains(range))
return DBL_MAX;
@@ -97,12 +117,12 @@
// results.
if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
{
- AddResult(queryIndex, referenceNode, true);
+ AddResult(queryIndex, referenceNode);
return DBL_MAX; // We don't need to go any deeper.
}
- // Otherwise the score doesn't matter. Recursion order is irrelevant in range
- // search.
+ // Otherwise the score doesn't matter. Recursion order is irrelevant in
+ // range search.
return 0.0;
}
@@ -122,35 +142,90 @@
double RangeSearchRules<MetricType, TreeType>::Score(TreeType& queryNode,
TreeType& referenceNode)
{
- const math::Range distances = referenceNode.RangeDistance(&queryNode);
-
- // If the ranges do not overlap, prune this node.
- if (!distances.Contains(range))
- return DBL_MAX;
-
- // In this case, all of the points in the reference node will be part of all
- // the results for each point in the query node.
- if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+ math::Range distances;
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
{
- for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
- AddResult(queryNode.Descendant(i), referenceNode, false);
- return DBL_MAX; // We don't need to go any deeper.
+ // It is possible that the base case has already been calculated.
+ double baseCase = 0.0;
+ bool alreadyDone = false;
+ if (tree::TreeTraits<TreeType>::HasSelfChildren)
+ {
+ TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode();
+ TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode();
+
+ // Did the query node's last combination do the base case?
+ if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0)))
+ {
+ baseCase = queryNode.Stat().LastDistance();
+ alreadyDone = true;
+ }
+
+ // Did the reference node's last combination do the base case?
+ if ((lastQuery != NULL) && (queryNode.Point(0) == lastQuery->Point(0)))
+ {
+ baseCase = referenceNode.Stat().LastDistance();
+ alreadyDone = true;
+ }
+
+ // If the query node is a self-child, did the query parent's last
+ // combination do the base case?
+ if ((queryNode.Parent() != NULL) &&
+ (queryNode.Point(0) == queryNode.Parent()->Point(0)))
+ {
+ TreeType* lastParentRef = (TreeType*)
+ queryNode.Parent()->Stat().LastDistanceNode();
+ if ((lastParentRef != NULL) &&
+ (referenceNode.Point(0) == lastParentRef->Point(0)))
+ {
+ baseCase = queryNode.Parent()->Stat().LastDistance();
+ alreadyDone = true;
+ }
+ }
+
+ // If the reference node is a self-child, did the reference parent's last
+ // combination do the base case?
+ if ((referenceNode.Parent() != NULL) &&
+ (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
+ {
+ TreeType* lastQueryRef = (TreeType*)
+ referenceNode.Parent()->Stat().LastDistanceNode();
+ if ((lastQueryRef != NULL) &&
+ (queryNode.Point(0) == lastQueryRef->Point(0)))
+ {
+ baseCase = referenceNode.Parent()->Stat().LastDistance();
+ alreadyDone = true;
+ }
+ }
+ }
+
+ if (!alreadyDone)
+ {
+ // We must calculate the base case.
+ baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+ }
+ else
+ {
+ // Make sure that if BaseCase() is called, we don't duplicate results.
+ lastQueryIndex = queryNode.Point(0);
+ lastReferenceIndex = referenceNode.Point(0);
+ }
+
+ distances.Lo() = baseCase - queryNode.FurthestDescendantDistance()
+ - referenceNode.FurthestDescendantDistance();
+ distances.Hi() = baseCase + queryNode.FurthestDescendantDistance()
+ + referenceNode.FurthestDescendantDistance();
+
+ // Update the last distances performed for the query and reference node.
+ queryNode.Stat().LastDistanceNode() = (void*) &referenceNode;
+ queryNode.Stat().LastDistance() = baseCase;
+ referenceNode.Stat().LastDistanceNode() = (void*) &queryNode;
+ referenceNode.Stat().LastDistance() = baseCase;
+ }
+ else
+ {
+ // Just perform the calculation.
+ distances = referenceNode.RangeDistance(&queryNode);
}
-
- // Otherwise the score doesn't matter. Recursion order is irrelevant in range
- // search.
- return 0.0;
-}
-
-//! Dual-tree scoring function.
-template<typename MetricType, typename TreeType>
-double RangeSearchRules<MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult)
-{
- const math::Range distances = referenceNode.RangeDistance(&queryNode,
- baseCaseResult);
// If the ranges do not overlap, prune this node.
if (!distances.Contains(range))
@@ -160,11 +235,8 @@
// the results for each point in the query node.
if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
{
- AddResult(queryNode.Descendant(0), referenceNode, true);
- // We have not calculated the base case for any descendants other than the
- // first point.
- for (size_t i = 1; i < queryNode.NumDescendants(); ++i)
- AddResult(queryNode.Descendant(i), referenceNode, false);
+ for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
+ AddResult(queryNode.Descendant(i), referenceNode);
return DBL_MAX; // We don't need to go any deeper.
}
@@ -188,14 +260,15 @@
//! point.
template<typename MetricType, typename TreeType>
void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex,
- TreeType& referenceNode,
- const bool hasBaseCase)
+ TreeType& referenceNode)
{
// Some types of trees calculate the base case evaluation before Score() is
// called, so if the base case has already been calculated, then we must avoid
// adding that point to the results again.
size_t baseCaseMod = 0;
- if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && hasBaseCase)
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid &&
+ (queryIndex == lastQueryIndex) &&
+ (referenceNode.Point(0) == lastReferenceIndex))
{
baseCaseMod = 1;
}
More information about the mlpack-svn
mailing list