[mlpack-git] master: Perform tree update at start of iteration. Cache some variables inside DTNNKMeans. (f440ca0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:39 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit f440ca00db79bac726e2e3e7f825259cf0a575f9
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Feb 18 18:12:11 2015 -0500

    Perform tree update at start of iteration. Cache some variables inside DTNNKMeans.


>---------------------------------------------------------------

f440ca00db79bac726e2e3e7f825259cf0a575f9
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  8 ++--
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 58 +++++++++++++-------------
 2 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index bf83533..6c23bd2 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -92,17 +92,17 @@ class DTNNKMeans
 
   arma::mat lastIterationCentroids; // For sanity checks.
 
+  arma::vec clusterDistances; // The amount the clusters moved last iteration.
+
   //! Update the bounds in the tree before the next iteration.
+  //! centroids is the current (not yet searched) centroids.
   void UpdateTree(TreeType& node,
-                  arma::vec& clusterDistances,
-                  std::vector<size_t>& oldFromNewCentroids,
-                  arma::mat& newCentroids);
+                  const arma::mat& centroids);
 
   //! Extract the centroids of the clusters.
   void ExtractCentroids(TreeType& node,
                         arma::mat& newCentroids,
                         arma::Col<size_t>& newCounts,
-                        std::vector<size_t>& oldFromNewCentroids,
                         arma::mat& centroids);
 
   void CoalesceTree(TreeType& node, const size_t child = 0);
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 0ab1bd4..6bc3228 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -71,7 +71,10 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
   Timer::Stop("tree_building");
 
   for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
     prunedPoints[i] = false;
+    visited[i] = false;
+  }
   assignments.fill(size_t(-1));
   upperBounds.fill(DBL_MAX);
   lowerBounds.fill(DBL_MAX);
@@ -91,16 +94,25 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::mat& newCentroids,
     arma::Col<size_t>& counts)
 {
-  // Reset information.
-  for (size_t i = 0; i < dataset.n_cols; ++i)
-    visited[i] = false;
+  // Reset information, if we need to.
+  if (iteration > 0)
+  {
+    UpdateTree(*tree, centroids);
+
+    for (size_t i = 0; i < dataset.n_cols; ++i)
+      visited[i] = false;
+  }
+  else
+  {
+    // Not initialized yet.
+    clusterDistances.set_size(centroids.n_cols + 1);
+  }
 
   // Build a tree on the centroids.
   arma::mat oldCentroids(centroids); // Slow. :(
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
-
 /*
   Timer::Start("knn");
   // Find the nearest neighbors of each of the clusters.
@@ -112,7 +124,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   distanceCalculations += nns.BaseCases() + nns.Scores();
   Timer::Stop("knn");
 */
-
   // We won't use the AllkNN class here because we have our own set of rules.
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
@@ -137,12 +148,10 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   // Now we need to extract the clusters.
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
-  ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids,
-      oldCentroids);
+  ExtractCentroids(*tree, newCentroids, counts, oldCentroids);
 
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
-  arma::vec clusterDistances(centroids.n_cols + 1);
   clusterDistances[centroids.n_cols] = 0.0;
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
@@ -168,8 +177,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
   distanceCalculations += centroids.n_cols;
 
-  UpdateTree(*tree, clusterDistances, oldFromNewCentroids, newCentroids);
-
   delete centroidTree;
 
   ++iteration;
@@ -180,16 +187,14 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
-    arma::vec& clusterDistances,
-    std::vector<size_t>& oldFromNewCentroids,
-    arma::mat& newCentroids)
+    const arma::mat& centroids)
 {
   const bool prunedLastIteration = node.Stat().StaticPruned();
   node.Stat().StaticPruned() = false;
 
   // Grab information from the parent, if we can.
   if (node.Parent() != NULL &&
-      node.Parent()->Stat().Pruned() == clusterDistances.n_elem - 1)
+      node.Parent()->Stat().Pruned() == centroids.n_cols)
   {
     node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
     node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
@@ -197,12 +202,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     node.Stat().Owner() = node.Parent()->Stat().Owner();
   }
 
-  if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
-      (node.Stat().Owner() < clusterDistances.n_elem - 1))
+  if ((node.Stat().Pruned() == centroids.n_cols) &&
+      (node.Stat().Owner() < centroids.n_cols))
   {
     // Adjust bounds.
     node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
-    node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+    node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
     if (node.Stat().UpperBound() < node.Stat().LowerBound())
     {
       node.Stat().StaticPruned() = true;
@@ -211,7 +216,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     {
       // Tighten bound.
       node.Stat().UpperBound() =
-          node.MaxDistance(newCentroids.col(node.Stat().Owner()));
+          node.MaxDistance(centroids.col(node.Stat().Owner()));
       ++distanceCalculations;
       if (node.Stat().UpperBound() < node.Stat().LowerBound())
       {
@@ -221,7 +226,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
   }
   else
   {
-    node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+    node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
   }
 
   if (!node.Stat().StaticPruned())
@@ -245,7 +250,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       prunedPoints[index] = false;
       const size_t owner = assignments[node.Point(i)];
       const double lowerBound = std::min(lowerBounds[index] -
-          clusterDistances[newCentroids.n_cols], node.Stat().LowerBound());
+          clusterDistances[centroids.n_cols], node.Stat().LowerBound());
       if (upperBounds[index] + clusterDistances[owner] < lowerBound)
       {
         prunedPoints[index] = true;
@@ -256,7 +261,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       {
         // Attempt to tighten the bound.
         upperBounds[index] = metric.Evaluate(dataset.col(index),
-                                             newCentroids.col(owner));
+                                             centroids.col(owner));
         ++distanceCalculations;
         if (upperBounds[index] < lowerBound)
         {
@@ -280,7 +285,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       node.Stat().StaticUpperBoundMovement() +=
           clusterDistances[node.Stat().Owner()];
       node.Stat().StaticLowerBoundMovement() +=
-          clusterDistances[newCentroids.n_cols];
+          clusterDistances[centroids.n_cols];
     }
     else
     {
@@ -290,15 +295,14 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
   }
 
   for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
-        newCentroids);
+    UpdateTree(node.Child(i), centroids);
 
   if (!node.Stat().StaticPruned())
   {
     node.Stat().UpperBound() = DBL_MAX;
     node.Stat().LowerBound() = DBL_MAX;
     node.Stat().Pruned() = size_t(-1);
-    node.Stat().Owner() = clusterDistances.n_elem - 1;
+    node.Stat().Owner() = centroids.n_cols;
     node.Stat().StaticPruned() = false;
   }
 }
@@ -308,7 +312,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     TreeType& node,
     arma::mat& newCentroids,
     arma::Col<size_t>& newCounts,
-    std::vector<size_t>& oldFromNewCentroids,
     arma::mat& centroids)
 {
   // Does this node own points?
@@ -386,8 +389,7 @@ assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
 
     // The node is not entirely owned by a cluster.  Recurse.
     for (size_t i = 0; i < node.NumChildren(); ++i)
-      ExtractCentroids(node.Child(i), newCentroids, newCounts,
-          oldFromNewCentroids, centroids);
+      ExtractCentroids(node.Child(i), newCentroids, newCounts, centroids);
   }
 }
 



More information about the mlpack-git mailing list