[mlpack-git] master: Prune nodes whose points and children are pruned. This gives significant speedup, and on my little test dataset, this is the fastest algorithm yet that I have created. (536c639)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:27 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 536c639f8872a53c856c5636b7795edac6e53c44
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 19 13:45:30 2015 -0500
Prune nodes whose points and children are pruned. This gives significant speedup, and on my little test dataset, this is the fastest algorithm yet that I have created.
>---------------------------------------------------------------
536c639f8872a53c856c5636b7795edac6e53c44
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 47 +++++++++++++++++++-------
1 file changed, 34 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 9b70887..8a76d06 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -241,6 +241,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
}
+ bool allPointsPruned = true;
if (!node.Stat().StaticPruned())
{
// Try to prune individual points.
@@ -248,8 +249,11 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
{
const size_t index = node.Point(i);
if (!visited[index] && !prunedPoints[index])
+ {
+ allPointsPruned = false;
continue; // We didn't visit it and we don't have valid bounds -- so we
// can't prune it.
+ }
if (prunedLastIteration)
{
@@ -287,11 +291,40 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
// Point cannot be pruned.
upperBounds[index] = DBL_MAX;
lowerBounds[index] = DBL_MAX;
+ allPointsPruned = false;
}
}
}
}
- else
+
+ // Recurse into children, and if all the children (and all the points) are
+ // pruned, then we can mark this as statically pruned.
+ bool allChildrenPruned = true;
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ UpdateTree(node.Child(i), centroids, interclusterDistances);
+ if (!node.Child(i).Stat().StaticPruned())
+ allChildrenPruned = false;
+ }
+
+ // If all of the children and points are pruned, we may mark this node as
+ // pruned.
+ if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
+ {
+ node.Stat().StaticPruned() = true;
+ node.Stat().Owner() = centroids.n_cols; // Invalid owner.
+ node.Stat().Pruned() = size_t(-1);
+ }
+
+ if (!node.Stat().StaticPruned())
+ {
+ node.Stat().UpperBound() = DBL_MAX;
+ node.Stat().LowerBound() = DBL_MAX;
+ node.Stat().Pruned() = size_t(-1);
+ node.Stat().Owner() = centroids.n_cols;
+ node.Stat().StaticPruned() = false;
+ }
+ else // The node is now pruned.
{
if (prunedLastIteration)
{
@@ -307,18 +340,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().StaticLowerBoundMovement() = 0.0;
}
}
-
- for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), centroids, interclusterDistances);
-
- if (!node.Stat().StaticPruned())
- {
- node.Stat().UpperBound() = DBL_MAX;
- node.Stat().LowerBound() = DBL_MAX;
- node.Stat().Pruned() = size_t(-1);
- node.Stat().Owner() = centroids.n_cols;
- node.Stat().StaticPruned() = false;
- }
}
template<typename MetricType, typename MatType, typename TreeType>
More information about the mlpack-git
mailing list