[mlpack-git] master: Refactor Elkan-type prune into its own method, for simplicity. (02950dc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:02:46 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 02950dc782feeca9c6e0e47db864b3fbb0b3e39e
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Nov 7 20:54:39 2014 +0000

    Refactor Elkan-type prune into its own method, for simplicity.


>---------------------------------------------------------------

02950dc782feeca9c6e0e47db864b3fbb0b3e39e
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |  33 +++++
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 143 +++++++++------------
 2 files changed, 94 insertions(+), 82 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index e9320d1..4a54192 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -70,6 +70,39 @@ class DualTreeKMeansRules
 
   bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
       potentialChild) const;
+
+  /**
+   * See if an Elkan-type prune can be performed.  If so, return DBL_MAX;
+   * otherwise, return a score.  The Elkan-type prune can occur when the minimum
+   * distance between the query node and the current best query node for the
+   * reference node (referenceNode.Stat().ClosestQueryNode()) is greater than
+   * two times the maximum distance between the reference node and the current
+   * best query node (again, referenceNode.Stat().ClosestQueryNode()).
+   *
+   * @param queryNode Query node.
+   * @param referenceNode Reference node.
+   */
+  double ElkanTypeScore(TreeType& queryNode, TreeType& referenceNode) const;
+
+  /**
+   * See if an Elkan-type prune can be performed.  If so, return DBL_MAX;
+   * otherwise, return a score.  The Elkan-type prune can occur when the minimum
+   * distance between the query node and the current best query node for the
+   * reference node (referenceNode.Stat().ClosestQueryNode()) is greater than
+   * two times the maximum distance between the reference node and the current
+   * best query node (again, referenceNode.Stat().ClosestQueryNode()).
+   *
+   * This particular overload is for when the minimum distance between the query
+   * noed and the current best query node has already been calculated.
+   *
+   * @param queryNode Query node.
+   * @param referenceNode Reference node.
+   * @param minQueryDistance Minimum distance between query node and current
+   *      best query node for the reference node.
+   */
+  double ElkanTypeScore(TreeType& queryNode,
+                        TreeType& referenceNode,
+                        const double minQueryDistance) const;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index adcedad..1e352de 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -140,48 +140,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     return 0.0; // Pruning is not possible.
   }
 
-  // See if we can do an Elkan-type prune on between-centroid distances.
-  const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
-  const double minQueryDistance = queryNode.MinDistance((TreeType*)
-      referenceNode.Stat().ClosestQueryNode());
   ++distanceCalculations;
 
-  if (minQueryDistance > 2.0 * maxDistance)
-  {
-    // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
-    // means that N_q cannot possibly hold any clusters that own any points in
-    // N_r.
-    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
-
-    // Have we pruned everything?
-    if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
-    {
-      // Then the best query node must contain just one point.
-      const TreeType* bestQueryNode = (TreeType*)
-          referenceNode.Stat().ClosestQueryNode();
-      const size_t cluster = mappings[bestQueryNode->Descendant(0)];
-
-      referenceNode.Stat().Owner() = cluster;
-      newCentroids.col(cluster) += referenceNode.NumDescendants() *
-          referenceNode.Stat().Centroid();
-      counts(cluster) += referenceNode.NumDescendants();
-      referenceNode.Stat().ClustersPruned()++;
-    }
-    else if (referenceNode.Stat().ClustersPruned() +
-        visited[referenceNode.Descendant(0)] == centroids.n_cols)
-    {
-      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
-      {
-        const size_t cluster = assignments[referenceNode.Point(i)];
-        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
-        counts(cluster)++;
-      }
-    }
-
-    return DBL_MAX;
-  }
-
-  return minQueryDistance;
+  return ElkanTypeScore(queryNode, referenceNode);
 }
 
 template<typename MetricType, typename TreeType>
@@ -202,48 +163,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
   if (oldScore == DBL_MAX)
     return oldScore; // We can't unprune something.  This shouldn't happen.
 
-  // Can we update the minimum query node distance for this reference node?
-  const double minQueryDistance = oldScore;
-
-  // See if we can do an Elkan-type prune on between-centroid distances.
-  const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
-
-  if (minQueryDistance > 2.0 * maxDistance)
-  {
-    // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
-    // means that N_q cannot possibly hold any clusters that own any points in
-    // N_r.
-    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
-
-    // Have we pruned everything?
-    if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
-    {
-      // Then the best query node must contain just one point.
-      const TreeType* bestQueryNode = (TreeType*)
-          referenceNode.Stat().ClosestQueryNode();
-      const size_t cluster = mappings[bestQueryNode->Descendant(0)];
-
-      referenceNode.Stat().Owner() = cluster;
-      newCentroids.col(cluster) += referenceNode.NumDescendants() *
-          referenceNode.Stat().Centroid();
-      counts(cluster) += referenceNode.NumDescendants();
-      referenceNode.Stat().ClustersPruned()++;
-    }
-    else if (referenceNode.Stat().ClustersPruned() +
-        visited[referenceNode.Descendant(0)] == centroids.n_cols)
-    {
-      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
-      {
-        const size_t cluster = assignments[referenceNode.Point(i)];
-        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
-        counts(cluster)++;
-      }
-    }
-
-    return DBL_MAX;
-  }
-
-  return oldScore;
+  return ElkanTypeScore(queryNode, referenceNode, oldScore);
 }
 
 template<typename MetricType, typename TreeType>
@@ -311,6 +231,65 @@ bool DualTreeKMeansRules<MetricType, TreeType>::IsDescendantOf(
     return IsDescendantOf(potentialParent, *potentialChild.Parent());
 }
 
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
+    TreeType& queryNode,
+    TreeType& referenceNode) const
+{
+  // We have to calculate the minimum distance between the query node and the
+  // reference node's best query node.
+  const double minQueryDistance = queryNode.MinDistance((TreeType*)
+      referenceNode.Stat().ClosestQueryNode());
+  return ElkanTypeScore(queryNode, referenceNode, minQueryDistance);
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
+    TreeType& queryNode,
+    TreeType& referenceNode,
+    const double minQueryDistance) const
+{
+  // See if we can do an Elkan-type prune on between-centroid distances.
+  const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+
+  if (minQueryDistance > 2.0 * maxDistance)
+  {
+    // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+    // means that N_q cannot possibly hold any clusters that own any points in
+    // N_r.
+    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+    // Have we pruned everything?
+    if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+    {
+      // Then the best query node must contain just one point.
+      const TreeType* bestQueryNode = (TreeType*)
+          referenceNode.Stat().ClosestQueryNode();
+      const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+      referenceNode.Stat().Owner() = cluster;
+      newCentroids.col(cluster) += referenceNode.NumDescendants() *
+          referenceNode.Stat().Centroid();
+      counts(cluster) += referenceNode.NumDescendants();
+      referenceNode.Stat().ClustersPruned()++;
+    }
+    else if (referenceNode.Stat().ClustersPruned() +
+        visited[referenceNode.Descendant(0)] == centroids.n_cols)
+    {
+      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+      {
+        const size_t cluster = assignments[referenceNode.Point(i)];
+        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+        counts(cluster)++;
+      }
+    }
+
+    return DBL_MAX;
+  }
+
+  return minQueryDistance;
+}
+
 } // namespace kmeans
 } // namespace mlpack
 



More information about the mlpack-git mailing list