[mlpack-git] master: An attempt at making things faster. Might work? (8c7c819)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:23 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 8c7c81981ee3e89cfbf29f13a0f8e9ee4ca4cd50
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Mar 2 17:32:42 2015 -0500

    An attempt at making things faster. Might work?


>---------------------------------------------------------------

8c7c81981ee3e89cfbf29f13a0f8e9ee4ca4cd50
 .../breadth_first_dual_tree_traverser_impl.hpp     | 52 ++++++++++++++--------
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp     | 16 ++++---
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp      | 14 ++++--
 3 files changed, 53 insertions(+), 29 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index f48bf95..4513db6 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -160,25 +160,39 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
     }
     else
     {
-      // We have to recurse down both query and reference nodes.  Because the
-      // query descent order does not matter, we will go to the left query child
-      // first.  Before recursing, we have to set the traversal information
-      // correctly.
-      QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
-          queryDepth + 1, score, rule.TraversalInfo() };
-      leftChildQueue.push(fll);
-
-      QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
-          queryDepth + 1, score, rule.TraversalInfo() };
-      leftChildQueue.push(flr);
-
-      QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
-          queryDepth + 1, score, rule.TraversalInfo() };
-      rightChildQueue.push(frl);
-
-      QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
-          queryDepth + 1, score, rule.TraversalInfo() };
-      rightChildQueue.push(frr);
+      if (score >= 0.0)
+      {
+        // We have to recurse down both query and reference nodes.  Because the
+        // query descent order does not matter, we will go to the left query
+        // child first.  Before recursing, we have to set the traversal
+        // information correctly.
+        QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
+            queryDepth + 1, score, rule.TraversalInfo() };
+        leftChildQueue.push(fll);
+
+        QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
+            queryDepth + 1, score, rule.TraversalInfo() };
+        leftChildQueue.push(flr);
+
+        QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
+            queryDepth + 1, score, rule.TraversalInfo() };
+        rightChildQueue.push(frl);
+
+        QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
+            queryDepth + 1, score, rule.TraversalInfo() };
+        rightChildQueue.push(frr);
+      }
+      else
+      {
+        // Only recurse down the references.
+        QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
+            score, rule.TraversalInfo() };
+        referenceQueue.push(fl);
+
+        QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
+            score, ti };
+        referenceQueue.push(fr);
+      }
     }
   }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index f5eaf12..c788c34 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -132,7 +132,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
 
   // We won't use the AllkNN class here because we have our own set of rules.
-  //lastIterationCentroids = oldCentroids;
+  lastIterationCentroids = oldCentroids;
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
       metric, prunedPoints, oldFromNewCentroids, visited);
@@ -212,9 +212,9 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
     node.Stat().Owner() = node.Parent()->Stat().Owner();
   }
 
-
   // Exhaustive lower bound check. Sigh.
-/*  if (!prunedLastIteration)
+/*
+  if (!prunedLastIteration)
   {
     for (size_t i = 0; i < node.NumDescendants(); ++i)
     {
@@ -238,11 +238,14 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
 
       if (closest - 1e-10 > node.Stat().UpperBound())
       {
+        Log::Warn << node;
         Log::Warn << distances.t();
       Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
 "c" << node.NumDescendants() << " invalidates upper bound " <<
 node.Stat().UpperBound() << " with closest cluster distance " << closest <<
-".\n";
+", lb " << lowerBounds[node.Descendant(i)] << " n " << node.Stat().LowerBound()
+<< " pp " << prunedPoints[node.Descendant(i)] << " visited " <<
+visited[node.Descendant(i)] << ".\n";
       }
 
     if (node.NumChildren() == 0)
@@ -250,6 +253,7 @@ node.Stat().UpperBound() << " with closest cluster distance " << closest <<
       if (secondClosest + 1e-10 < std::min(lowerBounds[node.Descendant(i)],
   node.Stat().LowerBound()))
       {
+      Log::Warn << node;
       Log::Warn << distances.t();
       Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
 "c" << node.NumDescendants() << " invalidates lower bound " <<
@@ -261,8 +265,8 @@ visited[node.Descendant(i)] << ".\n";
       }
     }
   }
-  }*/
-
+  }
+*/
 
   if ((node.Stat().Pruned() == centroids.n_cols) &&
       (node.Stat().Owner() < centroids.n_cols))
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index d7890aa..4c0cfc2 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -124,7 +124,7 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   const double queryDescDist = queryNode.FurthestDescendantDistance();
   const double refParentDist = referenceNode.ParentDistance();
   const double refDescDist = referenceNode.FurthestDescendantDistance();
-  const double lastScore = traversalInfo.LastScore();
+  const double lastScore = std::abs(traversalInfo.LastScore());
   double adjustedScore;
   double score = 0.0;
 
@@ -266,6 +266,10 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
             referenceNode.Descendant(0);
       }
     }
+
+    if (distances.Hi() > queryNode.Stat().UpperBound() &&
+        referenceNode.NumDescendants() > 1 && score != DBL_MAX)
+      score = -score; // Invert score for smarter strategy.
   }
   
   // Is everything pruned?
@@ -302,14 +306,16 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
   if (oldScore == DBL_MAX)
     return DBL_MAX; // It's already pruned.
 
+  const double realScore = std::abs(oldScore);
+
   // oldScore contains the minimum distance between queryNode and referenceNode.
   // In the time since Score() has been called, the upper bound *may* have
   // tightened.  If it has tightened enough, we may prune this node now.
-  if (oldScore > queryNode.Stat().UpperBound())
+  if (realScore > queryNode.Stat().UpperBound())
   {
     // We may still be able to improve the lower bound on pruned nodes.
-    if (oldScore < queryNode.Stat().LowerBound())
-      queryNode.Stat().LowerBound() = oldScore;
+    if (realScore < queryNode.Stat().LowerBound())
+      queryNode.Stat().LowerBound() = realScore;
 
     // This assumes that reference clusters don't appear elsewhere in the tree.
     queryNode.Stat().Pruned() += referenceNode.NumDescendants();



More information about the mlpack-git mailing list