[mlpack-git] master: An attempt at making things faster. Might work? (8c7c819)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:23 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 8c7c81981ee3e89cfbf29f13a0f8e9ee4ca4cd50
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Mar 2 17:32:42 2015 -0500
An attempt at making things faster. Might work?
>---------------------------------------------------------------
8c7c81981ee3e89cfbf29f13a0f8e9ee4ca4cd50
.../breadth_first_dual_tree_traverser_impl.hpp | 52 ++++++++++++++--------
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 16 ++++---
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 14 ++++--
3 files changed, 53 insertions(+), 29 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index f48bf95..4513db6 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -160,25 +160,39 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
}
else
{
- // We have to recurse down both query and reference nodes. Because the
- // query descent order does not matter, we will go to the left query child
- // first. Before recursing, we have to set the traversal information
- // correctly.
- QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
- queryDepth + 1, score, rule.TraversalInfo() };
- leftChildQueue.push(fll);
-
- QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
- queryDepth + 1, score, rule.TraversalInfo() };
- leftChildQueue.push(flr);
-
- QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
- queryDepth + 1, score, rule.TraversalInfo() };
- rightChildQueue.push(frl);
-
- QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
- queryDepth + 1, score, rule.TraversalInfo() };
- rightChildQueue.push(frr);
+ if (score >= 0.0)
+ {
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query
+ // child first. Before recursing, we have to set the traversal
+ // information correctly.
+ QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ leftChildQueue.push(fll);
+
+ QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ leftChildQueue.push(flr);
+
+ QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ rightChildQueue.push(frl);
+
+ QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
+ queryDepth + 1, score, rule.TraversalInfo() };
+ rightChildQueue.push(frr);
+ }
+ else
+ {
+ // Only recurse down the references.
+ QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
+ score, rule.TraversalInfo() };
+ referenceQueue.push(fl);
+
+ QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
+ score, ti };
+ referenceQueue.push(fr);
+ }
}
}
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index f5eaf12..c788c34 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -132,7 +132,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
// We won't use the AllkNN class here because we have our own set of rules.
- //lastIterationCentroids = oldCentroids;
+ lastIterationCentroids = oldCentroids;
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
metric, prunedPoints, oldFromNewCentroids, visited);
@@ -212,9 +212,9 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().Owner() = node.Parent()->Stat().Owner();
}
-
// Exhaustive lower bound check. Sigh.
-/* if (!prunedLastIteration)
+/*
+ if (!prunedLastIteration)
{
for (size_t i = 0; i < node.NumDescendants(); ++i)
{
@@ -238,11 +238,14 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
if (closest - 1e-10 > node.Stat().UpperBound())
{
+ Log::Warn << node;
Log::Warn << distances.t();
Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
"c" << node.NumDescendants() << " invalidates upper bound " <<
node.Stat().UpperBound() << " with closest cluster distance " << closest <<
-".\n";
+", lb " << lowerBounds[node.Descendant(i)] << " n " << node.Stat().LowerBound()
+<< " pp " << prunedPoints[node.Descendant(i)] << " visited " <<
+visited[node.Descendant(i)] << ".\n";
}
if (node.NumChildren() == 0)
@@ -250,6 +253,7 @@ node.Stat().UpperBound() << " with closest cluster distance " << closest <<
if (secondClosest + 1e-10 < std::min(lowerBounds[node.Descendant(i)],
node.Stat().LowerBound()))
{
+ Log::Warn << node;
Log::Warn << distances.t();
Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
"c" << node.NumDescendants() << " invalidates lower bound " <<
@@ -261,8 +265,8 @@ visited[node.Descendant(i)] << ".\n";
}
}
}
- }*/
-
+ }
+*/
if ((node.Stat().Pruned() == centroids.n_cols) &&
(node.Stat().Owner() < centroids.n_cols))
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index d7890aa..4c0cfc2 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -124,7 +124,7 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
const double queryDescDist = queryNode.FurthestDescendantDistance();
const double refParentDist = referenceNode.ParentDistance();
const double refDescDist = referenceNode.FurthestDescendantDistance();
- const double lastScore = traversalInfo.LastScore();
+ const double lastScore = std::abs(traversalInfo.LastScore());
double adjustedScore;
double score = 0.0;
@@ -266,6 +266,10 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
referenceNode.Descendant(0);
}
}
+
+ if (distances.Hi() > queryNode.Stat().UpperBound() &&
+ referenceNode.NumDescendants() > 1 && score != DBL_MAX)
+ score = -score; // Invert score for smarter strategy.
}
// Is everything pruned?
@@ -302,14 +306,16 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
if (oldScore == DBL_MAX)
return DBL_MAX; // It's already pruned.
+ const double realScore = std::abs(oldScore);
+
// oldScore contains the minimum distance between queryNode and referenceNode.
// In the time since Score() has been called, the upper bound *may* have
// tightened. If it has tightened enough, we may prune this node now.
- if (oldScore > queryNode.Stat().UpperBound())
+ if (realScore > queryNode.Stat().UpperBound())
{
// We may still be able to improve the lower bound on pruned nodes.
- if (oldScore < queryNode.Stat().LowerBound())
- queryNode.Stat().LowerBound() = oldScore;
+ if (realScore < queryNode.Stat().LowerBound())
+ queryNode.Stat().LowerBound() = realScore;
// This assumes that reference clusters don't appear elsewhere in the tree.
queryNode.Stat().Pruned() += referenceNode.NumDescendants();
More information about the mlpack-git
mailing list