[mlpack-git] master: Basic static pruning. Minor speedup. (853b4bf)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:52 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 853b4bf6e6e44724717a6c83917df4feb6fc53d4
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 17 16:38:35 2015 -0500
Basic static pruning. Minor speedup.
>---------------------------------------------------------------
853b4bf6e6e44724717a6c83917df4feb6fc53d4
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 6 +-
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 98 +++++++++++++++++---------
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 3 +
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 9 ++-
4 files changed, 78 insertions(+), 38 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index e756669..26c90df 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -95,13 +95,15 @@ class DTNNKMeans
//! Update the bounds in the tree before the next iteration.
void UpdateTree(TreeType& node,
arma::vec& clusterDistances,
- std::vector<size_t>& oldFromNewCentroids);
+ std::vector<size_t>& oldFromNewCentroids,
+ arma::mat& newCentroids);
//! Extract the centroids of the clusters.
void ExtractCentroids(TreeType& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
- std::vector<size_t>& oldFromNewCentroids);
+ std::vector<size_t>& oldFromNewCentroids,
+ arma::mat& centroids);
};
//! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 08482f8..ddf7c43 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -128,9 +128,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// Now we need to extract the clusters.
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
- ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids);
- Log::Warn << "New counts: " << counts.t();
- Log::Warn << accu(counts) << ".\n";
+ ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids,
+ oldCentroids);
// Now, calculate how far the clusters moved, after normalizing them.
double residual = 0.0;
@@ -160,7 +159,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
distanceCalculations += centroids.n_cols;
- UpdateTree(*tree, clusterDistances, oldFromNewCentroids);
+ UpdateTree(*tree, clusterDistances, oldFromNewCentroids, newCentroids);
delete centroidTree;
@@ -173,45 +172,46 @@ template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
arma::vec& clusterDistances,
- std::vector<size_t>& oldFromNewCentroids)
+ std::vector<size_t>& oldFromNewCentroids,
+ arma::mat& newCentroids)
{
- // Simply reset the bounds.
- node.Stat().UpperBound() = DBL_MAX;
- node.Stat().LowerBound() = DBL_MAX;
+ node.Stat().StaticPruned() = false;
+
if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
(node.Stat().Owner() < clusterDistances.n_elem - 1))
{
- const size_t owner = node.Stat().Owner();
-
- node.Stat().LastUpperBound() = node.Stat().UpperBound() +
- clusterDistances[owner];
-
- // Update child bounds, at least a little.
- for (size_t i = 0; i < node.NumChildren(); ++i)
+ // Adjust bounds.
+ node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
+ node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+ if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
- node.Child(i).Stat().UpperBound() = node.Stat().UpperBound();
- node.Child(i).Stat().LowerBound() = node.Stat().LowerBound();
- node.Child(i).Stat().Owner() = node.Stat().Owner();
- node.Child(i).Stat().Pruned() = node.Stat().Pruned();
+ node.Stat().StaticPruned() = true;
+ }
+ else
+ {
+ // Tighten bound.
+ node.Stat().UpperBound() =
+ node.MaxDistance(newCentroids.col(node.Stat().Owner()));
+ ++distanceCalculations;
+ if (node.Stat().UpperBound() < node.Stat().LowerBound())
+ {
+ node.Stat().StaticPruned() = true;
+ }
}
}
- else if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
- (node.Stat().Owner() >= clusterDistances.n_elem - 1))
- {
- Log::Warn << clusterDistances.n_cols - 1 << ".\n";
- Log::Warn << node;
- Log::Fatal << "Node is pruned, but has no owner!\n";
- }
- else
+
+ if (!node.Stat().StaticPruned())
{
- node.Stat().LastUpperBound() = node.Stat().UpperBound() +
- clusterDistances[clusterDistances.n_elem - 1];
+ node.Stat().UpperBound() = DBL_MAX;
+ node.Stat().LowerBound() = DBL_MAX;
+ node.Stat().Pruned() = 0;
+ node.Stat().Owner() = clusterDistances.n_elem - 1;
+ node.Stat().StaticPruned() = false;
}
- node.Stat().Pruned() = size_t(-1);
- node.Stat().Owner() = size_t(-1);
for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids);
+ UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
+ newCentroids);
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -219,14 +219,42 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
TreeType& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
- std::vector<size_t>& oldFromNewCentroids)
+ std::vector<size_t>& oldFromNewCentroids,
+ arma::mat& centroids)
{
// Does this node own points?
- if (node.Stat().Pruned() == newCentroids.n_cols)
+ if ((node.Stat().Pruned() == newCentroids.n_cols) ||
+ (node.Stat().StaticPruned() && node.Stat().Owner() < newCentroids.n_cols))
{
const size_t owner = node.Stat().Owner();
newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
newCounts[owner] += node.NumDescendants();
+
+ // Perform the sanity check here.
+/*
+ for (size_t i = 0; i < node.NumDescendants(); ++i)
+ {
+ const size_t index = node.Descendant(i);
+ arma::vec trueDistances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(index),
+ centroids.col(j));
+ trueDistances[j] = dist;
+ }
+
+ arma::uword minIndex;
+ const double minDist = trueDistances.min(minIndex);
+ if (size_t(minIndex) != owner)
+ {
+ Log::Warn << trueDistances.t();
+ Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
+<< node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
+ << "distance " << minDist << " but node is pruned with upper bound " <<
+node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
+ }
+ }
+*/
}
else
{
@@ -245,7 +273,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
// The node is not entirely owned by a cluster. Recurse.
for (size_t i = 0; i < node.NumChildren(); ++i)
ExtractCentroids(node.Child(i), newCentroids, newCounts,
- oldFromNewCentroids);
+ oldFromNewCentroids, centroids);
}
}
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index d5268f4..26e549a 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -87,6 +87,9 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+ if (queryNode.Stat().StaticPruned() == true)
+ return DBL_MAX;
+
// Pruned() for the root node must never be set to size_t(-1).
if (queryNode.Stat().Pruned() == size_t(-1))
{
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index ee24b23..e549afe 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -23,6 +23,7 @@ class DTNNStatistic : public
lastUpperBound(DBL_MAX),
owner(size_t(-1)),
pruned(size_t(-1)),
+ staticPruned(false),
centroid()
{
// Nothing to do.
@@ -35,7 +36,8 @@ class DTNNStatistic : public
lowerBound(DBL_MAX),
lastUpperBound(DBL_MAX),
owner(size_t(-1)),
- pruned(size_t(-1))
+ pruned(size_t(-1)),
+ staticPruned(false)
{
// Empirically calculate the centroid.
centroid.zeros(node.Dataset().n_rows);
@@ -67,6 +69,9 @@ class DTNNStatistic : public
size_t Pruned() const { return pruned; }
size_t& Pruned() { return pruned; }
+ bool StaticPruned() const { return staticPruned; }
+ bool& StaticPruned() { return staticPruned; }
+
std::string ToString() const
{
std::ostringstream o;
@@ -75,6 +80,7 @@ class DTNNStatistic : public
o << " Lower bound: " << lowerBound << ".\n";
o << " Last upper bound: " << lastUpperBound << ".\n";
o << " Pruned: " << pruned << ".\n";
+ o << " Static pruned: " << staticPruned << ".\n";
o << " Owner: " << owner << ".\n";
return o.str();
}
@@ -85,6 +91,7 @@ class DTNNStatistic : public
double lastUpperBound;
size_t owner;
size_t pruned;
+ bool staticPruned;
arma::vec centroid;
};
More information about the mlpack-git
mailing list