[mlpack-git] master: Add static point prunes. Fairly significant runtime improvement. (99b8d56)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:31 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 99b8d56dc1d29bea6096648e92e88782d002eb4d
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 17 18:34:02 2015 -0500
Add static point prunes. Fairly significant runtime improvement.
>---------------------------------------------------------------
99b8d56dc1d29bea6096648e92e88782d002eb4d
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 109 +++++++++++++++++++++++--
src/mlpack/tests/kmeans_test.cpp | 12 ++-
2 files changed, 109 insertions(+), 12 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 4f5377c..5b43bc6 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -69,6 +69,12 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
Timer::Stop("tree_building");
+
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ prunedPoints[i] = false;
+ assignments.fill(size_t(-1));
+ upperBounds.fill(DBL_MAX);
+ lowerBounds.fill(DBL_MAX);
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -86,13 +92,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::Col<size_t>& counts)
{
// Reset information.
- upperBounds.fill(DBL_MAX);
- lowerBounds.fill(DBL_MAX);
for (size_t i = 0; i < dataset.n_cols; ++i)
- {
- prunedPoints[i] = false;
visited[i] = false;
- }
// Build a tree on the centroids.
arma::mat oldCentroids(centroids); // Slow. :(
@@ -177,6 +178,16 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
{
node.Stat().StaticPruned() = false;
+ // Grab information from the parent, if we can.
+ if (node.Parent() != NULL &&
+ node.Parent()->Stat().Pruned() == clusterDistances.n_elem - 1)
+ {
+ node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
+ node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
+ node.Stat().Pruned() = node.Parent()->Stat().Pruned();
+ node.Stat().Owner() = node.Parent()->Stat().Owner();
+ }
+
if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
(node.Stat().Owner() < clusterDistances.n_elem - 1))
{
@@ -199,6 +210,65 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
}
}
}
+ else
+ {
+ node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+ }
+
+ if (!node.Stat().StaticPruned())
+ {
+ // Try to prune individual points.
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ const size_t index = node.Point(i);
+ if (!visited[index] && !prunedPoints[index])
+ continue; // We didn't visit it and we don't have valid bounds -- so we
+ // can't prune it.
+
+ prunedPoints[index] = false;
+ const size_t owner = assignments[node.Point(i)];
+ const double lowerBound = std::min(lowerBounds[index] -
+ clusterDistances[newCentroids.n_cols], node.Stat().LowerBound());
+ if (upperBounds[index] + clusterDistances[owner] < lowerBound)
+ {
+ prunedPoints[index] = true;
+ upperBounds[index] += clusterDistances[owner];
+ lowerBounds[index] = lowerBound;
+ }
+ else
+ {
+ // Attempt to tighten the bound.
+ upperBounds[index] = metric.Evaluate(dataset.col(index),
+ newCentroids.col(owner));
+ ++distanceCalculations;
+ if (upperBounds[index] < lowerBound)
+ {
+ prunedPoints[index] = true;
+ lowerBounds[index] = lowerBound;
+ }
+ else
+ {
+ // Point cannot be pruned.
+ upperBounds[index] = DBL_MAX;
+ lowerBounds[index] = DBL_MAX;
+ }
+ }
+ }
+ }
+ else
+ {
+ // Adjust bounds for individual points.
+ for (size_t i = 0; i < node.NumDescendants(); ++i)
+ {
+ upperBounds[node.Descendant(i)] += clusterDistances[node.Stat().Owner()];
+ lowerBounds[node.Descendant(i)] -=
+ clusterDistances[newCentroids.n_cols - 1];
+ }
+ }
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
+ newCentroids);
if (!node.Stat().StaticPruned())
{
@@ -208,10 +278,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().Owner() = clusterDistances.n_elem - 1;
node.Stat().StaticPruned() = false;
}
-
- for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
- newCentroids);
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -267,6 +333,31 @@ node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
const size_t owner = assignments[node.Point(i)];
newCentroids.col(owner) += dataset.col(node.Point(i));
++newCounts[owner];
+
+/*
+ const size_t index = node.Point(i);
+ arma::vec trueDistances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(index),
+ centroids.col(j));
+ trueDistances[j] = dist;
+ }
+
+ arma::uword minIndex;
+ const double minDist = trueDistances.min(minIndex);
+ if (size_t(minIndex) != owner)
+ {
+ Log::Warn << trueDistances.t();
+ Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
+ << node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
+ << "distance " << minDist << " but was assigned to cluster " <<
+assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
+" and lb " << lowerBounds[node.Point(0)] << "; pp " <<
+(prunedPoints[node.Point(0)] ? "true" : "false") << ", visited " << (visited[node.Point(0)] ? "true"
+: "false") << ".\n";
+ }
+*/
}
}
diff --git a/src/mlpack/tests/kmeans_test.cpp b/src/mlpack/tests/kmeans_test.cpp
index d274e49..562589d 100644
--- a/src/mlpack/tests/kmeans_test.cpp
+++ b/src/mlpack/tests/kmeans_test.cpp
@@ -711,7 +711,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansBaseCaseTest)
upperBounds.fill(DBL_MAX);
lowerBounds.fill(DBL_MAX);
std::vector<bool> visited(points, false); // Fill with false.
- std::vector<size_t> oldFromNewCentroids; // Not used.
+ std::vector<size_t> oldFromNewCentroids(clusters);
+ for (size_t i = 0; i < clusters; ++i)
+ oldFromNewCentroids[i] = i;
std::vector<bool> prunedPoints(points, false); // Fill with false.
EuclideanDistance e;
@@ -768,7 +770,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeOneLeafTest)
upperBounds.fill(DBL_MAX);
lowerBounds.fill(DBL_MAX);
std::vector<bool> visited(points, false); // Fill with false.
- std::vector<size_t> oldFromNewCentroids; // Not used.
+ std::vector<size_t> oldFromNewCentroids(clusters);
+ for (size_t i = 0; i < clusters; ++i)
+ oldFromNewCentroids[i] = i;
std::vector<bool> prunedPoints(points, false); // Fill with false.
EuclideanDistance e;
@@ -876,7 +880,9 @@ BOOST_AUTO_TEST_CASE(DualTreeKMeansScoreKDTreeTest)
upperBounds.fill(DBL_MAX);
lowerBounds.fill(DBL_MAX);
std::vector<bool> visited(points, false); // Fill with false.
- std::vector<size_t> oldFromNewCentroids; // Not used.
+ std::vector<size_t> oldFromNewCentroids(clusters);
+ for (size_t i = 0; i < clusters; ++i)
+ oldFromNewCentroids[i] = i;
std::vector<bool> prunedPoints(points, false); // Fill with false.
EuclideanDistance e;
More information about the mlpack-git
mailing list