[mlpack-git] master: Refactor UpdateTree() to pass parent bounds. This is a little bit tighter, I think, and is necessary for cover trees, where the MaxDistance() to a node is not necessarily bounded above by the MaxDistance() to the node's parent. (64b52e0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:05:53 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d

>---------------------------------------------------------------

commit 64b52e0c1296e4b4c9edd9ec05d780f4afedbcff
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon May 18 23:40:54 2015 -0400

    Refactor UpdateTree() to pass parent bounds.
    This is a little bit tighter, I think, and is necessary for cover trees, where the MaxDistance() to a node is not necessarily bounded above by the MaxDistance() to the node's parent.


>---------------------------------------------------------------

64b52e0c1296e4b4c9edd9ec05d780f4afedbcff
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |   6 +-
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 124 +++++++++++++++++----
 2 files changed, 110 insertions(+), 20 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index d849379..f4fd3aa 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -100,7 +100,11 @@ class DualTreeKMeans
   //! centroids is the current (not yet searched) centroids.
   void UpdateTree(TreeType& node,
                   const arma::mat& centroids,
-                  const arma::vec& interclusterDistances);
+                  const arma::vec& interclusterDistances,
+                  const double parentUpperBound = 0.0,
+                  const double adjustedParentUpperBound = DBL_MAX,
+                  const double parentLowerBound = DBL_MAX,
+                  const double adjustedParentLowerBound = 0.0);
 
   //! Extract the centroids of the clusters.
   void ExtractCentroids(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 8ccdb43..27e7591 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -136,7 +136,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   }
 
   // We won't use the AllkNN class here because we have our own set of rules.
-  //lastIterationCentroids = oldCentroids;
+  lastIterationCentroids = oldCentroids;
   typedef DualTreeKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
       metric, prunedPoints, oldFromNewCentroids, visited);
@@ -192,6 +192,8 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   delete centroidTree;
 
   ++iteration;
+  Log::Debug << counts.t();
+  Log::Debug << arma::accu(counts) << "!\n";
 
   return std::sqrt(residual);
 }
@@ -200,25 +202,46 @@ template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
     const arma::mat& centroids,
-    const arma::vec& interclusterDistances)
+    const arma::vec& interclusterDistances,
+    const double parentUpperBound,
+    const double adjustedParentUpperBound,
+    const double parentLowerBound,
+    const double adjustedParentLowerBound)
 {
   const bool prunedLastIteration = node.Stat().StaticPruned();
   node.Stat().StaticPruned() = false;
 
   // Grab information from the parent, if we can.
   if (node.Parent() != NULL &&
-      node.Parent()->Stat().Pruned() == centroids.n_cols)
+      node.Parent()->Stat().Pruned() == centroids.n_cols &&
+      node.Parent()->Stat().Owner() < centroids.n_cols)
   {
-    node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
-    node.Stat().LowerBound() = node.Parent()->Stat().LowerBound() +
-        clusterDistances[centroids.n_cols];
+    if (node.Point(0) == 10475 || node.Point(0) == 12756)
+      Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from parent " << node.Parent()->Point(0) << "c" <<
+node.Parent()->NumDescendants() << " from " << node.Stat().UpperBound() << " to " << parentUpperBound << ".\n";
+    // When taking bounds from the parent, note that the parent has already
+    // adjusted the bounds according to the cluster movements, so we need to
+    // de-adjust them since we'll adjust them again.  Maybe there is a smarter
+    // way to do this...
+    if (node.Stat().UpperBound() < parentUpperBound && !prunedLastIteration)
+    {
+      Log::Warn << node;
+      Log::Fatal << "Wat, upper bound not DBL_MAX.\n";
+    }
+    node.Stat().UpperBound() = parentUpperBound;
+    node.Stat().LowerBound() = parentLowerBound;
     node.Stat().Pruned() = node.Parent()->Stat().Pruned();
     node.Stat().Owner() = node.Parent()->Stat().Owner();
   }
-
+  const double unadjustedUpperBound = node.Stat().UpperBound();
+  double adjustedUpperBound = adjustedParentUpperBound;
+  const double unadjustedLowerBound = node.Stat().LowerBound();
+  double adjustedLowerBound = adjustedParentLowerBound;
 
   // Exhaustive lower bound check. Sigh.
-/*  if (!prunedLastIteration)
+
+  if (!prunedLastIteration)
   {
     for (size_t i = 0; i < node.NumDescendants(); ++i)
     {
@@ -239,7 +262,6 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
         else if (dist < secondClosest)
           secondClosest = dist;
       }
-
       if (closest - 1e-10 > node.Stat().UpperBound())
       {
         Log::Warn << distances.t();
@@ -265,31 +287,82 @@ visited[node.Descendant(i)] << ".\n";
       }
     }
   }
-  }*/
+  }
+
 
 
   if ((node.Stat().Pruned() == centroids.n_cols) &&
       (node.Stat().Owner() < centroids.n_cols))
   {
     // Adjust bounds.
+    if (node.Point(0) == 10475 || node.Point(0) == 12756)
+      Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
+node.Stat().UpperBound() + clusterDistances[node.Stat().Owner()] << " as cluster"
+ << " adjustment.\n";
     node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
     node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
+    if (node.Point(0) == 10475 || node.Point(0) == 12756)
+      Log::Debug << node;
+
+    if (adjustedParentUpperBound < node.Stat().UpperBound())
+    {
+      if (node.Point(0) == 10475 || node.Point(0) == 12756)
+      {
+        Log::Debug << "Take adjusted parent upper bound of " <<
+adjustedParentUpperBound << " vs. adjusted bound of " <<
+node.Stat().UpperBound() << ".\n";
+      }
+      node.Stat().UpperBound() = adjustedParentUpperBound;
+      if (node.Point(0) == 10475 || node.Point(0) == 12756)
+        Log::Debug << node;
+    }
+
+    if (adjustedParentLowerBound > node.Stat().LowerBound())
+      node.Stat().LowerBound() = adjustedParentLowerBound;
+
     const double interclusterBound = interclusterDistances[node.Stat().Owner()]
         / 2.0;
     if (interclusterBound > node.Stat().LowerBound())
+    {
       node.Stat().LowerBound() = interclusterBound;
+      adjustedLowerBound = node.Stat().LowerBound();
+    }
     if (node.Stat().UpperBound() < node.Stat().LowerBound())
     {
+      if (node.Point(0) == 10475 || node.Point(0) == 12756)
+        Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
+<< " as statically pruned.\n" << node;
       node.Stat().StaticPruned() = true;
     }
     else
     {
       // Tighten bound.
+      if (node.Point(0) == 10475 || node.Point(0) == 12756)
+      {
+        Log::Debug << centroids.col(node.Stat().Owner()).t() << ".\n";
+        Log::Debug << dataset.col(node.Point(0)).t() << ".\n";
+        Log::Debug << "FDD " << node.FurthestDescendantDistance() << ".\n";
+        Log::Debug << metric.Evaluate(centroids.col(node.Stat().Owner()),
+dataset.col(node.Point(0))) << ".\n";
+        Log::Debug << "Tighten upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
+node.MaxDistance(centroids.col(node.Stat().Owner())) << " with owner " <<
+node.Stat().Owner() << ".\n";
+      }
       node.Stat().UpperBound() =
-          node.MaxDistance(centroids.col(node.Stat().Owner()));
+          std::min(node.Stat().UpperBound(),
+                   node.MaxDistance(centroids.col(node.Stat().Owner())));
+      adjustedUpperBound = node.Stat().UpperBound();
+      if (node.Point(0) == 10475 || node.Point(0) == 12756)
+        Log::Debug << node;
+
       ++distanceCalculations;
       if (node.Stat().UpperBound() < node.Stat().LowerBound())
       {
+        if (node.Point(0) == 10475 || node.Point(0) == 12756)
+          Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
+<< " as statically pruned.\n" << node;
         node.Stat().StaticPruned() = true;
       }
     }
@@ -362,14 +435,20 @@ visited[node.Descendant(i)] << ".\n";
   bool allChildrenPruned = true;
   for (size_t i = 0; i < node.NumChildren(); ++i)
   {
-    UpdateTree(node.Child(i), centroids, interclusterDistances);
+    UpdateTree(node.Child(i), centroids, interclusterDistances,
+        unadjustedUpperBound, adjustedUpperBound, unadjustedLowerBound,
+        adjustedLowerBound);
     if (!node.Child(i).Stat().StaticPruned())
       allChildrenPruned = false;
+//      node.Child(i).Stat().StaticPruned() = true;
   }
 
   if (node.Stat().StaticPruned() && !allChildrenPruned)
   {
     Log::Warn << node;
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      Log::Warn << "child " << i << ":\n" << node.Child(i);
+//    Log::Warn << "grandchild: " << node.Child(1).Child(1);
     Log::Fatal << "Node is statically pruned but not all its children are!\n";
   }
 
@@ -408,6 +487,9 @@ visited[node.Descendant(i)] << ".\n";
           clusterDistances[centroids.n_cols];
     }
   }
+
+  if (node.Point(0) == 12756)
+    Log::Debug << "Node at end of iteration:\n" << node;
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
@@ -426,7 +508,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     newCounts[owner] += node.NumDescendants();
 
     // Perform the sanity check here.
-/*
+
     for (size_t i = 0; i < node.NumDescendants(); ++i)
     {
       const size_t index = node.Descendant(i);
@@ -450,7 +532,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
 node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
       }
     }
-*/
+
   }
   else
   {
@@ -464,7 +546,7 @@ node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
         newCentroids.col(owner) += dataset.col(node.Point(i));
         ++newCounts[owner];
 
-/*
+
         const size_t index = node.Point(i);
         arma::vec trueDistances(centroids.n_cols);
         for (size_t j = 0; j < centroids.n_cols; ++j)
@@ -489,7 +571,7 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
 (visited[node.Point(i)] ? "true"
 : "false") << ".\n";
         }
-*/
+
       }
     }
 
@@ -501,9 +583,10 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
-    TreeType& node,
-    const size_t child /* Which child are we? */)
+    TreeType& /*node*/,
+    const size_t /*child /* Which child are we? */)
 {
+/*
   // If all children except one are pruned, we can hide this node.
   if (node.NumChildren() == 0)
     return; // We can't do anything.
@@ -544,12 +627,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
     for (size_t i = 0; i < node.NumChildren(); ++i)
       CoalesceTree(node.Child(i), i);
   }
+*/
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
-    TreeType& node)
+    TreeType& /*node*/)
 {
+/*
   node.Parent() = (TreeType*) node.Stat().TrueParent();
   for (size_t i = 0; i < node.NumChildren(); ++i)
     node.ChildPtr(i) = (TreeType*) node.Stat().TrueChild(i);
@@ -559,6 +644,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
     DecoalesceTree(node.Child(0));
     DecoalesceTree(node.Child(1));
   }
+*/
 }
 
 } // namespace kmeans



More information about the mlpack-git mailing list