[mlpack-git] master: A first attempt at a working Hamerly prune. The bounds tighten too much and don't reset, so there's not much speedup, but it's a start. (c83b94b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:28 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit c83b94bc2243dcb3143cbcb521b4bcbe2aa757db
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jan 21 16:50:07 2015 -0500

    A first attempt at a working Hamerly prune. The bounds tighten too much and don't reset, so there's not much speedup, but it's a start.


>---------------------------------------------------------------

c83b94bc2243dcb3143cbcb521b4bcbe2aa757db
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |   5 +-
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 116 ++++++++++++++++-----
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp |  90 ++++++++++------
 3 files changed, 152 insertions(+), 59 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 27bcf25..b7e3c61 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -63,7 +63,10 @@ class DualTreeKMeans
 
   void TreeUpdate(TreeType* node,
                   const size_t clusters,
-                  const arma::vec& clusterDistances);
+                  const arma::vec& clusterDistances,
+                  const arma::Col<size_t>& assignments,
+                  const arma::mat& oldCentroids,
+                  const arma::mat& dataset);
 };
 
 template<typename MetricType, typename MatType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 068a74f..747e69f 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -66,6 +66,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   }
 
   // Build a tree on the centroids.
+  arma::mat oldCentroids(centroids);
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
@@ -120,10 +121,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
       residual += std::pow(dist, 2.0);
     }
   }
-//  Log::Info << clusterDistances.t();
 
   // Update the tree with the centroid movement information.
-  TreeUpdate(tree, centroids.n_cols, clusterDistances);
+  TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
+      oldCentroids, dataset);
 
   delete centroidTree;
 
@@ -157,7 +158,10 @@ template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     TreeType* node,
     const size_t clusters,
-    const arma::vec& clusterDistances)
+    const arma::vec& clusterDistances,
+    const arma::Col<size_t>& assignments,
+    const arma::mat& centroids,
+    const arma::mat& dataset)
 {
   // This is basically IterationUpdate(), but pulled out to be separate from the
   // actual dual-tree algorithm.
@@ -165,6 +169,22 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
   if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
     node->Stat().Owner() = node->Parent()->Stat().Owner();
 
+  const size_t cluster = assignments[node->Descendant(0)];
+  bool allSame = true;
+  for (size_t i = 1; i < node->NumDescendants(); ++i)
+  {
+    if (assignments[node->Descendant(i)] != cluster)
+    {
+      allSame = false;
+      break;
+    }
+  }
+
+  if (allSame)
+    node->Stat().Owner() = cluster;
+
+  node->Stat().HamerlyPruned() = false;
+
   // The easy case: this node had an owner.
   if (node->Stat().Owner() < clusters)
   {
@@ -175,24 +195,62 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
       node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
 
-/*
-    // During the last iteration, this node was pruned.  In addition, we have
-    // cached a lower bound on the second closest cluster.  So, use the
-    // triangle inequality: if the maximum distance between the point and the
-    // cluster centroid plus the distance that centroid moved is less than the
-    // lower bound minus the maximum moving centroid, then this cluster *must*
-    // still have the same owner.
-    const size_t owner = node->Stat().Owner();
-    const double closestUpperBound = node->Stat().MaxQueryNodeDistance() +
-        clusterDistances[owner];
-    const TreeType* nonOwner = (TreeType*) node->Stat().ClosestNonOwner();
-    const double tightestLowerBound = node->Stat().ClosestNonOwnerDistance() -
-        nonOwner->Stat().MinQueryNodeDistance();
-    if (closestUpperBound <= tightestLowerBound)
+    // Check if we can perform a Hamerly prune: if the node has an owner, and
+    // the second closest cluster could not have moved close enough that any
+    // points could have changed assignment, then this node *must* belong to the
+    // same owner in the next iteration.  Note that MaxQueryNodeDistance() has
+    // already been adjusted for cluster movement.
+
+    if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
+        - clusterDistances[clusters])
     {
-      // Then the owner must not have changed.
+      node->Stat().HamerlyPruned() = true;
+      Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
+          << "Hamerly pruned.\n";
+
+      // Check the second bound.  (This is time-consuming...)
+      for (size_t j = 0; j < node->NumDescendants(); ++j)
+      {
+        arma::vec distances(centroids.n_cols);
+        double secondClosestDist = DBL_MAX;
+        for (size_t i = 0; i < centroids.n_cols; ++i)
+        {
+          const double distance = MetricType::Evaluate(centroids.col(i),
+              dataset.col(node->Descendant(j)));
+          if (distance < secondClosestDist && i != node->Stat().Owner())
+            secondClosestDist = distance;
+
+          distances(i) = distance;
+        }
+
+        if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
+        {
+          Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
+node->Stat().MaxQueryNodeDistance() << ", mnqnd " <<
+node->Stat().MinQueryNodeDistance() << ".\n";
+          Log::Warn << distances.t();
+          Log::Fatal << "Second closest bound " <<
+node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
+              << "! (" << node->Stat().SecondClosestBound() - secondClosestDist
+<< ")\n";
+        }
+//        if (node->Begin() == 37591)
+//          Log::Warn << "r37591c" << node->Count() << ": " << distances.t();
+      }
     }
-*/
+//    else
+//    {
+//      Log::Warn << "Failed Hamerly prune for r" << node->Begin() << "c" <<
+//          node->Count() << "; mqnd " << node->Stat().MaxQueryNodeDistance() <<
+//          ", scb " << node->Stat().SecondClosestBound() << ".\n";
+//    }
+
+//    if (node->Stat().SecondClosestBound() == DBL_MAX)
+//   {
+//      Log::Warn << "r" << node->Begin() << "c" << node->Count() << " never had "
+//          << "the second bound updated.\n";
+//    }
+
   }
   else
   {
@@ -204,6 +262,9 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
       node->Stat().MaxQueryNodeDistance() += clusterDistances[clusters];
     if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
       node->Stat().MinQueryNodeDistance() += clusterDistances[clusters];
+
+    // Since the node didn't have an owner, it can't be Hamerly pruned.
+    node->Stat().HamerlyPruned() = false;
   }
 
   node->Stat().Iteration() = iteration;
@@ -211,11 +272,18 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
   // We have to set the closest query node to NULL because the cluster tree will
   // be rebuilt.
   node->Stat().ClosestQueryNode() = NULL;
-//  node->Stat().MaxQueryNodeDistance() = DBL_MAX;
-//  node->Stat().MinQueryNodeDistance() = DBL_MAX;
-
-  for (size_t i = 0; i < node->NumChildren(); ++i)
-    TreeUpdate(&node->Child(i), clusters, clusterDistances);
+  node->Stat().SecondClosestBound() -= clusterDistances[clusters];
+  if (node->Stat().SecondClosestBound() < 0)
+    node->Stat().SecondClosestBound() = 0;
+
+//  if (node->Begin() == 37591)
+//    Log::Warn << "scb for r37591c" << node->Count() << " updated to " <<
+//node->Stat().SecondClosestBound() << ".\n";
+
+//  if (!node->Stat().HamerlyPruned())
+    for (size_t i = 0; i < node->NumChildren(); ++i)
+      TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
+          centroids, dataset);
 }
 
 
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 15ea586..6e8cdb2 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -106,6 +106,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& referenceNode)
 {
   // This won't happen with the root since it is explicitly set to 0.
+  const size_t origPruned = referenceNode.Stat().ClustersPruned();
   if (referenceNode.Stat().ClustersPruned() == size_t(-1))
     referenceNode.Stat().ClustersPruned() =
         referenceNode.Parent()->Stat().ClustersPruned();
@@ -123,34 +124,67 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     referenceNode.Stat().MaxQueryNodeDistance() = std::min(
         referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
         referenceNode.Stat().MaxQueryNodeDistance());
+    referenceNode.Stat().SecondClosestBound() = std::min(
+        referenceNode.Parent()->Stat().SecondClosestBound(),
+        referenceNode.Stat().SecondClosestBound());
   }
 
-  double score = ElkanTypeScore(queryNode, referenceNode);
+  double score = HamerlyTypeScore(referenceNode);
+  if (score == DBL_MAX)
+  {
+    if (origPruned == size_t(-1))
+    {
+      const size_t cluster = referenceNode.Stat().Owner();
+      newCentroids.col(cluster) += referenceNode.Stat().Centroid() *
+          referenceNode.NumDescendants();
+      counts(cluster) += referenceNode.NumDescendants();
+      referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+    }
+    return DBL_MAX; // No other bookkeeping to do.
+  }
 
   if (score != DBL_MAX)
   {
-    // We also have to update things if the closest query node is null.  This
-    // can probably be improved.
-    const double minDistance = referenceNode.MinDistance(&queryNode);
-    ++distanceCalculations;
-    score = PellegMooreScore(queryNode, referenceNode, minDistance);
+    score = ElkanTypeScore(queryNode, referenceNode);
 
-    if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+    if (score != DBL_MAX)
     {
-      const double maxDistance = referenceNode.MaxDistance(&queryNode);
+      // We also have to update things if the closest query node is null.  This
+      // can probably be improved.
+      const double minDistance = referenceNode.MinDistance(&queryNode);
       ++distanceCalculations;
-      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-      referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-    }
-    else if (IsDescendantOf(*((TreeType*)
-        referenceNode.Stat().ClosestQueryNode()), queryNode))
-    {
-      const double maxDistance = referenceNode.MaxDistance(&queryNode);
-      ++distanceCalculations;
-      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-      referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+      score = PellegMooreScore(queryNode, referenceNode, minDistance);
+
+      if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+      {
+        const double maxDistance = referenceNode.MaxDistance(&queryNode);
+        // Only take the previous minimum query node distance in some
+        // circumstances.
+        if (!IsDescendantOf(*((TreeType*)
+            referenceNode.Stat().ClosestQueryNode()), queryNode) &&
+            referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX &&
+            referenceNode.Stat().MinQueryNodeDistance() <
+                referenceNode.Stat().SecondClosestBound())
+          referenceNode.Stat().SecondClosestBound() =
+              referenceNode.Stat().MinQueryNodeDistance();
+        ++distanceCalculations;
+        referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+        referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+        referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+      }
+      else if (IsDescendantOf(*((TreeType*)
+          referenceNode.Stat().ClosestQueryNode()), queryNode))
+      {
+        const double maxDistance = referenceNode.MaxDistance(&queryNode);
+        ++distanceCalculations;
+        referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+        referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+        referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+      }
+      else if (minDistance < referenceNode.Stat().SecondClosestBound())
+      {
+        referenceNode.Stat().SecondClosestBound() = minDistance;
+      }
     }
   }
 
@@ -209,20 +243,8 @@ template<typename MetricType, typename TreeType>
 double DualTreeKMeansRules<MetricType, TreeType>::HamerlyTypeScore(
     TreeType& referenceNode)
 {
-  // Does the reference node have an owner?
-  if (referenceNode.Owner() < centroids.n_cols)
-  {
-    // Has the owner stayed stationary enough and no other centroids moved
-    // enough that this owner _must_ be the continued owner?
-    if (referenceNode.MaxQueryNodeDistance() +
-        clusterDistances[referenceNode.Owner()] <
-        referenceNode.SecondClosestQueryNodeDistance() -
-        clusterDistances[centroids.n_cols])
-    {
-      return DBL_MAX;
-      // Not yet handled: when to add this to the finished counts?
-    }
-  }
+  if (referenceNode.Stat().HamerlyPruned())
+    return DBL_MAX;
 
   return 0.0;
 }



More information about the mlpack-git mailing list