[mlpack-git] master: Correct Pelleg-Moore prunes that finish a node. There were cases where a Pelleg-Moore prune would happen before committing the point. This is actually getting pretty fast in terms of base cases, so I am happy with that (for once). (6300189)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:32 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 6300189ab44849d38795732ac29d8df20d379075
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Jan 29 16:59:56 2015 -0500

    Correct Pelleg-Moore prunes that finish a node. There were cases where a Pelleg-Moore prune would happen before committing the point. This is actually getting pretty fast in terms of base cases, so I am happy with that (for once).


>---------------------------------------------------------------

6300189ab44849d38795732ac29d8df20d379075
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |  1 +
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 42 ++++++++++++++--------
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |  3 +-
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 36 ++++++++++++++++---
 4 files changed, 63 insertions(+), 19 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 6b2c81f..68714cd 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -50,6 +50,7 @@ class DualTreeKMeans
   arma::vec clusterDistances;
   arma::Col<size_t> assignments;
   arma::vec distances;
+  arma::Col<size_t> visited;
   arma::Col<size_t> distanceIteration;
 
   //! The current iteration.
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index dcf0aa5..27d2fd0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -30,6 +30,7 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
   distances.set_size(dataset.n_cols);
   distances.fill(DBL_MAX);
   assignments.zeros(dataset.n_cols);
+  visited.zeros(dataset.n_cols);
   distanceIteration.zeros(dataset.n_cols);
 
   Timer::Start("tree_building");
@@ -89,9 +90,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // Now run the dual-tree algorithm.
   typedef DualTreeKMeansRules<MetricType, TreeType> RulesType;
+  visited.zeros(dataset.n_cols);
   RulesType rules(dataset, centroids, newCentroids, counts, oldFromNewCentroids,
-      iteration, clusterDistances, distances, assignments, distanceIteration,
-      interclusterDistances, metric);
+      iteration, clusterDistances, distances, assignments, visited,
+      distanceIteration, interclusterDistances, metric);
 
   // Use the dual-tree traverser.
 //typename TreeType::template DualTreeTraverser<RulesType> traverser(rules);
@@ -99,9 +101,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
       traverser(rules);
 
   tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
-  Log::Info << "Traversal begins.\n";
   traverser.Traverse(*centroidTree, *tree);
-  Log::Info << "Traversal done.\n";
 
   distanceCalculations += rules.DistanceCalculations();
 
@@ -131,12 +131,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   size_t hamerlyPruned = 0;
   size_t hamerlyPrunedNodes = 0;
   size_t totalNodes = 0;
-  Log::Info << "Update tree.\n";
   UpdateOwner(tree, centroids.n_cols, assignments);
   TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
       oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned,
       hamerlyPrunedNodes, totalNodes);
-  Log::Info << "Update tree done.\n";
 
   delete centroidTree;
 
@@ -241,19 +239,30 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
   node->Stat().HamerlyPruned() = false;
   ++totalNodes;
 
+/*
+  for (size_t i = 0; i < node->NumPoints(); ++i)
+  {
+    if (!prunedLastIteration &&
+        distanceIteration[node->Point(i)] < iteration)
+      Log::Warn << "Point " << node->Point(i) << " was never visited!"
+<< " (" << distanceIteration[node->Point(i)] << ", " << prunedLastIteration
+<< ")\n";
+    if (!prunedLastIteration &&
+        node->Stat().ClustersPruned() + visited[node->Point(i)] < clusters)
+      Log::Fatal << "Point " << node->Point(i) << " was only visited " <<
+node->Stat().ClustersPruned() << " + " << visited[node->Point(i)] << 
+" times!\n";
+  }
+*/
+
   // The easy case: this node had an owner.
   if (node->Stat().Owner() < clusters)
   {
 /*
+    if (prunedLastIteration && node->Stat().MaxQueryNodeDistance() == DBL_MAX)
+      Log::Fatal << "r" << node->Begin() << "c" << node->Count() << " was "
+          << "Hamerly pruned but was not visited!\n";
     // Verify correctness...
-    for (size_t i = 0; i < node->NumPoints(); ++i)
-    {
-      if (!prunedLastIteration &&
-          distanceIteration[node->Descendant(i)] < iteration)
-        Log::Fatal << "Point " << node->Descendant(i) << " was never visited!"
-<< " (" << distanceIteration[node->Descendant(i)] << ", " << prunedLastIteration
-<< ")\n";
-    }
     for (size_t i = 0; i < node->NumDescendants(); ++i)
     {
       size_t closest = clusters;
@@ -319,6 +328,11 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
     {
       if (node->Parent() != NULL && node->Parent()->Stat().HamerlyPruned())
       {
+//        Log::Warn << "Extra prune via parent: r" << node->Begin() << "c" <<
+//node->Count() << ".\n";
+        if (node->Stat().Owner() != node->Parent()->Stat().Owner())
+          Log::Fatal << "Holy shit!\n";
+
         node->Stat().HamerlyPruned() = true;
         node->Stat().MinQueryNodeDistance() = DBL_MAX;
       }
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index e47651e..8cf4ef5 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -23,6 +23,7 @@ class DualTreeKMeansRules
                       const arma::vec& clusterDistances,
                       arma::vec& distances,
                       arma::Col<size_t>& assignments,
+                      arma::Col<size_t>& visited,
                       arma::Col<size_t>& distanceIteration,
                       const arma::mat& interclusterDistances,
                       MetricType& metric);
@@ -59,7 +60,7 @@ class DualTreeKMeansRules
   const arma::vec& clusterDistances;
   arma::vec& distances;
   arma::Col<size_t>& assignments;
-  arma::Col<size_t> visited;
+  arma::Col<size_t>& visited;
   arma::Col<size_t>& distanceIteration;
   const arma::mat& interclusterDistances;
   MetricType& metric;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 3064d88..7857d88 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -24,6 +24,7 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
     const arma::vec& clusterDistances,
     arma::vec& distances,
     arma::Col<size_t>& assignments,
+    arma::Col<size_t>& visited,
     arma::Col<size_t>& distanceIteration,
     const arma::mat& interclusterDistances,
     MetricType& metric) :
@@ -36,14 +37,12 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
     clusterDistances(clusterDistances),
     distances(distances),
     assignments(assignments),
+    visited(visited),
     distanceIteration(distanceIteration),
     interclusterDistances(interclusterDistances),
     metric(metric),
     distanceCalculations(0)
-{
-  // Nothing has been visited yet.
-  visited.zeros(dataset.n_cols);
-}
+{ }
 
 template<typename MetricType, typename TreeType>
 inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
@@ -57,6 +56,10 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
 
   // It's possible that the reference node has been pruned before we got to the
   // base case.  In that case, don't do the base case, and just return.
+//  if (referenceIndex == 37447)
+//    Log::Warn << "Visit " << referenceIndex << ", q" << queryIndex << ".  " <<
+//traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
+//visited[referenceIndex] << ".\n";
   if (traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
       visited[referenceIndex] == centroids.n_cols)
     return 0.0;
@@ -86,6 +89,7 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
     newCentroids.col(assignments[referenceIndex]) +=
         dataset.col(referenceIndex);
     ++counts(assignments[referenceIndex]);
+//    Log::Warn << "Commit base case " << referenceIndex << ".\n";
   }
 
   return distance;
@@ -105,6 +109,12 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+//  if (referenceNode.Begin() == 33313 || referenceNode.Begin() == 37121 ||
+//  if (referenceNode.Begin() == 37447)
+//    Log::Warn << "Visit r" << referenceNode.Begin() << "c" <<
+//referenceNode.Count() << ", q" << queryNode.Begin() << "c" << queryNode.Count()
+//<< ":\n" << referenceNode.Stat();
+
   // This won't happen with the root since it is explicitly set to 0.
   if (referenceNode.Stat().ClustersPruned() == size_t(-1))
     referenceNode.Stat().ClustersPruned() =
@@ -150,7 +160,25 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
   }
   else if (distances.Lo() > referenceNode.Stat().SecondMaxQueryNodeDistance())
   {
+//    if (referenceNode.Begin() == 37447)
+//      Log::Warn << "Pelleg-Moore pruned.\n";
     referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+    // Is everything pruned?  Then commit the points.
+    if (referenceNode.Stat().ClustersPruned() +
+        visited[referenceNode.Descendant(0)] == centroids.n_cols)
+    {
+//      Log::Warn << "Commit points in r" << referenceNode.Begin() << "c" <<
+//referenceNode.Count() << ".\n";
+      for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
+      {
+        const size_t index = referenceNode.Descendant(i);
+        const size_t cluster = assignments[index];
+        referenceNode.Stat().Owner() = cluster;
+        newCentroids.col(cluster) += dataset.col(index);
+        ++counts(cluster);
+      }
+    }
     return DBL_MAX;
   }
 



More information about the mlpack-git mailing list