[mlpack-git] master: Refactoring, and tighten a bound for minor speedup. (820ba74)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:36 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 820ba7490a2a244aae6b9e27b34492bf636b1a7a
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 3 21:22:54 2015 -0500
Refactoring, and tighten a bound for minor speedup.
>---------------------------------------------------------------
820ba7490a2a244aae6b9e27b34492bf636b1a7a
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 204 ++++---------------------
1 file changed, 26 insertions(+), 178 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 842feae..40dec27 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -304,105 +304,29 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().SecondClusterBound())
node.Stat().SecondClusterBound() = newSecondClusterBound;
- // Sanity check: ensure the owner is right.
-/*
- for (size_t i = 0; i < node.NumPoints(); ++i)
- {
- arma::vec dists(centroids.n_cols);
- size_t trueOwner = centroids.n_cols;
- double trueDist = DBL_MAX;
- for (size_t j = 0; j < centroids.n_cols; ++j)
- {
- const double dist = metric.Evaluate(dataset.col(node.Point(i)),
- lastIterationCentroids.col(j));
- dists(j) = dist;
- if (dist < trueDist)
- {
- trueDist = dist;
- trueOwner = j;
- }
- }
-
- if (trueOwner != owner)
- {
- Log::Warn << node << "...\n" << *node.Parent();
- Log::Warn << dists.t();
- Log::Warn << "Assignment: " << assignments(0, node.Point(i)) << ".\n";
- Log::Warn << "Dists: " << distances(0, node.Point(i)) << ", " <<
-distances(1, node.Point(i)) << ".\n";
-// TreeType* n = node.Parent()->Parent();
-// while (n != NULL)
-// {
-// Log::Warn << "...\n" << *n;
-// n = n->Parent();
-// }
- Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
- << owner << " but has true owner " << trueOwner << "! [" <<
-assignments(0, node.Point(i)) << " -- " <<
-metric.Evaluate(dataset.col(node.Point(i)),
-centroids.col(assignments(0, node.Point(i)))) << "] " <<
-distances(0, node.Point(i)) << " " <<
-assignments(0, node.Point(i)) << " " <<
-assignments(0, node.Point(i - 1)) << ".\n";
- }
- }
-*/
-
- if (node.NumPoints() == 0 && childrenPruned)
- {
- // Pruned because its children are all pruned.
+ // Convenience variables to clean up the expressions.
+ const double mcd = node.Stat().MaxClusterDistance();
+ const double scb = node.Stat().SecondClusterBound();
+ const double ownerMovement = clusterDistances[owner];
+ const double maxMovement = clusterDistances[centroids.n_cols];
+ const double closestClusterDistance =
+ interclusterDistances[newFromOldCentroids[owner]];
+ if ((node.NumPoints() == 0 && childrenPruned) ||
+ (mcd + ownerMovement < scb - maxMovement) ||
+ (mcd < 0.5 * closestClusterDistance))
node.Stat().Pruned() = true;
- }
- // What is the maximum distance to the closest cluster in the node?
- else if (node.Stat().MaxClusterDistance() +
- clusterDistances[node.Stat().Owner()] <
- node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
- {
- node.Stat().Pruned() = true;
- }
- else
- {
- // Also do between-cluster prune.
- if (node.Stat().MaxClusterDistance() < 0.5 *
- interclusterDistances[newFromOldCentroids[owner]])
- {
- node.Stat().Pruned() = true;
- }
- }
- // Adjust for next iteration.
- node.Stat().MaxClusterDistance() +=
- clusterDistances[node.Stat().Owner()];
- node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
+ // Adjust bounds for next iteration, regardless of whether or not the node
+ // was pruned. (Does this adjustment need to happen if there is no prune?
+ node.Stat().MaxClusterDistance() += ownerMovement;
+ node.Stat().SecondClusterBound() -= maxMovement;
}
- else
+ else if (childrenPruned && node.NumChildren() > 0 && node.NumPoints() == 0)
{
// The node isn't owned by a single cluster. But if it has no points and
// its children are all pruned, we may prune it too.
- if (childrenPruned && node.NumChildren() > 0)
- {
-// Log::Warn << "Prune parent node " << node.Point(0) << "c" <<
-//node.NumDescendants() << ".\n";
- node.Stat().Pruned() = true;
- node.Stat().Owner() = centroids.n_cols;
- }
-// if (node.NumChildren() > 0)
-// if (node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
-// Log::Warn << "Node left child pruned but right child not:\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ", this:\n" << node;
-// if (node.NumChildren() > 0)
-// if (node.Child(1).Stat().Pruned() && !node.Child(0).Stat().Pruned())
-// Log::Warn << "Node right child pruned but left child not:\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ", this:\n" << node;
-// if (node.NumChildren() > 0)
-// Log::Warn << "Node has more than 0 children: " << node << ".l\n" <<
-//node.Child(0) << ", r\n" << node.Child(1) << ".\n";
-
- // Adjust the bounds for next iteration.
-// node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
-// node.Stat().SecondClusterBound() = std::max(0.0,
-// node.Stat().SecondClusterBound() -
-// clusterDistances[centroids.n_cols]);
+ node.Stat().Pruned() = true;
+ node.Stat().Owner() = centroids.n_cols;
}
}
else if (node.Stat().Pruned())
@@ -447,37 +371,6 @@ assignments(0, node.Point(i - 1)) << ".\n";
}
}
}
-/*
- if (node.Stat().Pruned() && node.Stat().Owner() != centroids.n_cols)
- {
- for (size_t i = 0; i < node.NumPoints(); ++i)
- {
- size_t trueOwner = 0;
- double ownerDist = DBL_MAX;
- arma::vec distances(centroids.n_cols);
- for (size_t j = 0; j < centroids.n_cols; ++j)
- {
- const double dist = metric.Evaluate(dataset.col(node.Point(i)),
- lastIterationCentroids.col(j));
- distances(j) = dist;
- if (dist < ownerDist)
- {
- trueOwner = j;
- ownerDist = dist;
- }
- }
-
- if (trueOwner != node.Stat().Owner())
- {
- Log::Warn << node << "...\n" << *node.Parent();
- Log::Warn << distances.t();
- Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
- << node.Stat().Owner() << " but has true owner " << trueOwner <<
-"!\n";
- }
- }
- }
-*/
}
else
{
@@ -512,47 +405,12 @@ assignments(0, node.Point(i - 1)) << ".\n";
const double upperPointBound = distances(0, index) +
clusterDistances[owner];
- if (distances(0, index) + clusterDistances[owner] <
- lowerSecondBounds[index] - clusterDistances[centroids.n_cols])
- {
-/*
- // Sanity check.
- size_t trueOwner;
- double trueDist = DBL_MAX;
- arma::vec distances(centroids.n_cols);
- for (size_t j = 0; j < centroids.n_cols; ++j)
- {
- const double dist = metric.Evaluate(lastIterationCentroids.col(j),
- dataset.col(index));
- distances(j) = dist;
- if (dist < trueDist)
- {
- trueOwner = j;
- trueDist = dist;
- }
- }
-
- if (trueOwner != owner)
- {
- Log::Warn << "Point " << index << ", ub " << distances(0, index) << ","
- << " lb " << lowerSecondBounds[index] << ", pruned " <<
-prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
-"owner!\n";
- Log::Warn << distances.t();
- Log::Fatal << "Assigned owner " << owner << " but true owner is "
- << trueOwner << "!\n";
- }
-*/
- prunedPoints[index] = true;
- distances(0, index) += clusterDistances[owner];
- lastOwners[index] = owner;
- distances(1, index) += clusterDistances[centroids.n_cols];
- lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
- prunedCentroids.col(owner) += dataset.col(index);
- prunedCounts(owner)++;
- }
- else if (distances(0, index) + clusterDistances[owner] < 0.5 *
- interclusterDistances[newFromOldCentroids[owner]])
+ const double lowerSecondBound = lowerSecondBounds[index] -
+ clusterDistances[centroids.n_cols];
+ const double closestClusterDistance =
+ interclusterDistances[newFromOldCentroids[owner]];
+ if ((upperPointBound < lowerSecondBound) ||
+ (upperPointBound < 0.5 * closestClusterDistance))
{
prunedPoints[index] = true;
distances(0, index) += clusterDistances[owner];
@@ -568,18 +426,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
distances(0, index) = metric.Evaluate(centroids.col(owner),
dataset.col(index));
++distanceCalculations;
- if (distances(0, index) < lowerSecondBounds[index] -
- clusterDistances[centroids.n_cols])
- {
- prunedPoints[index] = true;
- lastOwners[index] = owner;
- lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
- distances(1, index) += clusterDistances[centroids.n_cols];
- prunedCentroids.col(owner) += dataset.col(index);
- prunedCounts(owner)++;
- }
- else if (distances(0, index) < 0.5 *
- interclusterDistances[newFromOldCentroids[owner]])
+
+ if ((distances(0, index) < lowerSecondBound) ||
+ (distances(0, index) < 0.5 * closestClusterDistance))
{
prunedPoints[index] = true;
lastOwners[index] = owner;
@@ -593,7 +442,6 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
prunedPoints[index] = false;
allPruned = false;
// Still update these anyway.
- distances(0, index) += clusterDistances[owner];
distances(1, index) += clusterDistances[centroids.n_cols];
}
}
More information about the mlpack-git
mailing list