[mlpack-git] master: A tighter, but random, prune. Minor speedup. Also add a commented-out check for the lower bounds, and fix a minor bug that screwed them up. (01ec8a5)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:06 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 01ec8a54be0e047cfd45872d391ff59715799051
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 24 17:40:27 2015 -0500

    A tighter, but random, prune. Minor speedup. Also add a commented-out check for the lower bounds, and fix a minor bug that screwed them up.


>---------------------------------------------------------------

01ec8a54be0e047cfd45872d391ff59715799051
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 62 +++++++++++++++++++++++++-
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  | 10 ++++-
 2 files changed, 69 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 8a76d06..ffd818f 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -132,6 +132,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
 
   // We won't use the AllkNN class here because we have our own set of rules.
+  //lastIterationCentroids = oldCentroids;
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
       metric, prunedPoints, oldFromNewCentroids, visited);
@@ -205,11 +206,64 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       node.Parent()->Stat().Pruned() == centroids.n_cols)
   {
     node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
-    node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
+    node.Stat().LowerBound() = node.Parent()->Stat().LowerBound() +
+        clusterDistances[centroids.n_cols];
     node.Stat().Pruned() = node.Parent()->Stat().Pruned();
     node.Stat().Owner() = node.Parent()->Stat().Owner();
   }
 
+
+  // Exhaustive lower bound check. Sigh.
+/*  if (!prunedLastIteration)
+  {
+    for (size_t i = 0; i < node.NumDescendants(); ++i)
+    {
+      double closest = DBL_MAX;
+      double secondClosest = DBL_MAX;
+      arma::vec distances(centroids.n_cols);
+      for (size_t j = 0; j < centroids.n_cols; ++j)
+      {
+        const double dist = metric.Evaluate(dataset.col(node.Descendant(i)),
+            lastIterationCentroids.col(j));
+        distances(j) = dist;
+
+        if (dist < closest)
+        {
+          secondClosest = closest;
+          closest = dist;
+        }
+        else if (dist < secondClosest)
+          secondClosest = dist;
+      }
+
+      if (closest - 1e-10 > node.Stat().UpperBound())
+      {
+        Log::Warn << distances.t();
+      Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
+"c" << node.NumDescendants() << " invalidates upper bound " <<
+node.Stat().UpperBound() << " with closest cluster distance " << closest <<
+".\n";
+      }
+
+    if (node.NumChildren() == 0)
+    {
+      if (secondClosest + 1e-10 < std::min(lowerBounds[node.Descendant(i)],
+  node.Stat().LowerBound()))
+      {
+      Log::Warn << distances.t();
+      Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
+"c" << node.NumDescendants() << " invalidates lower bound " <<
+std::min(lowerBounds[node.Descendant(i)], node.Stat().LowerBound()) << " (" <<
+lowerBounds[node.Descendant(i)] << ", " << node.Stat().LowerBound() << ") with "
+      << "second closest cluster distance " << secondClosest << ". cd " <<
+closest << "; pruned " << prunedPoints[node.Descendant(i)] << " visited " <<
+visited[node.Descendant(i)] << ".\n";
+      }
+    }
+  }
+  }*/
+
+
   if ((node.Stat().Pruned() == centroids.n_cols) &&
       (node.Stat().Owner() < centroids.n_cols))
   {
@@ -307,6 +361,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       allChildrenPruned = false;
   }
 
+  if (node.Stat().StaticPruned() && !allChildrenPruned)
+  {
+    Log::Warn << node;
+    Log::Fatal << "Node is statically pruned but not all its children are!\n";
+  }
+
   // If all of the children and points are pruned, we may mark this node as
   // pruned.
   if (allChildrenPruned && allPointsPruned && !node.Stat().StaticPruned())
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index c87c1da..6eb714d 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -217,7 +217,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
       if (adjustedScore < queryNode.Stat().LowerBound())
       {
         // If this might affect the lower bound, make it more exact.
-        queryNode.Stat().LowerBound() = queryNode.MinDistance(&referenceNode);
+        queryNode.Stat().LowerBound() = std::min(queryNode.Stat().LowerBound(),
+            queryNode.MinDistance(&referenceNode));
         ++scores;
       }
 
@@ -229,7 +230,12 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   if (score != DBL_MAX)
   {
     // Get minimum and maximum distances.
-    math::Range distances = queryNode.RangeDistance(&referenceNode);
+//    math::Range distances = queryNode.RangeDistance(&referenceNode);
+    math::Range distances;
+    distances.Lo() = queryNode.MinDistance(&referenceNode);
+    distances.Hi() =
+        queryNode.MaxDistance(centroids.col(referenceNode.Descendant(0)));
+
     score = distances.Lo();
     ++scores;
     if (distances.Lo() > queryNode.Stat().UpperBound())



More information about the mlpack-git mailing list