[mlpack-git] master: Refactor UpdateTree() to sometimes Hamerly prune. We aren't properly retaining pruned nodes between iterations, but this is definitely a start and it's basically as fast as any of these attempted algorithms I've written. (29a7f5f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:58 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 29a7f5f2dff3a9822e6ee9bdac4f6f60bdbc2772
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Jan 31 14:08:26 2015 -0500
Refactor UpdateTree() to sometimes Hamerly prune. We aren't properly retaining pruned nodes between iterations, but this is definitely a start and it's basically as fast as any of these attempted algorithms I've written.
>---------------------------------------------------------------
29a7f5f2dff3a9822e6ee9bdac4f6f60bdbc2772
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 23 ++-
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 207 ++++++++++++++++++++++---
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 3 +
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 53 ++++++-
4 files changed, 257 insertions(+), 29 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index fddca15..e655e32 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -14,6 +14,8 @@
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
+#include "dtnn_statistic.hpp"
+
namespace mlpack {
namespace kmeans {
@@ -28,7 +30,7 @@ template<
typename MetricType,
typename MatType,
typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >
+ DTNNStatistic> >
class DTNNKMeans
{
public:
@@ -74,9 +76,24 @@ class DTNNKMeans
//! Track distance calculations.
size_t distanceCalculations;
+ //! Track iteration number.
+ size_t iteration;
+
+ //! Centroids from pruning. Not normalized.
+ arma::mat prunedCentroids;
+ //! Counts from pruning. Not normalized.
+ arma::Col<size_t> prunedCounts;
//! Update the bounds in the tree before the next iteration.
- void UpdateTree(TreeType& node, const double tolerance);
+ void UpdateTree(TreeType& node,
+ const double tolerance,
+ const arma::mat& centroids,
+ const arma::Mat<size_t>& assignments,
+ const arma::mat& distances,
+ const arma::mat& clusterDistances,
+ const std::vector<size_t>& oldFromNewCentroids);
+
+ void PrecalculateCentroids(TreeType& node);
};
//! A template typedef for the DTNNKMeans algorithm with the default tree type
@@ -88,7 +105,7 @@ using DefaultDTNNKMeans = DTNNKMeans<MetricType, MatType>;
template<typename MetricType, typename MatType>
using CoverTreeDTNNKMeans = DTNNKMeans<MetricType, MatType,
tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
- neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >;
+ DTNNStatistic> >;
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 6112ca7..9aa0be1 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -41,7 +41,7 @@ TreeType* BuildTree(
tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
>::type = 0)
{
- return new TreeType(dataset);
+ return new TreeType(dataset, 1);
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -51,7 +51,8 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
datasetOrig),
metric(metric),
- distanceCalculations(0)
+ distanceCalculations(0),
+ iteration(0)
{
Timer::Start("tree_building");
@@ -79,18 +80,25 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::mat& newCentroids,
arma::Col<size_t>& counts)
{
+ if (iteration == 0)
+ {
+ prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+ prunedCounts.zeros(centroids.n_cols);
+ }
+
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
// Build a tree on the centroids.
+ arma::mat oldCentroids(centroids); // Slow. :(
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
// We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
- arma::mat distances(5, dataset.n_cols);
- arma::Mat<size_t> assignments(5, dataset.n_cols);
+ arma::mat distances(2, dataset.n_cols);
+ arma::Mat<size_t> assignments(2, dataset.n_cols);
distances.fill(DBL_MAX);
assignments.fill(size_t(-1));
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
@@ -101,27 +109,36 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
traverser.Traverse(*tree, *centroidTree);
+ Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
+ rules.Scores() << " scores.\n";
distanceCalculations += rules.BaseCases() + rules.Scores();
// From the assignments, calculate the new centroids and counts.
for (size_t i = 0; i < dataset.n_cols; ++i)
{
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
- dataset.col(i);
- ++counts(oldFromNewCentroids[assignments(0, i)]);
- }
- else
+ if (assignments(0, i) != size_t(-1))
{
- newCentroids.col(assignments(0, i)) += dataset.col(i);
- ++counts(assignments(0, i));
+ if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ {
+ newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
+ dataset.col(i);
+ ++counts(oldFromNewCentroids[assignments(0, i)]);
+ }
+ else
+ {
+ newCentroids.col(assignments(0, i)) += dataset.col(i);
+ ++counts(assignments(0, i));
+ }
}
}
+ newCentroids += prunedCentroids;
+ counts += prunedCounts;
+
// Now, calculate how far the clusters moved, after normalizing them.
double residual = 0.0;
double maxMovement = 0.0;
+ arma::vec clusterDistances(centroids.n_cols + 1);
for (size_t c = 0; c < centroids.n_cols; ++c)
{
// Get the mapping to the old cluster, if necessary.
@@ -130,41 +147,187 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
if (counts[old] == 0)
{
newCentroids.col(old).fill(DBL_MAX);
+ clusterDistances[old] = 0;
}
else
{
newCentroids.col(old) /= counts(old);
const double movement = metric.Evaluate(centroids.col(c),
newCentroids.col(old));
+ clusterDistances[old] = movement;
residual += std::pow(movement, 2.0);
if (movement > maxMovement)
maxMovement = movement;
}
}
+ clusterDistances[centroids.n_cols] = maxMovement;
+ Log::Warn << clusterDistances.t();
distanceCalculations += centroids.n_cols;
- UpdateTree(*tree, maxMovement);
+ UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
+ clusterDistances, oldFromNewCentroids);
+
+ // Reset centroids and counts for things we will collect during pruning.
+ prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+ prunedCounts.zeros(centroids.n_cols);
+ PrecalculateCentroids(*tree);
delete centroidTree;
+ ++iteration;
+
return std::sqrt(residual);
}
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
- const double tolerance)
+ const double tolerance,
+ const arma::mat& centroids,
+ const arma::Mat<size_t>& assignments,
+ const arma::mat& distances,
+ const arma::mat& clusterDistances,
+ const std::vector<size_t>& oldFromNewCentroids)
{
- if (node.Stat().FirstBound() != DBL_MAX)
- node.Stat().FirstBound() += tolerance;
- if (node.Stat().SecondBound() != DBL_MAX)
- node.Stat().SecondBound() += tolerance;
- if (node.Stat().Bound() != DBL_MAX)
- node.Stat().Bound() += tolerance;
+ // Update iteration.
+// node.Stat().Iteration() = iteration;
+
+ if (node.Stat().Owner() == size_t(-1))
+ node.Stat().Owner() = centroids.n_cols;
+ // Do the tree update in a depth-first manner: leaves first.
+ bool childrenPruned = true;
for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), tolerance);
+ {
+ UpdateTree(node.Child(i), tolerance, centroids, assignments, distances,
+ clusterDistances, oldFromNewCentroids);
+ if (!node.Child(i).Stat().Pruned())
+ childrenPruned = false; // Not all children are pruned.
+ }
+
+ // Does the node have a single owner?
+ // It would be nice if we could do this during the traversal.
+ bool singleOwner = true;
+ size_t owner = centroids.n_cols + 1;
+ node.Stat().MaxClusterDistance() = 0.0;
+ node.Stat().SecondClusterBound() = DBL_MAX;
+ if (!node.Stat().Pruned() && childrenPruned)
+ {
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ // Don't forget to map back from the new cluster index.
+ if (owner == centroids.n_cols + 1)
+ owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+ oldFromNewCentroids[assignments(0, node.Point(i))] :
+ oldFromNewCentroids[assignments(0, node.Point(i))];
+ else if (owner != oldFromNewCentroids[assignments(0, node.Point(i))])
+ singleOwner = false;
+
+ // Update maximum cluster distance and second cluster bound.
+ if (distances(0, node.Point(i)) > node.Stat().MaxClusterDistance())
+ node.Stat().MaxClusterDistance() = distances(0, node.Point(i));
+ if (distances(1, node.Point(i)) < node.Stat().SecondClusterBound())
+ node.Stat().SecondClusterBound() = distances(1, node.Point(i));
+ }
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ if (owner == centroids.n_cols + 1)
+ owner = node.Child(i).Stat().Owner();
+ else if (node.Child(i).Stat().Owner() == centroids.n_cols)
+ singleOwner = false;
+ else if (owner != node.Child(i).Stat().Owner())
+ singleOwner = false;
+
+ // Update maximum cluster distance and second cluster bound.
+ if (node.Child(i).Stat().MaxClusterDistance() >
+ node.Stat().MaxClusterDistance())
+ node.Stat().MaxClusterDistance() =
+ node.Child(i).Stat().MaxClusterDistance();
+ if (node.Child(i).Stat().SecondClusterBound() <
+ node.Stat().SecondClusterBound())
+ node.Stat().SecondClusterBound() =
+ node.Child(i).Stat().SecondClusterBound();
+ }
+
+ // Okay, now we know if it's owned or not, and by which cluster.
+ if (singleOwner)
+ {
+ node.Stat().Owner() = owner;
+
+ // Sanity check: ensure the owner is right.
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ const double ownerDist = metric.Evaluate(dataset.col(node.Point(i)),
+ centroids.col(owner));
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(node.Point(i)),
+ centroids.col(j));
+ if (dist < ownerDist)
+ {
+ Log::Warn << node << "...\n" << *node.Parent();
+ Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+ << owner << " but has true owner " << j << "! [" <<
+oldFromNewCentroids[assignments(0, node.Point(i))] << " -- " <<
+metric.Evaluate(dataset.col(node.Point(i)),
+centroids.col(oldFromNewCentroids[assignments(0, node.Point(i))])) << "] " <<
+distances(0, node.Point(i)) << " " <<
+oldFromNewCentroids[assignments(0, node.Point(i))] << " " <<
+oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
+ }
+ }
+ }
+
+ // What is the maximum distance to the closest cluster in the node?
+ if (node.Stat().MaxClusterDistance() +
+ clusterDistances[node.Stat().Owner()] <
+ node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
+ {
+ node.Stat().Pruned() = true;
+ }
+ }
+ }
+ else if (node.Stat().Pruned())
+ {
+ // The node was pruned last iteration. See if the node can remain pruned.
+ singleOwner = false;
+
+ node.Stat().Pruned() = false;
+ node.Stat().FirstBound() = DBL_MAX;
+ node.Stat().SecondBound() = DBL_MAX;
+ node.Stat().Bound() = DBL_MAX;
+ }
+ else
+ {
+ // The children haven't been pruned, so we can't.
+ // This node was not pruned last iteration, so we simply need to adjust the
+ // bounds.
+ if (node.Stat().FirstBound() != DBL_MAX)
+ node.Stat().FirstBound() += tolerance;
+ if (node.Stat().SecondBound() != DBL_MAX)
+ node.Stat().SecondBound() += tolerance;
+ if (node.Stat().Bound() != DBL_MAX)
+ node.Stat().Bound() += tolerance;
+ }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::PrecalculateCentroids(
+ TreeType& node)
+{
+ if (node.Stat().Pruned() && node.Stat().Owner() < prunedCentroids.n_cols)
+ {
+ prunedCentroids.col(node.Stat().Owner()) += node.Stat().Centroid() *
+ node.NumDescendants();
+ prunedCounts(node.Stat().Owner()) += node.NumDescendants();
+ }
+ else
+ {
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ PrecalculateCentroids(node.Child(i));
+ }
}
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index c4492c1..eface18 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -48,6 +48,9 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+ if (queryNode.Stat().Pruned())
+ return DBL_MAX;
+
// Check if the query node is Hamerly pruned, and if not, then don't continue.
return rules.Score(queryNode, referenceNode);
}
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index f6b0de6..ba3af0d 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -17,19 +17,36 @@ class DTNNStatistic : public
{
public:
DTNNStatistic() :
+ neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
pruned(false),
iteration(0),
- neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>()
+ maxClusterDistance(0.0),
+ secondClusterBound(0.0),
+ owner(size_t(-1)),
+ centroid()
{
// Nothing to do.
}
- DTNNStatistic(TreeType& /* node */) :
+ template<typename TreeType>
+ DTNNStatistic(TreeType& node) :
+ neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
pruned(false),
iteration(0),
- neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>()
+ maxClusterDistance(0.0),
+ secondClusterBound(0.0),
+ owner(size_t(-1))
{
- // Nothing to do.
+ // Empirically calculate the centroid.
+ centroid.zeros(node.Dataset().n_rows);
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ centroid += node.Dataset().col(node.Point(i));
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ centroid += node.Child(i).NumDescendants() *
+ node.Child(i).Stat().Centroid();
+
+ centroid /= node.NumDescendants();
}
bool Pruned() const { return pruned; }
@@ -38,9 +55,37 @@ class DTNNStatistic : public
size_t Iteration() const { return iteration; }
size_t& Iteration() { return iteration; }
+ double MaxClusterDistance() const { return maxClusterDistance; }
+ double& MaxClusterDistance() { return maxClusterDistance; }
+
+ double SecondClusterBound() const { return secondClusterBound; }
+ double& SecondClusterBound() { return secondClusterBound; }
+
+ size_t Owner() const { return owner; }
+ size_t& Owner() { return owner; }
+
+ const arma::vec& Centroid() const { return centroid; }
+ arma::vec& Centroid() { return centroid; }
+
+ std::string ToString() const
+ {
+ std::ostringstream o;
+ o << "DTNNStatistic [" << this << "]:\n";
+ o << " Pruned: " << pruned << ".\n";
+ o << " Iteration: " << iteration << ".\n";
+ o << " MaxClusterDistance: " << maxClusterDistance << ".\n";
+ o << " SecondClusterBound: " << secondClusterBound << ".\n";
+ o << " Owner: " << owner << ".\n";
+ return o.str();
+ }
+
private:
bool pruned;
size_t iteration;
+ double maxClusterDistance;
+ double secondClusterBound;
+ size_t owner;
+ arma::vec centroid;
};
} // namespace kmeans
More information about the mlpack-git
mailing list