[mlpack-git] master: A tighter, but random, prune. Minor speedup. Also add a commented-out check for the lower bounds, and fix a minor bug that screwed them up. (01ec8a5)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:06 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 01ec8a54be0e047cfd45872d391ff59715799051
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 24 17:40:27 2015 -0500
A tighter, but random, prune. Minor speedup. Also add a commented-out check for the lower bounds, and fix a minor bug that screwed them up.
>---------------------------------------------------------------
01ec8a54be0e047cfd45872d391ff59715799051
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 62 +++++++++++++++++++++++++-
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 10 ++++-
2 files changed, 69 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 8a76d06..ffd818f 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -132,6 +132,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
// We won't use the AllkNN class here because we have our own set of rules.
+ //lastIterationCentroids = oldCentroids;
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
metric, prunedPoints, oldFromNewCentroids, visited);
@@ -205,11 +206,64 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Parent()->Stat().Pruned() == centroids.n_cols)
{
node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
- node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
+ node.Stat().LowerBound() = node.Parent()->Stat().LowerBound() +
+ clusterDistances[centroids.n_cols];
node.Stat().Pruned() = node.Parent()->Stat().Pruned();
node.Stat().Owner() = node.Parent()->Stat().Owner();
}
+
+ // Exhaustive lower bound check. Sigh.
+/* if (!prunedLastIteration)
+ {
+ for (size_t i = 0; i < node.NumDescendants(); ++i)
+ {
+ double closest = DBL_MAX;
+ double secondClosest = DBL_MAX;
+ arma::vec distances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(node.Descendant(i)),
+ lastIterationCentroids.col(j));
+ distances(j) = dist;
+
+ if (dist < closest)
+ {
+ secondClosest = closest;
+ closest = dist;
+ }
+ else if (dist < secondClosest)
+ secondClosest = dist;
+ }
+
+ if (closest - 1e-10 > node.Stat().UpperBound())
+ {
+ Log::Warn << distances.t();
+ Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
+"c" << node.NumDescendants() << " invalidates upper bound " <<
+node.Stat().UpperBound() << " with closest cluster distance " << closest <<
+".\n";
+ }
+
+ if (node.NumChildren() == 0)
+ {
+ if (secondClosest + 1e-10 < std::min(lowerBounds[node.Descendant(i)],
+ node.Stat().LowerBound()))
+ {
+ Log::Warn << distances.t();
+ Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
+"c" << node.NumDescendants() << " invalidates lower bound " <<
+std::min(lowerBounds[node.Descendant(i)], node.Stat().LowerBound()) << " (" <<
+lowerBounds[node.Descendant(i)] << ", " << node.Stat().LowerBound() << ") with "
+ << "second closest cluster distance " << secondClosest << ". cd " <<
+closest << "; pruned " << prunedPoints[node.Descendant(i)] << " visited " <<
+visited[node.Descendant(i)] << ".\n";
+ }
+ }
+ }
+ }*/
+
+
if ((node.Stat().Pruned() == centroids.n_cols) &&
(node.Stat().Owner() < centroids.n_cols))
{
@@ -307,6 +361,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
allChildrenPruned = false;
}
+ if (node.Stat().StaticPruned() && !allChildrenPruned)
+ {
+ Log::Warn << node;
+ Log::Fatal << "Node is statically pruned but not all its children are!\n";
+ }
+
// If all of the children and points are pruned, we may mark this node as
// pruned.
if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index c87c1da..6eb714d 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -217,7 +217,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
if (adjustedScore < queryNode.Stat().LowerBound())
{
// If this might affect the lower bound, make it more exact.
- queryNode.Stat().LowerBound() = queryNode.MinDistance(&referenceNode);
+ queryNode.Stat().LowerBound() = std::min(queryNode.Stat().LowerBound(),
+ queryNode.MinDistance(&referenceNode));
++scores;
}
@@ -229,7 +230,12 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
if (score != DBL_MAX)
{
// Get minimum and maximum distances.
- math::Range distances = queryNode.RangeDistance(&referenceNode);
+// math::Range distances = queryNode.RangeDistance(&referenceNode);
+ math::Range distances;
+ distances.Lo() = queryNode.MinDistance(&referenceNode);
+ distances.Hi() =
+ queryNode.MaxDistance(centroids.col(referenceNode.Descendant(0)));
+
score = distances.Lo();
++scores;
if (distances.Lo() > queryNode.Stat().UpperBound())
More information about the mlpack-git
mailing list