[mlpack-git] master: Refactor UpdateTree() to sometimes Hamerly prune. We aren't properly retaining pruned nodes between iterations, but this is definitely a start and it's basically as fast as any of these attempted algorithms I've written. (29a7f5f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:58 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 29a7f5f2dff3a9822e6ee9bdac4f6f60bdbc2772
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Jan 31 14:08:26 2015 -0500

    Refactor UpdateTree() to sometimes Hamerly prune. We aren't properly retaining pruned nodes between iterations, but this is definitely a start and it's basically as fast as any of these attempted algorithms I've written.


>---------------------------------------------------------------

29a7f5f2dff3a9822e6ee9bdac4f6f60bdbc2772
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  23 ++-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 207 ++++++++++++++++++++++---
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  |   3 +
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   |  53 ++++++-
 4 files changed, 257 insertions(+), 29 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index fddca15..e655e32 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -14,6 +14,8 @@
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 
+#include "dtnn_statistic.hpp"
+
 namespace mlpack {
 namespace kmeans {
 
@@ -28,7 +30,7 @@ template<
     typename MetricType,
     typename MatType,
     typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-        neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >
+        DTNNStatistic> >
 class DTNNKMeans
 {
  public:
@@ -74,9 +76,24 @@ class DTNNKMeans
 
   //! Track distance calculations.
   size_t distanceCalculations;
+  //! Track iteration number.
+  size_t iteration;
+
+  //! Centroids from pruning.  Not normalized.
+  arma::mat prunedCentroids;
+  //! Counts from pruning.  Not normalized.
+  arma::Col<size_t> prunedCounts;
 
   //! Update the bounds in the tree before the next iteration.
-  void UpdateTree(TreeType& node, const double tolerance);
+  void UpdateTree(TreeType& node,
+                  const double tolerance,
+                  const arma::mat& centroids,
+                  const arma::Mat<size_t>& assignments,
+                  const arma::mat& distances,
+                  const arma::mat& clusterDistances,
+                  const std::vector<size_t>& oldFromNewCentroids);
+
+  void PrecalculateCentroids(TreeType& node);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
@@ -88,7 +105,7 @@ using DefaultDTNNKMeans = DTNNKMeans<MetricType, MatType>;
 template<typename MetricType, typename MatType>
 using CoverTreeDTNNKMeans = DTNNKMeans<MetricType, MatType,
     tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
-    neighbor::NeighborSearchStat<neighbor::NearestNeighborSort> > >;
+    DTNNStatistic> >;
 
 } // namespace kmeans
 } // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 6112ca7..9aa0be1 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -41,7 +41,7 @@ TreeType* BuildTree(
         tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
     >::type = 0)
 {
-  return new TreeType(dataset);
+  return new TreeType(dataset, 1);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -51,7 +51,8 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
     dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
         datasetOrig),
     metric(metric),
-    distanceCalculations(0)
+    distanceCalculations(0),
+    iteration(0)
 {
   Timer::Start("tree_building");
 
@@ -79,18 +80,25 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::mat& newCentroids,
     arma::Col<size_t>& counts)
 {
+  if (iteration == 0)
+  {
+    prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+    prunedCounts.zeros(centroids.n_cols);
+  }
+
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
 
   // Build a tree on the centroids.
+  arma::mat oldCentroids(centroids); // Slow. :(
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
 
   // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
-  arma::mat distances(5, dataset.n_cols);
-  arma::Mat<size_t> assignments(5, dataset.n_cols);
+  arma::mat distances(2, dataset.n_cols);
+  arma::Mat<size_t> assignments(2, dataset.n_cols);
   distances.fill(DBL_MAX);
   assignments.fill(size_t(-1));
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
@@ -101,27 +109,36 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 
   traverser.Traverse(*tree, *centroidTree);
 
+  Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
+      rules.Scores() << " scores.\n";
   distanceCalculations += rules.BaseCases() + rules.Scores();
 
   // From the assignments, calculate the new centroids and counts.
   for (size_t i = 0; i < dataset.n_cols; ++i)
   {
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
-    {
-      newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
-          dataset.col(i);
-      ++counts(oldFromNewCentroids[assignments(0, i)]);
-    }
-    else
+    if (assignments(0, i) != size_t(-1))
     {
-      newCentroids.col(assignments(0, i)) += dataset.col(i);
-      ++counts(assignments(0, i));
+      if (tree::TreeTraits<TreeType>::RearrangesDataset)
+      {
+        newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
+            dataset.col(i);
+        ++counts(oldFromNewCentroids[assignments(0, i)]);
+      }
+      else
+      {
+        newCentroids.col(assignments(0, i)) += dataset.col(i);
+        ++counts(assignments(0, i));
+      }
     }
   }
 
+  newCentroids += prunedCentroids;
+  counts += prunedCounts;
+
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
   double maxMovement = 0.0;
+  arma::vec clusterDistances(centroids.n_cols + 1);
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
     // Get the mapping to the old cluster, if necessary.
@@ -130,41 +147,187 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     if (counts[old] == 0)
     {
       newCentroids.col(old).fill(DBL_MAX);
+      clusterDistances[old] = 0;
     }
     else
     {
       newCentroids.col(old) /= counts(old);
       const double movement = metric.Evaluate(centroids.col(c),
           newCentroids.col(old));
+      clusterDistances[old] = movement;
       residual += std::pow(movement, 2.0);
 
       if (movement > maxMovement)
         maxMovement = movement;
     }
   }
+  clusterDistances[centroids.n_cols] = maxMovement;
+  Log::Warn << clusterDistances.t();
   distanceCalculations += centroids.n_cols;
 
-  UpdateTree(*tree, maxMovement);
+  UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
+      clusterDistances, oldFromNewCentroids);
+
+  // Reset centroids and counts for things we will collect during pruning.
+  prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  prunedCounts.zeros(centroids.n_cols);
+  PrecalculateCentroids(*tree);
 
   delete centroidTree;
 
+  ++iteration;
+
   return std::sqrt(residual);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
-    const double tolerance)
+    const double tolerance,
+    const arma::mat& centroids,
+    const arma::Mat<size_t>& assignments,
+    const arma::mat& distances,
+    const arma::mat& clusterDistances,
+    const std::vector<size_t>& oldFromNewCentroids)
 {
-  if (node.Stat().FirstBound() != DBL_MAX)
-    node.Stat().FirstBound() += tolerance;
-  if (node.Stat().SecondBound() != DBL_MAX)
-    node.Stat().SecondBound() += tolerance;
-  if (node.Stat().Bound() != DBL_MAX)
-    node.Stat().Bound() += tolerance;
+  // Update iteration.
+//  node.Stat().Iteration() = iteration;
+
+  if (node.Stat().Owner() == size_t(-1))
+    node.Stat().Owner() = centroids.n_cols;
 
+  // Do the tree update in a depth-first manner: leaves first.
+  bool childrenPruned = true;
   for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), tolerance);
+  {
+    UpdateTree(node.Child(i), tolerance, centroids, assignments, distances,
+        clusterDistances, oldFromNewCentroids);
+    if (!node.Child(i).Stat().Pruned())
+      childrenPruned = false; // Not all children are pruned.
+  }
+
+  // Does the node have a single owner?
+  // It would be nice if we could do this during the traversal.
+  bool singleOwner = true;
+  size_t owner = centroids.n_cols + 1;
+  node.Stat().MaxClusterDistance() = 0.0;
+  node.Stat().SecondClusterBound() = DBL_MAX;
+  if (!node.Stat().Pruned() && childrenPruned)
+  {
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+    {
+      // Don't forget to map back from the new cluster index.
+      if (owner == centroids.n_cols + 1)
+        owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+            oldFromNewCentroids[assignments(0, node.Point(i))] :
+            oldFromNewCentroids[assignments(0, node.Point(i))];
+      else if (owner != oldFromNewCentroids[assignments(0, node.Point(i))])
+        singleOwner = false;
+
+      // Update maximum cluster distance and second cluster bound.
+      if (distances(0, node.Point(i)) > node.Stat().MaxClusterDistance())
+        node.Stat().MaxClusterDistance() = distances(0, node.Point(i));
+      if (distances(1, node.Point(i)) < node.Stat().SecondClusterBound())
+        node.Stat().SecondClusterBound() = distances(1, node.Point(i));
+    }
+
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+    {
+      if (owner == centroids.n_cols + 1)
+        owner = node.Child(i).Stat().Owner();
+      else if (node.Child(i).Stat().Owner() == centroids.n_cols)
+        singleOwner = false;
+      else if (owner != node.Child(i).Stat().Owner())
+        singleOwner = false;
+
+      // Update maximum cluster distance and second cluster bound.
+      if (node.Child(i).Stat().MaxClusterDistance() >
+          node.Stat().MaxClusterDistance())
+        node.Stat().MaxClusterDistance() =
+            node.Child(i).Stat().MaxClusterDistance();
+      if (node.Child(i).Stat().SecondClusterBound() <
+          node.Stat().SecondClusterBound())
+        node.Stat().SecondClusterBound() =
+            node.Child(i).Stat().SecondClusterBound();
+    }
+
+    // Okay, now we know if it's owned or not, and by which cluster.
+    if (singleOwner)
+    {
+      node.Stat().Owner() = owner;
+
+      // Sanity check: ensure the owner is right.
+      for (size_t i = 0; i < node.NumPoints(); ++i)
+      {
+        const double ownerDist = metric.Evaluate(dataset.col(node.Point(i)),
+            centroids.col(owner));
+        for (size_t j = 0; j < centroids.n_cols; ++j)
+        {
+          const double dist = metric.Evaluate(dataset.col(node.Point(i)),
+              centroids.col(j));
+          if (dist < ownerDist)
+          {
+            Log::Warn << node << "...\n" << *node.Parent();
+            Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+                << owner << " but has true owner " << j << "! [" <<
+oldFromNewCentroids[assignments(0, node.Point(i))] << " -- " <<
+metric.Evaluate(dataset.col(node.Point(i)),
+centroids.col(oldFromNewCentroids[assignments(0, node.Point(i))])) << "] " <<
+distances(0, node.Point(i)) << " " <<
+oldFromNewCentroids[assignments(0, node.Point(i))] << " " <<
+oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
+          }
+        }
+      }
+
+      // What is the maximum distance to the closest cluster in the node?
+      if (node.Stat().MaxClusterDistance() +
+          clusterDistances[node.Stat().Owner()] <
+          node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
+      {
+        node.Stat().Pruned() = true;
+      }
+    }
+  }
+  else if (node.Stat().Pruned())
+  {
+    // The node was pruned last iteration.  See if the node can remain pruned.
+    singleOwner = false;
+
+    node.Stat().Pruned() = false;
+    node.Stat().FirstBound() = DBL_MAX;
+    node.Stat().SecondBound() = DBL_MAX;
+    node.Stat().Bound() = DBL_MAX;
+  }
+  else
+  {
+    // The children haven't been pruned, so we can't.
+    // This node was not pruned last iteration, so we simply need to adjust the
+    // bounds.
+    if (node.Stat().FirstBound() != DBL_MAX)
+      node.Stat().FirstBound() += tolerance;
+    if (node.Stat().SecondBound() != DBL_MAX)
+      node.Stat().SecondBound() += tolerance;
+    if (node.Stat().Bound() != DBL_MAX)
+      node.Stat().Bound() += tolerance;
+  }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::PrecalculateCentroids(
+    TreeType& node)
+{
+  if (node.Stat().Pruned() && node.Stat().Owner() < prunedCentroids.n_cols)
+  {
+    prunedCentroids.col(node.Stat().Owner()) += node.Stat().Centroid() *
+        node.NumDescendants();
+    prunedCounts(node.Stat().Owner()) += node.NumDescendants();
+  }
+  else
+  {
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      PrecalculateCentroids(node.Child(i));
+  }
 }
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index c4492c1..eface18 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -48,6 +48,9 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+  if (queryNode.Stat().Pruned())
+    return DBL_MAX;
+
   // Check if the query node is Hamerly pruned, and if not, then don't continue.
   return rules.Score(queryNode, referenceNode);
 }
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index f6b0de6..ba3af0d 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -17,19 +17,36 @@ class DTNNStatistic : public
 {
  public:
   DTNNStatistic() :
+      neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
       pruned(false),
       iteration(0),
-      neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>()
+      maxClusterDistance(0.0),
+      secondClusterBound(0.0),
+      owner(size_t(-1)),
+      centroid()
   {
     // Nothing to do.
   }
 
-  DTNNStatistic(TreeType& /* node */) :
+  template<typename TreeType>
+  DTNNStatistic(TreeType& node) :
+      neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
       pruned(false),
       iteration(0),
-      neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>()
+      maxClusterDistance(0.0),
+      secondClusterBound(0.0),
+      owner(size_t(-1))
   {
-    // Nothing to do.
+    // Empirically calculate the centroid.
+    centroid.zeros(node.Dataset().n_rows);
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+      centroid += node.Dataset().col(node.Point(i));
+
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      centroid += node.Child(i).NumDescendants() *
+          node.Child(i).Stat().Centroid();
+
+    centroid /= node.NumDescendants();
   }
 
   bool Pruned() const { return pruned; }
@@ -38,9 +55,37 @@ class DTNNStatistic : public
   size_t Iteration() const { return iteration; }
   size_t& Iteration() { return iteration; }
 
+  double MaxClusterDistance() const { return maxClusterDistance; }
+  double& MaxClusterDistance() { return maxClusterDistance; }
+
+  double SecondClusterBound() const { return secondClusterBound; }
+  double& SecondClusterBound() { return secondClusterBound; }
+
+  size_t Owner() const { return owner; }
+  size_t& Owner() { return owner; }
+
+  const arma::vec& Centroid() const { return centroid; }
+  arma::vec& Centroid() { return centroid; }
+
+  std::string ToString() const
+  {
+    std::ostringstream o;
+    o << "DTNNStatistic [" << this << "]:\n";
+    o << "  Pruned: " << pruned << ".\n";
+    o << "  Iteration: " << iteration << ".\n";
+    o << "  MaxClusterDistance: " << maxClusterDistance << ".\n";
+    o << "  SecondClusterBound: " << secondClusterBound << ".\n";
+    o << "  Owner: " << owner << ".\n";
+    return o.str();
+  }
+
  private:
   bool pruned;
   size_t iteration;
+  double maxClusterDistance;
+  double secondClusterBound;
+  size_t owner;
+  arma::vec centroid;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list