[mlpack-git] master: Pre-emptive prunes. Potentially a slowdown. Not always, though, I don't think. I may revert this. (b2fc22a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:48 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit b2fc22aefd46992b4aa12339322e5a37d695c80b
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 19 15:05:51 2015 -0500
Pre-emptive prunes. Potentially a slowdown. Not always, though, I don't think. I may revert this.
>---------------------------------------------------------------
b2fc22aefd46992b4aa12339322e5a37d695c80b
src/mlpack/methods/kmeans/dtnn_rules.hpp | 10 +-
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 178 ++++++++++++++++++++++----
2 files changed, 163 insertions(+), 25 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 5252050..c2f7873 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -9,7 +9,7 @@
#ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
#define __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
-#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/neighbor_search/ns_traversal_info.hpp>
namespace mlpack {
namespace kmeans {
@@ -39,7 +39,7 @@ class DTNNKMeansRules
TreeType& referenceNode,
const double oldScore);
- typedef int TraversalInfoType;
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
TraversalInfoType& TraversalInfo() { return traversalInfo; }
const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
@@ -67,7 +67,11 @@ class DTNNKMeansRules
size_t baseCases;
size_t scores;
- int traversalInfo;
+ TraversalInfoType traversalInfo;
+
+ size_t lastQueryIndex;
+ size_t lastReferenceIndex;
+ size_t lastBaseCase;
};
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index ca8943c..c87c1da 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -33,9 +33,15 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
oldFromNewCentroids(oldFromNewCentroids),
visited(visited),
baseCases(0),
- scores(0)
+ scores(0),
+ lastQueryIndex(dataset.n_cols),
+ lastReferenceIndex(centroids.n_cols)
{
- // Nothing to do.
+ // We must set the traversal info last query and reference node pointers to
+ // something that is both invalid (i.e. not a tree node) and not NULL. We'll
+ // use the this pointer.
+ traversalInfo.LastQueryNode() = (TreeType*) this;
+ traversalInfo.LastReferenceNode() = (TreeType*) this;
}
template<typename MetricType, typename TreeType>
@@ -46,6 +52,10 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
if (prunedPoints[queryIndex])
return 0.0; // Returning 0 shouldn't be a problem.
+ // If we have already performed this base case, then do not perform it again.
+ if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
+ return lastBaseCase;
+
// Any base cases imply that we will get a result.
visited[queryIndex] = true;
@@ -66,6 +76,11 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
lowerBounds[queryIndex] = distance;
}
+ // Cache this information for the next time BaseCase() is called.
+ lastQueryIndex = queryIndex;
+ lastReferenceIndex = referenceIndex;
+ lastBaseCase = distance;
+
return distance;
}
@@ -102,30 +117,144 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
if (queryNode.Stat().Pruned() == centroids.n_cols)
return DBL_MAX;
- // Get minimum and maximum distances.
- math::Range distances = queryNode.RangeDistance(&referenceNode);
- double score = distances.Lo();
- ++scores;
- if (distances.Lo() > queryNode.Stat().UpperBound())
+ // This looks a lot like the hackery used in NeighborSearchRules to avoid
+ // distance computations. We'll use the traversal info to see if a
+ // parent-child or parent-parent prune is possible.
+ const double queryParentDist = queryNode.ParentDistance();
+ const double queryDescDist = queryNode.FurthestDescendantDistance();
+ const double refParentDist = referenceNode.ParentDistance();
+ const double refDescDist = referenceNode.FurthestDescendantDistance();
+ const double lastScore = traversalInfo.LastScore();
+ double adjustedScore;
+ double score = 0.0;
+
+ // We want to set adjustedScore to be the distance between the centroid of the
+ // last query node and last reference node. We will do this by adjusting the
+ // last score. In some cases, we can just use the last base case.
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ adjustedScore = traversalInfo.LastBaseCase();
+ }
+ else if (lastScore == 0.0) // Nothing we can do here.
+ {
+ adjustedScore = 0.0;
+ }
+ else
{
- // The reference node can own no points in this query node. We may improve
- // the lower bound on pruned nodes, though.
- if (distances.Lo() < queryNode.Stat().LowerBound())
- queryNode.Stat().LowerBound() = distances.Lo();
+ // The last score is equal to the distance between the centroids minus the
+ // radii of the query and reference bounds along the axis of the line
+ // between the two centroids. In the best case, these radii are the
+ // furthest descendant distances, but that is not always true. It would
+ // take too long to calculate the exact radii, so we are forced to use
+ // MinimumBoundDistance() as a lower-bound approximation.
+ const double lastQueryDescDist =
+ traversalInfo.LastQueryNode()->MinimumBoundDistance();
+ const double lastRefDescDist =
+ traversalInfo.LastReferenceNode()->MinimumBoundDistance();
+ adjustedScore = lastScore + lastQueryDescDist;
+ adjustedScore = lastScore + lastRefDescDist;
+ }
- // This assumes that reference clusters don't appear elsewhere in the tree.
- queryNode.Stat().Pruned() += referenceNode.NumDescendants();
- score = DBL_MAX;
+ // Assemble an adjusted score. For nearest neighbor search, this adjusted
+ // score is a lower bound on MinDistance(queryNode, referenceNode) that is
+ // assembled without actually calculating MinDistance(). For furthest
+ // neighbor search, it is an upper bound on
+ // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable
+ // then the node should not be pruned by this.
+ if (traversalInfo.LastQueryNode() == queryNode.Parent())
+ {
+ const double queryAdjust = queryParentDist + queryDescDist;
+ adjustedScore -= queryAdjust;
+ }
+ else if (traversalInfo.LastQueryNode() == &queryNode)
+ {
+ adjustedScore -= queryDescDist;
+ }
+ else
+ {
+ // The last query node wasn't this query node or its parent. So we force
+ // the adjustedScore to be such that this combination can't be pruned here,
+ // because we don't really know anything about it.
+
+ // It would be possible to modify this section to try and make a prune based
+ // on the query descendant distance and the distance between the query node
+ // and last traversal query node, but this case doesn't actually happen for
+ // kd-trees or cover trees.
+ adjustedScore = 0.0;
+ }
+ if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
+ {
+ const double refAdjust = refParentDist + refDescDist;
+ adjustedScore -= refAdjust;
+ }
+ else if (traversalInfo.LastReferenceNode() == &referenceNode)
+ {
+ adjustedScore -= refDescDist;
}
- else if (distances.Hi() < queryNode.Stat().UpperBound())
+ else
{
- // We can improve the best estimate.
- queryNode.Stat().UpperBound() = distances.Hi();
- // If this node has only one descendant, then it may be the owner.
- if (referenceNode.NumDescendants() == 1)
- queryNode.Stat().Owner() = (tree::TreeTraits<TreeType>::RearrangesDataset)
- ? oldFromNewCentroids[referenceNode.Descendant(0)]
- : referenceNode.Descendant(0);
+ // The last reference node wasn't this reference node or its parent. So we
+ // force the adjustedScore to be such that this combination can't be pruned
+ // here, because we don't really know anything about it.
+
+ // It would be possible to modify this section to try and make a prune based
+ // on the reference descendant distance and the distance between the
+ // reference node and last traversal reference node, but this case doesn't
+ // actually happen for kd-trees or cover trees.
+ adjustedScore = 0.0;
+ }
+
+ // Now, check if we can prune.
+ //Log::Warn << "adjusted score: " << adjustedScore << ".\n";
+ if (adjustedScore > queryNode.Stat().UpperBound())
+ {
+// Log::Warn << "Pre-emptive prune!\n";
+ if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0))
+ {
+ // There isn't any need to set the traversal information because no
+ // descendant combinations will be visited, and those are the only
+ // combinations that would depend on the traversal information.
+ if (adjustedScore < queryNode.Stat().LowerBound())
+ {
+ // If this might affect the lower bound, make it more exact.
+ queryNode.Stat().LowerBound() = queryNode.MinDistance(&referenceNode);
+ ++scores;
+ }
+
+ queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+ score = DBL_MAX;
+ }
+ }
+
+ if (score != DBL_MAX)
+ {
+ // Get minimum and maximum distances.
+ math::Range distances = queryNode.RangeDistance(&referenceNode);
+ score = distances.Lo();
+ ++scores;
+ if (distances.Lo() > queryNode.Stat().UpperBound())
+ {
+ // The reference node can own no points in this query node. We may
+ // improve the lower bound on pruned nodes, though.
+ if (distances.Lo() < queryNode.Stat().LowerBound())
+ queryNode.Stat().LowerBound() = distances.Lo();
+
+ // This assumes that reference clusters don't appear elsewhere in the
+ // tree.
+ queryNode.Stat().Pruned() += referenceNode.NumDescendants();
+ score = DBL_MAX;
+ }
+ else if (distances.Hi() < queryNode.Stat().UpperBound())
+ {
+ // We can improve the best estimate.
+ queryNode.Stat().UpperBound() = distances.Hi();
+ // If this node has only one descendant, then it may be the owner.
+ if (referenceNode.NumDescendants() == 1)
+ queryNode.Stat().Owner() =
+ (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+ oldFromNewCentroids[referenceNode.Descendant(0)] :
+ referenceNode.Descendant(0);
+ }
}
// Is everything pruned?
@@ -135,6 +264,11 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
return DBL_MAX;
}
+ // Set traversal information.
+ traversalInfo.LastQueryNode() = &queryNode;
+ traversalInfo.LastReferenceNode() = &referenceNode;
+ traversalInfo.LastScore() = score;
+
return score;
}
More information about the mlpack-git
mailing list