[mlpack-git] master: Better speedups, provide more output on prunes. (c413ad2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:18 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit c413ad23996be0c7cf504700da17b7cf41b6a3e2
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Jan 29 14:29:56 2015 -0500

    Better speedups, provide more output on prunes.


>---------------------------------------------------------------

c413ad23996be0c7cf504700da17b7cf41b6a3e2
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |  4 +-
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 69 +++++++++++++++++++---
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 17 +-----
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  | 16 ++---
 4 files changed, 74 insertions(+), 32 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 1d26273..6b2c81f 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -72,7 +72,9 @@ class DualTreeKMeans
                   const arma::mat& oldCentroids,
                   const arma::mat& dataset,
                   const std::vector<size_t>& oldFromNew,
-                  size_t& hamerlyPruned);
+                  size_t& hamerlyPruned,
+                  size_t& hamerlyPrunedNodes,
+                  size_t& totalNodes);
 };
 
 template<typename MetricType, typename MatType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 756cceb..dcf0aa5 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -99,7 +99,9 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
       traverser(rules);
 
   tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
+  Log::Info << "Traversal begins.\n";
   traverser.Traverse(*centroidTree, *tree);
+  Log::Info << "Traversal done.\n";
 
   distanceCalculations += rules.DistanceCalculations();
 
@@ -127,9 +129,14 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // Update the tree with the centroid movement information.
   size_t hamerlyPruned = 0;
+  size_t hamerlyPrunedNodes = 0;
+  size_t totalNodes = 0;
+  Log::Info << "Update tree.\n";
   UpdateOwner(tree, centroids.n_cols, assignments);
   TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
-      oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned);
+      oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned,
+      hamerlyPrunedNodes, totalNodes);
+  Log::Info << "Update tree done.\n";
 
   delete centroidTree;
 
@@ -224,18 +231,29 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     const arma::mat& centroids,
     const arma::mat& dataset,
     const std::vector<size_t>& oldFromNew,
-    size_t& hamerlyPruned)
+    size_t& hamerlyPruned,
+    size_t& hamerlyPrunedNodes,
+    size_t& totalNodes)
 {
   // This is basically IterationUpdate(), but pulled out to be separate from the
   // actual dual-tree algorithm.
   const bool prunedLastIteration = node->Stat().HamerlyPruned();
   node->Stat().HamerlyPruned() = false;
+  ++totalNodes;
 
   // The easy case: this node had an owner.
   if (node->Stat().Owner() < clusters)
   {
+/*
     // Verify correctness...
-    /*
+    for (size_t i = 0; i < node->NumPoints(); ++i)
+    {
+      if (!prunedLastIteration &&
+          distanceIteration[node->Descendant(i)] < iteration)
+        Log::Fatal << "Point " << node->Descendant(i) << " was never visited!"
+<< " (" << distanceIteration[node->Descendant(i)] << ", " << prunedLastIteration
+<< ")\n";
+    }
     for (size_t i = 0; i < node->NumDescendants(); ++i)
     {
       size_t closest = clusters;
@@ -262,7 +280,7 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
 ".\n";
       }
     }
-    */
+*/
 
     // During the last iteration, this node was pruned.
     const size_t owner = node->Stat().Owner();
@@ -293,8 +311,27 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
         node->Stat().HamerlyPruned() = true;
         if (!node->Parent()->Stat().HamerlyPruned())
           hamerlyPruned += node->NumDescendants();
+        ++hamerlyPrunedNodes;
+      }
+    }
+
+    if (!node->Stat().HamerlyPruned())
+    {
+      if (node->Parent() != NULL && node->Parent()->Stat().HamerlyPruned())
+      {
+        node->Stat().HamerlyPruned() = true;
+        node->Stat().MinQueryNodeDistance() = DBL_MAX;
+      }
+      else
+      {
+        if (node->Stat().SecondMaxQueryNodeDistance() != DBL_MAX)
+          node->Stat().SecondMaxQueryNodeDistance() += clusterDistances[clusters];
+        if (node->Stat().SecondMinQueryNodeDistance() != DBL_MAX)
+          node->Stat().SecondMinQueryNodeDistance() += clusterDistances[clusters];
       }
     }
+    else
+      node->Stat().MinQueryNodeDistance() = DBL_MAX;
   }
   else
   {
@@ -306,6 +343,10 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
       node->Stat().MaxQueryNodeDistance() += clusterDistances[clusters];
     if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
       node->Stat().MinQueryNodeDistance() += clusterDistances[clusters];
+    if (node->Stat().SecondMaxQueryNodeDistance() != DBL_MAX)
+      node->Stat().SecondMaxQueryNodeDistance() += clusterDistances[clusters];
+    if (node->Stat().SecondMinQueryNodeDistance() != DBL_MAX)
+      node->Stat().SecondMinQueryNodeDistance() += clusterDistances[clusters];
 
     // Since the node didn't have an owner, it can't be Hamerly pruned.
     node->Stat().HamerlyPruned() = false;
@@ -317,7 +358,8 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
   for (size_t i = 0; i < node->NumChildren(); ++i)
   {
     TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
-        centroids, dataset, oldFromNew, hamerlyPruned);
+        centroids, dataset, oldFromNew, hamerlyPruned, hamerlyPrunedNodes,
+        totalNodes);
     if (!node->Child(i).Stat().HamerlyPruned())
       allPruned = false;
     else if (owner == clusters)
@@ -330,30 +372,41 @@ closest << "!  It's part of node r" << node->Begin() << "c" << node->Count() <<
     allPruned = false;
 
   if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
+  {
+    node->Stat().MinQueryNodeDistance() = DBL_MAX;
     node->Stat().HamerlyPruned() = true;
+    hamerlyPrunedNodes++;
+  }
 
   node->Stat().Iteration() = iteration;
   node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;
   // We have to set the closest query node to NULL because the cluster tree will
   // be rebuilt.
-  node->Stat().ClosestQueryNode() = NULL;
+//  node->Stat().ClosestQueryNode() = NULL;
 
   if (prunedLastIteration)
     node->Stat().LastSecondClosestBound() -= clusterDistances[clusters];
   else
     node->Stat().LastSecondClosestBound() =
         node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters];
+//  node->Stat().MinQueryNodeDistance() = DBL_MAX;
   node->Stat().MinQueryNodeDistance() = DBL_MAX;
+  node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
   if (prunedLastIteration && !node->Stat().HamerlyPruned())
+  {
     node->Stat().MaxQueryNodeDistance() = DBL_MAX;
-  node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
-  node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
+    node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
+  }
   // This should change later, but I'm not yet sure how to do it.
 //  node->Stat().SecondClosestBound() = DBL_MAX;
 //  node->Stat().SecondClosestQueryNode() = NULL;
 
   if (node->Parent() == NULL)
+  {
     Log::Info << "Total Hamerly pruned points: " << hamerlyPruned << ".\n";
+    Log::Info << "Total pruned Hamerly nodes: " << hamerlyPrunedNodes << ".\n";
+    Log::Info << "Total nodes in tree: " << totalNodes << ".\n";
+  }
 }
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index ced3ffa..3064d88 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -128,46 +128,33 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
   // Calculate distance to node.
   // This costs about the same (in terms of runtime) as a single MinDistance()
   // call, so there only need to add one distance computation.
-  math::Range distances = referenceNode.RangeDistance(&queryNode);
+  const math::Range distances = referenceNode.RangeDistance(&queryNode);
   ++distanceCalculations;
 
   // Is this closer than the current best query node?
   if (distances.Lo() < referenceNode.Stat().MinQueryNodeDistance())
   {
     // This is the new closest node.
-    referenceNode.Stat().SecondClosestQueryNode() =
-        referenceNode.Stat().ClosestQueryNode();
     referenceNode.Stat().SecondMinQueryNodeDistance() =
         referenceNode.Stat().MinQueryNodeDistance();
     referenceNode.Stat().SecondMaxQueryNodeDistance() =
         referenceNode.Stat().MaxQueryNodeDistance();
-    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
     referenceNode.Stat().MinQueryNodeDistance() = distances.Lo();
     referenceNode.Stat().MaxQueryNodeDistance() = distances.Hi();
   }
   else if (distances.Lo() < referenceNode.Stat().SecondMinQueryNodeDistance())
   {
     // This is the new second closest node.
-    referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
     referenceNode.Stat().SecondMinQueryNodeDistance() = distances.Lo();
     referenceNode.Stat().SecondMaxQueryNodeDistance() = distances.Hi();
   }
   else if (distances.Lo() > referenceNode.Stat().SecondMaxQueryNodeDistance())
   {
-    // This is a Pelleg-Moore type prune.
-//    Log::Warn << "Pelleg-Moore prune: " << distances.Lo() << "/" <<
-//distances.Hi() << ", r" << referenceNode.Begin() << "c" << referenceNode.Count()
-//<< ", q" << queryNode.Begin() << "c" << queryNode.Count() << "; mQND " <<
-//referenceNode.Stat().MinQueryNodeDistance() << ", MQND " <<
-//referenceNode.Stat().MaxQueryNodeDistance() << ", smQND " <<
-//referenceNode.Stat().SecondMinQueryNodeDistance() << ", sMQND " <<
-//referenceNode.Stat().SecondMaxQueryNodeDistance() << ".\n";
-
     referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
     return DBL_MAX;
   }
 
-  return 0.0; // No pruning allowed at this time.
+  return distances.Lo(); // No pruning allowed at this time.
 }
 
 template<typename MetricType, typename TreeType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index c42eabe..8d762fd 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -17,8 +17,8 @@ class DualTreeKMeansStatistic
 
   template<typename TreeType>
   DualTreeKMeansStatistic(TreeType& node) :
-      closestQueryNode(NULL),
-      secondClosestQueryNode(NULL),
+//      closestQueryNode(NULL),
+//      secondClosestQueryNode(NULL),
       minQueryNodeDistance(DBL_MAX),
       maxQueryNodeDistance(DBL_MAX),
       secondMinQueryNodeDistance(DBL_MAX),
@@ -52,14 +52,14 @@ class DualTreeKMeansStatistic
   arma::vec& Centroid() { return centroid; }
 
   //! Get the current closest query node.
-  void* ClosestQueryNode() const { return closestQueryNode; }
+//  void* ClosestQueryNode() const { return closestQueryNode; }
   //! Modify the current closest query node.
-  void*& ClosestQueryNode() { return closestQueryNode; }
+//  void*& ClosestQueryNode() { return closestQueryNode; }
 
   //! Get the second closest query node.
-  void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+//  void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
   //! Modify the second closest query node.
-  void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+//  void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
 
   //! Get the minimum distance to the closest query node.
   double MinQueryNodeDistance() const { return minQueryNodeDistance; }
@@ -136,9 +136,9 @@ class DualTreeKMeansStatistic
   arma::vec centroid;
 
   //! The current closest query node to this reference node.
-  void* closestQueryNode;
+//  void* closestQueryNode;
   //! The second closest query node.
-  void* secondClosestQueryNode;
+//  void* secondClosestQueryNode;
   //! The minimum distance to the closest query node.
   double minQueryNodeDistance;
   //! The maximum distance to the closest query node.



More information about the mlpack-git mailing list