[mlpack-git] master: Refactoring to reduce runtime of tree update. It speeds things up in terms of distance computations too, a bit. (72f2ddd)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:17 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 72f2ddd51921b3417a4c264e396438aaf8330d13
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 3 20:36:29 2015 -0500
Refactoring to reduce runtime of tree update. It speeds things up in terms of distance computations too, a bit.
>---------------------------------------------------------------
72f2ddd51921b3417a4c264e396438aaf8330d13
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 55 +++++++++++++-------------
1 file changed, 28 insertions(+), 27 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 9d50113..842feae 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -250,7 +250,10 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
if (owner == centroids.n_cols + 1)
owner = c;
else if (owner != c)
+ {
singleOwner = false;
+ break;
+ }
// Update maximum cluster distance and second cluster bound.
if (!prunedPoints[node.Point(i)])
@@ -274,10 +277,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
{
if (owner == centroids.n_cols + 1)
owner = node.Child(i).Stat().Owner();
- else if (node.Child(i).Stat().Owner() == centroids.n_cols)
- singleOwner = false;
- else if (owner != node.Child(i).Stat().Owner())
+ else if ((node.Child(i).Stat().Owner() == centroids.n_cols) ||
+ (owner != node.Child(i).Stat().Owner()))
+ {
singleOwner = false;
+ break;
+ }
// Update maximum cluster distance and second cluster bound.
if (node.Child(i).Stat().MaxClusterDistance() > newMaxClusterDistance)
@@ -286,19 +291,19 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
}
- // What do we do with the new cluster bounds?
- if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
- node.Stat().MaxClusterDistance())
- node.Stat().MaxClusterDistance() = newMaxClusterDistance;
- if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
- node.Stat().SecondClusterBound())
- node.Stat().SecondClusterBound() = newSecondClusterBound;
-
// Okay, now we know if it's owned or not, and by which cluster.
if (singleOwner)
{
node.Stat().Owner() = owner;
+ // What do we do with the new cluster bounds?
+ if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
+ node.Stat().MaxClusterDistance())
+ node.Stat().MaxClusterDistance() = newMaxClusterDistance;
+ if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
+ node.Stat().SecondClusterBound())
+ node.Stat().SecondClusterBound() = newSecondClusterBound;
+
// Sanity check: ensure the owner is right.
/*
for (size_t i = 0; i < node.NumPoints(); ++i)
@@ -394,10 +399,10 @@ assignments(0, node.Point(i - 1)) << ".\n";
//node.Child(0) << ", r\n" << node.Child(1) << ".\n";
// Adjust the bounds for next iteration.
- node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
- node.Stat().SecondClusterBound() = std::max(0.0,
- node.Stat().SecondClusterBound() -
- clusterDistances[centroids.n_cols]);
+// node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
+// node.Stat().SecondClusterBound() = std::max(0.0,
+// node.Stat().SecondClusterBound() -
+// clusterDistances[centroids.n_cols]);
}
}
else if (node.Stat().Pruned())
@@ -496,21 +501,17 @@ assignments(0, node.Point(i - 1)) << ".\n";
{
const size_t index = node.Point(i);
size_t owner;
- if (!prunedLastIteration && !prunedPoints[index])
- {
- owner = assignments(0, index);
- // Establish bounds, since these points were searched this iteration.
- lowerSecondBounds[index] = distances(1, index);
- }
- else if (prunedLastIteration && node.Stat().Owner() < centroids.n_cols)
- {
+ if (prunedLastIteration && node.Stat().Owner() < centroids.n_cols)
owner = node.Stat().Owner();
- }
else
- {
- owner = lastOwners[index];
- }
+ owner = assignments(0, index);
+
+ // Update lower bound, if possible.
+ if (!prunedLastIteration && !prunedPoints[index])
+ lowerSecondBounds[index] = distances(1, index);
+ const double upperPointBound = distances(0, index) +
+ clusterDistances[owner];
if (distances(0, index) + clusterDistances[owner] <
lowerSecondBounds[index] - clusterDistances[centroids.n_cols])
{
More information about the mlpack-git
mailing list