[mlpack-git] master: Refactor: do tree update before kNN search. This will allow us to make tighter prunes. (d631986)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:52 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit d631986bf08dd5fc00f5b6a928e04a432c1e4b0a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 2 20:57:51 2015 -0500

    Refactor: do tree update before kNN search. This will allow us to make tighter prunes.


>---------------------------------------------------------------

d631986bf08dd5fc00f5b6a928e04a432c1e4b0a
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      | 13 ++++---
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 54 +++++++++++++-------------
 2 files changed, 34 insertions(+), 33 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index abc8236..b8c2792 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -84,6 +84,9 @@ class DTNNKMeans
   //! Counts from pruning.  Not normalized.
   arma::Col<size_t> prunedCounts;
 
+  //! Distances that the clusters moved last iteration.
+  arma::vec clusterDistances;
+
   //! Upper bounds on cluster distances for each point.
   arma::vec upperBounds;
   //! Lower bounds on second closest cluster distance for each point.
@@ -93,14 +96,14 @@ class DTNNKMeans
   //! The last cluster each point was assigned to.
   arma::Col<size_t> lastOwners;
 
+  arma::mat distances;
+  arma::Mat<size_t> assignments;
+
+  std::vector<size_t> lastOldFromNewCentroids;
+
   //! Update the bounds in the tree before the next iteration.
   void UpdateTree(TreeType& node,
-                  const double tolerance,
                   const arma::mat& centroids,
-                  const arma::Mat<size_t>& assignments,
-                  const arma::mat& distances,
-                  const arma::mat& clusterDistances,
-                  const std::vector<size_t>& oldFromNewCentroids,
                   const arma::mat& interclusterDistances,
                   const std::vector<size_t>& newFromOldCentroids);
 
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 305fbc1..407b16c 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -52,7 +52,9 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
         datasetOrig),
     metric(metric),
     distanceCalculations(0),
-    iteration(0)
+    iteration(0),
+    distances(2, dataset.n_cols),
+    assignments(2, dataset.n_cols)
 {
   prunedPoints.resize(dataset.n_cols, false); // Fill with false.
   upperBounds.set_size(dataset.n_cols);
@@ -90,6 +92,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   {
     prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
     prunedCounts.zeros(centroids.n_cols);
+    // The last element stores the maximum.
+    clusterDistances.zeros(centroids.n_cols + 1);
   }
 
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
@@ -117,10 +121,21 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   nns.Search(1, closestClusters, interclusterDistances);
   distanceCalculations += nns.BaseCases() + nns.Scores();
 
+  if (iteration != 0)
+  {
+    // Do the tree update for the previous iteration.
+    Log::Warn << "Performing tree update.\n";
+
+    // Reset centroids and counts for things we will collect during pruning.
+    prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+    prunedCounts.zeros(centroids.n_cols);
+    UpdateTree(*tree, oldCentroids, interclusterDistances, newFromOldCentroids);
+
+    PrecalculateCentroids(*tree);
+  }
+
   // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
-  arma::mat distances(2, dataset.n_cols);
-  arma::Mat<size_t> assignments(2, dataset.n_cols);
   distances.fill(DBL_MAX);
   assignments.fill(size_t(-1));
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
@@ -161,7 +176,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
   double maxMovement = 0.0;
-  arma::vec clusterDistances(centroids.n_cols + 1);
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
     // Get the mapping to the old cluster, if necessary.
@@ -187,14 +201,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   clusterDistances[centroids.n_cols] = maxMovement;
   distanceCalculations += centroids.n_cols;
 
-  // Reset centroids and counts for things we will collect during pruning.
-  prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
-  prunedCounts.zeros(centroids.n_cols);
-  UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
-      clusterDistances, oldFromNewCentroids, interclusterDistances,
-      newFromOldCentroids);
-
-  PrecalculateCentroids(*tree);
+  lastOldFromNewCentroids = oldFromNewCentroids;
 
   delete centroidTree;
 
@@ -206,12 +213,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
-    const double tolerance,
     const arma::mat& centroids,
-    const arma::Mat<size_t>& assignments,
-    const arma::mat& distances,
-    const arma::mat& clusterDistances,
-    const std::vector<size_t>& oldFromNewCentroids,
     const arma::mat& interclusterDistances,
     const std::vector<size_t>& newFromOldCentroids)
 {
@@ -225,8 +227,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
   bool childrenPruned = true;
   for (size_t i = 0; i < node.NumChildren(); ++i)
   {
-    UpdateTree(node.Child(i), tolerance, centroids, assignments, distances,
-        clusterDistances, oldFromNewCentroids, interclusterDistances,
+    UpdateTree(node.Child(i), centroids, interclusterDistances,
         newFromOldCentroids);
     if (!node.Child(i).Stat().Pruned())
       childrenPruned = false; // Not all children are pruned.
@@ -249,7 +250,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       size_t c;
       if (!prunedPoints[node.Point(i)])
         c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
-            oldFromNewCentroids[assignments(0, node.Point(i))] :
+            lastOldFromNewCentroids[assignments(0, node.Point(i))] :
             assignments(0, node.Point(i));
       else
         c = lastOwners[node.Point(i)];
@@ -293,10 +294,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
         newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
     }
 
-//    if (node.NumChildren() > 0)
-//      Log::Warn << "Node:\n" << node << "single owner: " << singleOwner <<
-//".l\n" << node.Child(0) << ".r\n" << node.Child(1) << ".\n";
-
     // What do we do with the new cluster bounds?
     if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
         node.Stat().MaxClusterDistance())
@@ -483,7 +480,8 @@ node.Child(0) << ", r\n" << node.Child(1) << ".\n";
       if (!prunedLastIteration && !prunedPoints[index])
       {
         owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
-            oldFromNewCentroids[assignments(0, index)] : assignments(0, index);
+            lastOldFromNewCentroids[assignments(0, index)] :
+            assignments(0, index);
         // Establish bounds, since these points were searched this iteration.
         upperBounds[index] = distances(0, index);
         lowerSecondBounds[index] = distances(1, index);
@@ -564,11 +562,11 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
   }
 
   if (node.Stat().FirstBound() != DBL_MAX)
-    node.Stat().FirstBound() += tolerance;
+    node.Stat().FirstBound() += clusterDistances[centroids.n_cols];
   if (node.Stat().SecondBound() != DBL_MAX)
-    node.Stat().SecondBound() += tolerance;
+    node.Stat().SecondBound() += clusterDistances[centroids.n_cols];
   if (node.Stat().Bound() != DBL_MAX)
-    node.Stat().Bound() += tolerance;
+    node.Stat().Bound() += clusterDistances[centroids.n_cols];
 }
 
 template<typename MetricType, typename MatType, typename TreeType>



More information about the mlpack-git mailing list