[mlpack-git] master: Refactor to perform allknn on the clusters. This is done at the start of each iteration. This allows Elkan pruning without an additional distance calculation. (347b8c0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:33 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 347b8c08aab02e007fd56e0203a729e482f2c745
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jan 7 15:05:35 2015 -0500

    Refactor to perform allknn on the clusters. This is done at the start of each iteration. This allows Elkan pruning without an additional distance calculation.


>---------------------------------------------------------------

347b8c08aab02e007fd56e0203a729e482f2c745
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 11 ++-
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |  2 +
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 82 ++++++++++++----------
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  | 38 +++++++++-
 4 files changed, 94 insertions(+), 39 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 0937cb8..ff6ddf8 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -11,6 +11,8 @@
 #include "dual_tree_kmeans.hpp"
 #include "dual_tree_kmeans_rules.hpp"
 
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
 namespace mlpack {
 namespace kmeans {
 
@@ -68,11 +70,18 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
 
+  // Now calculate distances between centroids.
+  neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
+      nns(centroidTree, centroids);
+  arma::mat interclusterDistances;
+  arma::Mat<size_t> closestClusters; // We don't actually care about these.
+  nns.Search(1, closestClusters, interclusterDistances);
+
   // Now run the dual-tree algorithm.
   typedef DualTreeKMeansRules<MetricType, TreeType> RulesType;
   RulesType rules(dataset, centroids, newCentroids, counts, oldFromNewCentroids,
       iteration, clusterDistances, distances, assignments, distanceIteration,
-      metric);
+      interclusterDistances, metric);
 
   // Use the dual-tree traverser.
 //typename TreeType::template DualTreeTraverser<RulesType> traverser(rules);
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index 24cad31..fe88edc 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -24,6 +24,7 @@ class DualTreeKMeansRules
                       arma::vec& distances,
                       arma::Col<size_t>& assignments,
                       arma::Col<size_t>& distanceIteration,
+                      const arma::mat& interclusterDistances,
                       MetricType& metric);
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
@@ -60,6 +61,7 @@ class DualTreeKMeansRules
   arma::Col<size_t>& assignments;
   arma::Col<size_t> visited;
   arma::Col<size_t>& distanceIteration;
+  const arma::mat& interclusterDistances;
   MetricType& metric;
 
   size_t distanceCalculations;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 5678521..aa14d9d 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -25,6 +25,7 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
     arma::vec& distances,
     arma::Col<size_t>& assignments,
     arma::Col<size_t>& distanceIteration,
+    const arma::mat& interclusterDistances,
     MetricType& metric) :
     dataset(dataset),
     centroids(centroids),
@@ -36,6 +37,7 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
     distances(distances),
     assignments(assignments),
     distanceIteration(distanceIteration),
+    interclusterDistances(interclusterDistances),
     metric(metric),
     distanceCalculations(0)
 {
@@ -112,48 +114,46 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
-//  if (IterationUpdate(referenceNode) == DBL_MAX)
-//  {
-    // The iteration update showed that the owner could not possibly change.
-//    return DBL_MAX;
-//  }
-
   if (referenceNode.Stat().ClustersPruned() == size_t(-1))
     referenceNode.Stat().ClustersPruned() =
         referenceNode.Parent()->Stat().ClustersPruned();
 
   traversalInfo.LastReferenceNode() = &referenceNode;
 
-  // Can we update the minimum query node distance for this reference node?
-  const double minDistance = referenceNode.MinDistance(&queryNode);
-  const double maxDistance = referenceNode.MaxDistance(&queryNode);
-  distanceCalculations += 2;
-  if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance())
-  {
-    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-    referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-        referenceNode.MaxDistance(&queryNode);
-//    ++distanceCalculations;
-    return 0.0; // Pruning is not possible.
-  }
+  double score = ElkanTypeScore(queryNode, referenceNode);
 
-  else if (IsDescendantOf(
-      *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
+  // We also have to update things if the closest query node is null.  This can
+  // probably be improved.
+  if (score != DBL_MAX || referenceNode.Stat().ClosestQueryNode() == NULL)
   {
-    // Just update.
-    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-    referenceNode.Stat().MaxQueryNodeDistance() =
-        referenceNode.MaxDistance(&queryNode);
-    ++distanceCalculations;
-    return 0.0; // Pruning is not possible.
-  }
+    // Can we update the minimum query node distance for this reference node?
+    const double minDistance = referenceNode.MinDistance(&queryNode);
+    const double maxDistance = referenceNode.MaxDistance(&queryNode);
+    distanceCalculations += 2;
+    if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance())
+    {
+      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+      referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+//          referenceNode.MaxDistance(&queryNode);
+//      ++distanceCalculations;
+      return 0.0; // Pruning is not possible.
+    }
 
-//  double score = ElkanTypeScore(queryNode, referenceNode);
-//  if (score != DBL_MAX)
+    else if (IsDescendantOf(
+        *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
+    {
+      // Just update.
+      referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+      referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+      referenceNode.Stat().MaxQueryNodeDistance() =
+          referenceNode.MaxDistance(&queryNode);
+      ++distanceCalculations;
+      return 0.0; // Pruning is not possible.
+    }
 
-  double score = PellegMooreScore(queryNode, referenceNode, minDistance);
+    score = PellegMooreScore(queryNode, referenceNode, minDistance);
+  }
 
   if (score == DBL_MAX)
   {
@@ -307,11 +307,17 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
     TreeType& referenceNode)
 {
   // We have to calculate the minimum distance between the query node and the
-  // reference node's best query node.
-  const double minQueryDistance = queryNode.MinDistance((TreeType*)
-      referenceNode.Stat().ClosestQueryNode());
-  ++distanceCalculations;
-  return ElkanTypeScore(queryNode, referenceNode, minQueryDistance);
+  // reference node's best query node.  First, try to use the cached distance.
+//  const double minQueryDistance = queryNode.Stat().FirstBound();
+  if (queryNode.NumDescendants() == 1)
+  {
+    const double score = ElkanTypeScore(queryNode, referenceNode,
+        interclusterDistances[queryNode.Descendant(0)]);
+//    Log::Warn << "Elkan scoring: " << score << ".\n";
+    return score;
+  }
+  else
+    return 0.0;
 }
 
 template<typename MetricType, typename TreeType>
@@ -322,6 +328,8 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
 {
   // See if we can do an Elkan-type prune on between-centroid distances.
   const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+  if (maxDistance == DBL_MAX)
+    return minQueryDistance;
 
   if (minQueryDistance > 2.0 * maxDistance)
   {
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 21481da..6d394b7 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -21,7 +21,12 @@ class DualTreeKMeansStatistic
       minQueryNodeDistance(DBL_MAX),
       maxQueryNodeDistance(DBL_MAX),
       clustersPruned(0),
-      iteration(size_t() - 1)
+      iteration(size_t() - 1),
+      firstBound(DBL_MAX),
+      secondBound(DBL_MAX),
+      bound(DBL_MAX),
+      lastDistanceNode(NULL),
+      lastDistance(0.0)
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -70,6 +75,29 @@ class DualTreeKMeansStatistic
   //! Modify the current owner (if any) of these reference points.
   size_t& Owner() { return owner; }
 
+  // For nearest neighbor search.
+
+  //! Get the first bound.
+  double FirstBound() const { return firstBound; }
+  //! Modify the first bound.
+  double& FirstBound() { return firstBound; }
+  //! Get the second bound.
+  double SecondBound() const { return secondBound; }
+  //! Modify the second bound.
+  double& SecondBound() { return secondBound; }
+  //! Get the overall bound.
+  double Bound() const { return bound; }
+  //! Modify the overall bound.
+  double& Bound() { return bound; }
+  //! Get the last distance evaluation node.
+  void* LastDistanceNode() const { return lastDistanceNode; }
+  //! Modify the last distance evaluation node.
+  void*& LastDistanceNode() { return lastDistanceNode; }
+  //! Get the last distance calculation.
+  double LastDistance() const { return lastDistance; }
+  //! Modify the last distance calculation.
+  double& LastDistance() { return lastDistance; }
+
  private:
   //! The empirically calculated centroid of the node.
   arma::vec centroid;
@@ -88,6 +116,14 @@ class DualTreeKMeansStatistic
   //! The owner of these reference nodes (centroids.n_cols if there is no
   //! owner).
   size_t owner;
+
+  // For nearest neighbor search.
+
+  double firstBound;
+  double secondBound;
+  double bound;
+  void* lastDistanceNode;
+  double lastDistance;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list