[mlpack-git] master: Some fairly serious refactoring, but it works. (56cbc1c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:39 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 56cbc1c3bbf7fc08fb7b4b1b8bb607902b57a7a6
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Jan 12 17:07:50 2015 -0500

    Some fairly serious refactoring, but it works.
    
    There are significant optimizations to still be made, but at this point we get
    some acceleration over the naive algorithm.


>---------------------------------------------------------------

56cbc1c3bbf7fc08fb7b4b1b8bb607902b57a7a6
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       |   5 +-
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |   2 -
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 158 +++++----------------
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  |   2 +-
 4 files changed, 37 insertions(+), 130 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index ff6ddf8..c59b496 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -88,6 +88,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   typename TreeType::template BreadthFirstDualTreeTraverser<RulesType>
       traverser(rules);
 
+  tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
   traverser.Traverse(*centroidTree, *tree);
 
   distanceCalculations += rules.DistanceCalculations();
@@ -220,8 +221,8 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
   // We have to set the closest query node to NULL because the cluster tree will
   // be rebuilt.
   node->Stat().ClosestQueryNode() = NULL;
-//  node->Stat().MaxQueryNodeDistance() = DBL_MAX;
-//  node->Stat().MinQueryNodeDistance() = DBL_MAX;
+  node->Stat().MaxQueryNodeDistance() = DBL_MAX;
+  node->Stat().MinQueryNodeDistance() = DBL_MAX;
 
   for (size_t i = 0; i < node->NumChildren(); ++i)
     TreeUpdate(&node->Child(i), clusters, clusterDistances);
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index fe88edc..2fedea0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -68,8 +68,6 @@ class DualTreeKMeansRules
 
   TraversalInfoType traversalInfo;
 
-  double IterationUpdate(TreeType& referenceNode);
-
   bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
       potentialChild) const;
 
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index f0c4c00..b13e544 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -50,13 +50,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
     const size_t queryIndex,
     const size_t referenceIndex)
 {
-//  Log::Info << "Base case, query " << queryIndex << " (" << mappings[queryIndex]
-//      << "), reference " << referenceIndex << ".\n";
-
   // Collect the number of clusters that have been pruned during the traversal.
   // The ternary operator may not be necessary.
   const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL) ?
-//      traversalInfo.LastReferenceNode()->Stat().Iteration() == iteration) ?
       traversalInfo.LastReferenceNode()->Stat().ClustersPruned() : 0;
 
   // It's possible that the reference node has been pruned before we got to the
@@ -87,8 +83,6 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
 
   if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
   {
-//    Log::Warn << "Commit reference index " << referenceIndex << " to cluster "
-//        << assignments[referenceIndex] << ".\n";
     newCentroids.col(assignments[referenceIndex]) +=
         dataset.col(referenceIndex);
     ++counts(assignments[referenceIndex]);
@@ -99,12 +93,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
 
 template<typename MetricType, typename TreeType>
 double DualTreeKMeansRules<MetricType, TreeType>::Score(
-    const size_t /* queryIndex */,
+    const size_t queryIndex,
     TreeType& referenceNode)
 {
-  // Update from previous iteration, if necessary.
-//  IterationUpdate(referenceNode);
-
   // No pruning here, for now.
   return 0.0;
 }
@@ -114,6 +105,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
+  // This won't happen with the root since it is explicitly set to 0.
   if (referenceNode.Stat().ClustersPruned() == size_t(-1))
     referenceNode.Stat().ClustersPruned() =
         referenceNode.Parent()->Stat().ClustersPruned();
@@ -124,35 +116,34 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
 
   // We also have to update things if the closest query node is null.  This can
   // probably be improved.
-  if (score != DBL_MAX || referenceNode.Stat().ClosestQueryNode() == NULL)
+  const double minDistance = referenceNode.MinDistance(&queryNode);
+  const double maxDistance = referenceNode.MaxDistance(&queryNode);
+  distanceCalculations += 2;
+  score = PellegMooreScore(queryNode, referenceNode, minDistance);
+
+  if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX &&
+      referenceNode.Parent() != NULL &&
+      referenceNode.Parent()->Stat().MaxQueryNodeDistance() != DBL_MAX)
   {
-    // Can we update the minimum query node distance for this reference node?
-    const double minDistance = referenceNode.MinDistance(&queryNode);
-    const double maxDistance = referenceNode.MaxDistance(&queryNode);
-    distanceCalculations += 2;
-    if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance())
-    {
-      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-      referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-//          referenceNode.MaxDistance(&queryNode);
-//      ++distanceCalculations;
-      return 0.0; // Pruning is not possible.
-    }
-
-    else if (IsDescendantOf(
-        *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
-    {
-      // Just update.
-      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-      referenceNode.Stat().MaxQueryNodeDistance() =
-          referenceNode.MaxDistance(&queryNode);
-      ++distanceCalculations;
-      return 0.0; // Pruning is not possible.
-    }
+    referenceNode.Stat().ClosestQueryNode() =
+        referenceNode.Parent()->Stat().ClosestQueryNode();
+    referenceNode.Stat().MaxQueryNodeDistance() =
+        referenceNode.Parent()->Stat().MaxQueryNodeDistance();
+  }
 
-    score = PellegMooreScore(queryNode, referenceNode, minDistance);
+  if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance() ||
+      referenceNode.Stat().ClosestQueryNode() == NULL)
+  {
+    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+    referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+  }
+  else if (IsDescendantOf(*((TreeType*)
+      referenceNode.Stat().ClosestQueryNode()), queryNode))
+  {
+    referenceNode.Stat().ClosestQueryNode() == (void*) &queryNode;
+    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+    referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
   }
 
   if (score == DBL_MAX)
@@ -176,17 +167,16 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     else if (referenceNode.Stat().ClustersPruned() +
         visited[referenceNode.Descendant(0)] == centroids.n_cols)
     {
-      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+      for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
       {
-        const size_t cluster = assignments[referenceNode.Point(i)];
-        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+        const size_t cluster = assignments[referenceNode.Descendant(i)];
+        newCentroids.col(cluster) += dataset.col(referenceNode.Descendant(i));
         counts(cluster)++;
       }
     }
   }
 
   return score;
-//  return 0.0;
 }
 
 template<typename MetricType, typename TreeType>
@@ -205,87 +195,6 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
     const double oldScore) const
 {
   return oldScore;
-
-//  if (oldScore == DBL_MAX)
-//    return oldScore; // We can't unprune something.  This shouldn't happen.
-
-//  return ElkanTypeScore(queryNode, referenceNode, oldScore);
-}
-
-template<typename MetricType, typename TreeType>
-inline double DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
-    TreeType& referenceNode)
-{
-  Log::Fatal << "Update! Why!\n";
-  if (referenceNode.Stat().Iteration() == iteration)
-    return 0;
-
-  const size_t itDiff = iteration - referenceNode.Stat().Iteration();
-  referenceNode.Stat().Iteration() = iteration;
-  referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
-      0 : referenceNode.Parent()->Stat().ClustersPruned();
-  referenceNode.Stat().ClosestQueryNode() = (referenceNode.Parent() == NULL) ?
-      NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
-
-  if (referenceNode.Stat().ClosestQueryNode() != NULL)
-  {
-    referenceNode.Stat().MinQueryNodeDistance() =
-        referenceNode.MinDistance((TreeType*)
-        referenceNode.Stat().ClosestQueryNode());
-    referenceNode.Stat().MaxQueryNodeDistance() =
-        referenceNode.MaxDistance((TreeType*)
-        referenceNode.Stat().ClosestQueryNode());
-    distanceCalculations += 2;
-  }
-
-
-  if (itDiff > 1)
-  {
-//    referenceNode.Stat().BestMaxDistance() = DBL_MAX;
-    referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
-    referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
-  }
-  else
-  {
-    if (referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX)
-    {
-      // Update the distance to the closest query node.  If this node has an
-      // owner, we know how far to increase the bound.  Otherwise, increase it
-      // by the furthest amount that any centroid moved.
-      if (referenceNode.Stat().Owner() < centroids.n_cols)
-      {
-        referenceNode.Stat().MinQueryNodeDistance() +=
-            clusterDistances(referenceNode.Stat().Owner());
-        referenceNode.Stat().MaxQueryNodeDistance() +=
-            clusterDistances(referenceNode.Stat().Owner());
-      }
-      else
-      {
-        referenceNode.Stat().MinQueryNodeDistance() +=
-            clusterDistances(centroids.n_cols);
-        referenceNode.Stat().MaxQueryNodeDistance() +=
-            clusterDistances(centroids.n_cols);
-      }
-    }
-    else
-    {
-      referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
-      referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
-    }
-  }
-
-//    if (referenceNode.Stat().BestMaxDistance() != DBL_MAX)
-//    {
-//      if (referenceNode.Stat().Owner() < centroids.n_cols)
-//        referenceNode.Stat().BestMaxDistance() +=
-//            clusterDistances(referenceNode.Stat().Owner());
-//      else
-//        referenceNode.Stat().BestMaxDistance() +=
-//            clusterDistances(centroids.n_cols);
-//    }
-//  }
-
-  return 1;
 }
 
 template<typename MetricType, typename TreeType>
@@ -308,12 +217,11 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
 {
   // We have to calculate the minimum distance between the query node and the
   // reference node's best query node.  First, try to use the cached distance.
-//  const double minQueryDistance = queryNode.Stat().FirstBound();
+  const double minQueryDistance = queryNode.Stat().FirstBound();
   if (queryNode.NumDescendants() == 1)
   {
     const double score = ElkanTypeScore(queryNode, referenceNode,
         interclusterDistances[queryNode.Descendant(0)]);
-//    Log::Warn << "Elkan scoring: " << score << ".\n";
     return score;
   }
   else
@@ -344,7 +252,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
 
 template<typename MetricType, typename TreeType>
 double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
-    TreeType& /* queryNode */,
+    TreeType& queryNode,
     TreeType& referenceNode,
     const double minDistance) const
 {
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 6d394b7..4874d5b 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -20,7 +20,7 @@ class DualTreeKMeansStatistic
       closestQueryNode(NULL),
       minQueryNodeDistance(DBL_MAX),
       maxQueryNodeDistance(DBL_MAX),
-      clustersPruned(0),
+      clustersPruned(size_t(-1)),
       iteration(size_t() - 1),
       firstBound(DBL_MAX),
       secondBound(DBL_MAX),



More information about the mlpack-git mailing list