[mlpack-git] master: Basic static pruning. Minor speedup. (853b4bf)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:52 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 853b4bf6e6e44724717a6c83917df4feb6fc53d4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 17 16:38:35 2015 -0500

    Basic static pruning. Minor speedup.


>---------------------------------------------------------------

853b4bf6e6e44724717a6c83917df4feb6fc53d4
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  6 +-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 98 +++++++++++++++++---------
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  |  3 +
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   |  9 ++-
 4 files changed, 78 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index e756669..26c90df 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -95,13 +95,15 @@ class DTNNKMeans
   //! Update the bounds in the tree before the next iteration.
   void UpdateTree(TreeType& node,
                   arma::vec& clusterDistances,
-                  std::vector<size_t>& oldFromNewCentroids);
+                  std::vector<size_t>& oldFromNewCentroids,
+                  arma::mat& newCentroids);
 
   //! Extract the centroids of the clusters.
   void ExtractCentroids(TreeType& node,
                         arma::mat& newCentroids,
                         arma::Col<size_t>& newCounts,
-                        std::vector<size_t>& oldFromNewCentroids);
+                        std::vector<size_t>& oldFromNewCentroids,
+                        arma::mat& centroids);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 08482f8..ddf7c43 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -128,9 +128,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   // Now we need to extract the clusters.
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
-  ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids);
-  Log::Warn << "New counts: " << counts.t();
-  Log::Warn << accu(counts) << ".\n";
+  ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids,
+      oldCentroids);
 
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
@@ -160,7 +159,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
   distanceCalculations += centroids.n_cols;
 
-  UpdateTree(*tree, clusterDistances, oldFromNewCentroids);
+  UpdateTree(*tree, clusterDistances, oldFromNewCentroids, newCentroids);
 
   delete centroidTree;
 
@@ -173,45 +172,46 @@ template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
     arma::vec& clusterDistances,
-    std::vector<size_t>& oldFromNewCentroids)
+    std::vector<size_t>& oldFromNewCentroids,
+    arma::mat& newCentroids)
 {
-  // Simply reset the bounds.
-  node.Stat().UpperBound() = DBL_MAX;
-  node.Stat().LowerBound() = DBL_MAX;
+  node.Stat().StaticPruned() = false;
+
   if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
       (node.Stat().Owner() < clusterDistances.n_elem - 1))
   {
-    const size_t owner = node.Stat().Owner();
-
-    node.Stat().LastUpperBound() = node.Stat().UpperBound() +
-        clusterDistances[owner];
-
-    // Update child bounds, at least a little.
-    for (size_t i = 0; i < node.NumChildren(); ++i)
+    // Adjust bounds.
+    node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
+    node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+    if (node.Stat().UpperBound() < node.Stat().LowerBound())
     {
-      node.Child(i).Stat().UpperBound() = node.Stat().UpperBound();
-      node.Child(i).Stat().LowerBound() = node.Stat().LowerBound();
-      node.Child(i).Stat().Owner() = node.Stat().Owner();
-      node.Child(i).Stat().Pruned() = node.Stat().Pruned();
+      node.Stat().StaticPruned() = true;
+    }
+    else
+    {
+      // Tighten bound.
+      node.Stat().UpperBound() =
+          node.MaxDistance(newCentroids.col(node.Stat().Owner()));
+      ++distanceCalculations;
+      if (node.Stat().UpperBound() < node.Stat().LowerBound())
+      {
+        node.Stat().StaticPruned() = true;
+      }
     }
   }
-  else if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
-           (node.Stat().Owner() >= clusterDistances.n_elem - 1))
-  {
-    Log::Warn << clusterDistances.n_cols - 1 << ".\n";
-    Log::Warn << node;
-    Log::Fatal << "Node is pruned, but has no owner!\n";
-  }
-  else
+
+  if (!node.Stat().StaticPruned())
   {
-    node.Stat().LastUpperBound() = node.Stat().UpperBound() +
-        clusterDistances[clusterDistances.n_elem - 1];
+    node.Stat().UpperBound() = DBL_MAX;
+    node.Stat().LowerBound() = DBL_MAX;
+    node.Stat().Pruned() = 0;
+    node.Stat().Owner() = clusterDistances.n_elem - 1;
+    node.Stat().StaticPruned() = false;
   }
-  node.Stat().Pruned() = size_t(-1);
-  node.Stat().Owner() = size_t(-1);
 
   for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids);
+    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
+        newCentroids);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -219,14 +219,42 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     TreeType& node,
     arma::mat& newCentroids,
     arma::Col<size_t>& newCounts,
-    std::vector<size_t>& oldFromNewCentroids)
+    std::vector<size_t>& oldFromNewCentroids,
+    arma::mat& centroids)
 {
   // Does this node own points?
-  if (node.Stat().Pruned() == newCentroids.n_cols)
+  if ((node.Stat().Pruned() == newCentroids.n_cols) ||
+      (node.Stat().StaticPruned() && node.Stat().Owner() < newCentroids.n_cols))
   {
     const size_t owner = node.Stat().Owner();
     newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
     newCounts[owner] += node.NumDescendants();
+
+    // Perform the sanity check here.
+/*
+    for (size_t i = 0; i < node.NumDescendants(); ++i)
+    {
+      const size_t index = node.Descendant(i);
+      arma::vec trueDistances(centroids.n_cols);
+      for (size_t j = 0; j < centroids.n_cols; ++j)
+      {
+        const double dist = metric.Evaluate(dataset.col(index),
+                                            centroids.col(j));
+        trueDistances[j] = dist;
+      }
+
+      arma::uword minIndex;
+      const double minDist = trueDistances.min(minIndex);
+      if (size_t(minIndex) != owner)
+      {
+        Log::Warn << trueDistances.t();
+        Log::Fatal << "Point " << index << " of node " << node.Point(0) << "c"
+<< node.NumDescendants() << " has true minimum cluster " << minIndex << " with "
+      << "distance " << minDist << " but node is pruned with upper bound " <<
+node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
+      }
+    }
+*/
   }
   else
   {
@@ -245,7 +273,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     // The node is not entirely owned by a cluster.  Recurse.
     for (size_t i = 0; i < node.NumChildren(); ++i)
       ExtractCentroids(node.Child(i), newCentroids, newCounts,
-          oldFromNewCentroids);
+          oldFromNewCentroids, centroids);
   }
 }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index d5268f4..26e549a 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -87,6 +87,9 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+  if (queryNode.Stat().StaticPruned() == true)
+    return DBL_MAX;
+
   // Pruned() for the root node must never be set to size_t(-1).
   if (queryNode.Stat().Pruned() == size_t(-1))
   {
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index ee24b23..e549afe 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -23,6 +23,7 @@ class DTNNStatistic : public
       lastUpperBound(DBL_MAX),
       owner(size_t(-1)),
       pruned(size_t(-1)),
+      staticPruned(false),
       centroid()
   {
     // Nothing to do.
@@ -35,7 +36,8 @@ class DTNNStatistic : public
       lowerBound(DBL_MAX),
       lastUpperBound(DBL_MAX),
       owner(size_t(-1)),
-      pruned(size_t(-1))
+      pruned(size_t(-1)),
+      staticPruned(false)
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -67,6 +69,9 @@ class DTNNStatistic : public
   size_t Pruned() const { return pruned; }
   size_t& Pruned() { return pruned; }
 
+  bool StaticPruned() const { return staticPruned; }
+  bool& StaticPruned() { return staticPruned; }
+
   std::string ToString() const
   {
     std::ostringstream o;
@@ -75,6 +80,7 @@ class DTNNStatistic : public
     o << "  Lower bound: " << lowerBound << ".\n";
     o << "  Last upper bound: " << lastUpperBound << ".\n";
     o << "  Pruned: " << pruned << ".\n";
+    o << "  Static pruned: " << staticPruned << ".\n";
     o << "  Owner: " << owner << ".\n";
     return o.str();
   }
@@ -85,6 +91,7 @@ class DTNNStatistic : public
   double lastUpperBound;
   size_t owner;
   size_t pruned;
+  bool staticPruned;
   arma::vec centroid;
 };
 



More information about the mlpack-git mailing list