[mlpack-git] master: Coalesce the tree before kNN. Speedup! This is a fairly significant speedup, actually. (e36e25f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:42 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit e36e25f2d9aa1dd4b113a2d036563eeee15069f7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Feb 4 14:15:57 2015 -0500

    Coalesce the tree before kNN. Speedup! This is a fairly significant speedup, actually.


>---------------------------------------------------------------

e36e25f2d9aa1dd4b113a2d036563eeee15069f7
 .../tree/binary_space_tree/binary_space_tree.hpp   |  3 +
 src/mlpack/core/tree/cover_tree/cover_tree.hpp     |  2 +
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp          |  3 +
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp     | 72 ++++++++++++++++++++++
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp      | 10 +++
 src/mlpack/methods/kmeans/dtnn_statistic.hpp       | 22 ++++++-
 6 files changed, 110 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index db1ece7..6f98c6c 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -332,6 +332,9 @@ class BinarySpaceTree
    */
   BinarySpaceTree& Child(const size_t child) const;
 
+  BinarySpaceTree*& ChildPtr(const size_t child)
+  { return (child == 0) ? left : right; }
+
   //! Return the number of points in this node (0 if not a leaf).
   size_t NumPoints() const;
 
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index bad4b8b..41db6fb 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -224,6 +224,8 @@ class CoverTree
   //! Modify a particular child node.
   CoverTree& Child(const size_t index) { return *children[index]; }
 
+  CoverTree*& ChildPtr(const size_t index) { return children[index]; }
+
   //! Get the number of children.
   size_t NumChildren() const { return children.size(); }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 547842b..e6ccd58 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -108,6 +108,9 @@ class DTNNKMeans
                   const std::vector<size_t>& newFromOldCentroids);
 
   void PrecalculateCentroids(TreeType& node);
+
+  void CoalesceTree(TreeType& node, const size_t child = 0);
+  void DecoalesceTree(TreeType& node);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 40dec27..f30f85d 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -118,6 +118,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
       newFromOldCentroids[oldFromNewCentroids[i]] = i;
   }
 
+  Timer::Start("knn");
   // Find the nearest neighbors of each of the clusters.
   neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
       nns(centroidTree, centroids);
@@ -125,21 +126,29 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   arma::Mat<size_t> closestClusters; // We don't actually care about these.
   nns.Search(1, closestClusters, interclusterDistances);
   distanceCalculations += nns.BaseCases() + nns.Scores();
+  Timer::Stop("knn");
 
   if (iteration != 0)
   {
     // Do the tree update for the previous iteration.
 
     // Reset centroids and counts for things we will collect during pruning.
+    Timer::Start("it_update");
     prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
     prunedCounts.zeros(centroids.n_cols);
     UpdateTree(*tree, oldCentroids, interclusterDistances, newFromOldCentroids);
 
     PrecalculateCentroids(*tree);
+    Timer::Stop("it_update");
   }
 
+  Timer::Start("tree_mod");
+  CoalesceTree(*tree);
+  Timer::Stop("tree_mod");
+
   // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
+  Timer::Start("knn");
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, distances, metric,
       prunedPoints, oldFromNewCentroids, visited);
@@ -148,6 +157,11 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
   traverser.Traverse(*tree, *centroidTree);
+  Timer::Stop("knn");
+
+  Timer::Start("tree_mod");
+  DecoalesceTree(*tree);
+  Timer::Stop("tree_mod");
 
   Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
       rules.Scores() << " scores.\n";
@@ -482,6 +496,64 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
+    TreeType& node,
+    const size_t child /* Which child are we? */)
+{
+  // If one of the two children is pruned, we hide this node.
+  // This assumes the BinarySpaceTree.  (bad Ryan! bad!)
+  if (node.NumChildren() == 0)
+    return; // We can't do anything.
+
+  // If this is the root node, we can't coalesce.
+  if (node.Parent() != NULL)
+  {
+    if (node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
+    {
+      CoalesceTree(node.Child(1), 1);
+
+      // Link the right child to the parent.
+      node.Child(1).Parent() = node.Parent();
+      node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+    }
+    else if (!node.Child(0).Stat().Pruned() && node.Child(1).Stat().Pruned())
+    {
+      CoalesceTree(node.Child(0), 0);
+
+      // Link the left child to the parent.
+      node.Child(0).Parent() = node.Parent();
+      node.Parent()->ChildPtr(child) = node.ChildPtr(0);
+
+    }
+    else if (!node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
+    {
+      // The conditional is probably not necessary.
+      CoalesceTree(node.Child(0), 0);
+      CoalesceTree(node.Child(1), 1);
+    }
+  }
+  else
+  {
+    CoalesceTree(node.Child(0), 0);
+    CoalesceTree(node.Child(1), 1);
+  }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
+{
+  node.Parent() = (TreeType*) node.Stat().TrueParent();
+  node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
+  node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+
+  if (node.NumChildren() > 0)
+  {
+    DecoalesceTree(node.Child(0));
+    DecoalesceTree(node.Child(1));
+  }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::PrecalculateCentroids(
     TreeType& node)
 {
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 7e4e48b..a526c94 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -37,8 +37,11 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
     const size_t referenceIndex)
 {
   // We'll check if the query point has been pruned.  If so, don't continue.
+//  Log::Debug << "Base case " << queryIndex << ", " << referenceIndex <<
+//".\n";
   if (prunedPoints[queryIndex])
     return 0.0; // Returning 0 shouldn't be a problem.
+//  Log::Debug << "(not pruned.)\n";
 
   // Any base cases imply that we will get a result.
   visited[queryIndex] = true;
@@ -107,8 +110,11 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& referenceNode)
 {
   // If the query point has already been pruned, then don't recurse further.
+//  Log::Debug << "Score " << queryIndex << ", r" << referenceNode.Point(0) << "c"
+//      << referenceNode.NumDescendants() << ".\n";
   if (prunedPoints[queryIndex])
     return DBL_MAX;
+//  Log::Debug << "(not pruned)\n";
 
   return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
       MetricType, TreeType>::Score(queryIndex, referenceNode);
@@ -119,8 +125,12 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+//  Log::Debug << "Score q" << queryNode.Point(0) << "c" <<
+//queryNode.NumDescendants() << ", r" << referenceNode.Point(0) << "c" <<
+//referenceNode.NumDescendants() << ".\n";
   if (queryNode.Stat().Pruned())
     return DBL_MAX;
+//  Log::Debug << "(not pruned.)\n";
 
   // Check if the query node is Hamerly pruned, and if not, then don't continue.
   return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 5ae7f30..691a643 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -23,7 +23,10 @@ class DTNNStatistic : public
       maxClusterDistance(DBL_MAX),
       secondClusterBound(0.0),
       owner(size_t(-1)),
-      centroid()
+      centroid(),
+      trueLeft(NULL),
+      trueRight(NULL),
+      trueParent(NULL)
   {
     // Nothing to do.
   }
@@ -35,7 +38,10 @@ class DTNNStatistic : public
       iteration(0),
       maxClusterDistance(DBL_MAX),
       secondClusterBound(0.0),
-      owner(size_t(-1))
+      owner(size_t(-1)),
+      trueLeft((void*) &node.Child(0)),
+      trueRight((void*) &node.Child(1)),
+      trueParent((void*) node.Parent())
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -67,6 +73,15 @@ class DTNNStatistic : public
   const arma::vec& Centroid() const { return centroid; }
   arma::vec& Centroid() { return centroid; }
 
+  const void* TrueLeft() const { return trueLeft; }
+  void*& TrueLeft() { return trueLeft; }
+
+  const void* TrueRight() const { return trueRight; }
+  void*& TrueRight() { return trueRight; }
+
+  const void* TrueParent() const { return trueParent; }
+  void*& TrueParent() { return trueParent; }
+
   std::string ToString() const
   {
     std::ostringstream o;
@@ -86,6 +101,9 @@ class DTNNStatistic : public
   double secondClusterBound;
   size_t owner;
   arma::vec centroid;
+  void* trueLeft;
+  void* trueRight;
+  void* trueParent;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list