[mlpack-git] master: Start applying prune on intercluster distances. Not quite done yet. (95cb008)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:14 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 95cb008c6ddde0e2b5a7b3fcc7027a509f766fd8
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Feb 18 18:34:34 2015 -0500

    Start applying prune on intercluster distances. Not quite done yet.


>---------------------------------------------------------------

95cb008c6ddde0e2b5a7b3fcc7027a509f766fd8
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  3 +-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 47 ++++++++++++++------------
 2 files changed, 28 insertions(+), 22 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 6c23bd2..ec141db 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -97,7 +97,8 @@ class DTNNKMeans
   //! Update the bounds in the tree before the next iteration.
   //! centroids is the current (not yet searched) centroids.
   void UpdateTree(TreeType& node,
-                  const arma::mat& centroids);
+                  const arma::mat& centroids,
+                  const arma::mat& interclusterDistances);
 
   //! Extract the centroids of the clusters.
   void ExtractCentroids(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 6bc3228..ef6cc2f 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -94,26 +94,12 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::mat& newCentroids,
     arma::Col<size_t>& counts)
 {
-  // Reset information, if we need to.
-  if (iteration > 0)
-  {
-    UpdateTree(*tree, centroids);
-
-    for (size_t i = 0; i < dataset.n_cols; ++i)
-      visited[i] = false;
-  }
-  else
-  {
-    // Not initialized yet.
-    clusterDistances.set_size(centroids.n_cols + 1);
-  }
-
   // Build a tree on the centroids.
   arma::mat oldCentroids(centroids); // Slow. :(
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
-/*
+
   Timer::Start("knn");
   // Find the nearest neighbors of each of the clusters.
   neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
@@ -121,9 +107,23 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   arma::mat interclusterDistances;
   arma::Mat<size_t> closestClusters; // We don't actually care about these.
   nns.Search(1, closestClusters, interclusterDistances);
-  distanceCalculations += nns.BaseCases() + nns.Scores();
+//  distanceCalculations += nns.BaseCases() + nns.Scores();
   Timer::Stop("knn");
-*/
+
+  // Reset information in the tree, if we need to.
+  if (iteration > 0)
+  {
+    UpdateTree(*tree, oldCentroids, interclusterDistances);
+
+    for (size_t i = 0; i < dataset.n_cols; ++i)
+      visited[i] = false;
+  }
+  else
+  {
+    // Not initialized yet.
+    clusterDistances.set_size(centroids.n_cols + 1);
+  }
+
   // We won't use the AllkNN class here because we have our own set of rules.
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
@@ -187,7 +187,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
-    const arma::mat& centroids)
+    const arma::mat& centroids,
+    const arma::mat& interclusterDistances)
 {
   const bool prunedLastIteration = node.Stat().StaticPruned();
   node.Stat().StaticPruned() = false;
@@ -208,7 +209,9 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     // Adjust bounds.
     node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
     node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
-    if (node.Stat().UpperBound() < node.Stat().LowerBound())
+    const double lowerBound = std::max(node.Stat().LowerBound(),
+        interclusterDistances[node.Stat().Owner()] / 2.0);
+    if (node.Stat().UpperBound() < lowerBound)
     {
       node.Stat().StaticPruned() = true;
     }
@@ -218,7 +221,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       node.Stat().UpperBound() =
           node.MaxDistance(centroids.col(node.Stat().Owner()));
       ++distanceCalculations;
-      if (node.Stat().UpperBound() < node.Stat().LowerBound())
+      if (node.Stat().UpperBound() < lowerBound)
       {
         node.Stat().StaticPruned() = true;
       }
@@ -251,6 +254,8 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       const size_t owner = assignments[node.Point(i)];
       const double lowerBound = std::min(lowerBounds[index] -
           clusterDistances[centroids.n_cols], node.Stat().LowerBound());
+//      const double pruningLowerBound = std::max(lowerBound,
+//          interclusterDistances[owner] / 2.0);
       if (upperBounds[index] + clusterDistances[owner] < lowerBound)
       {
         prunedPoints[index] = true;
@@ -295,7 +300,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
   }
 
   for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i), centroids);
+    UpdateTree(node.Child(i), centroids, interclusterDistances);
 
   if (!node.Stat().StaticPruned())
   {



More information about the mlpack-git mailing list