[mlpack-git] master: Slight refactoring of dual-tree k-means rules. (c63a518)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:31 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit c63a518c8b6bf01e393028c5e91c7f7a95df662a
Author: ryan <ryan at ratml.org>
Date:   Thu Jan 1 18:56:00 2015 -0500

    Slight refactoring of dual-tree k-means rules.
    
    This is an attempt to clean up this code because it quickly became unwieldy and
    hard to work with.  A more modular approach could lead to a better, faster
    solution.


>---------------------------------------------------------------

c63a518c8b6bf01e393028c5e91c7f7a95df662a
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 75 ++++++++++++++++++++++
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |  2 +-
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 57 +++++++++++-----
 3 files changed, 118 insertions(+), 16 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index fe63498..088d565 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -112,6 +112,81 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   return std::sqrt(residual);
 }
 
+/*
+template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::ClusterTreeUpdate(
+    TreeType* node)
+{
+  // We will abuse stat.owner to hold the cluster with the most change.
+  // stat.minQueryNodeDistance will hold the distance.
+  double maxChange = 0.0;
+  size_t maxChangeCluster = 0;
+
+  for (size_t i = 0; i < node->NumChildren(); ++i)
+  {
+    ClusterTreeUpdate(&node->Child(i));
+
+    const double nodeChange = node->Child(i).Stat().MinQueryNodeDistance();
+    if (nodeChange > maxChange)
+    {
+      maxChange = nodeChange;
+      maxChangeCluster = node->Child(i).Stat().Owner();
+    }
+  }
+
+  for (size_t i = 0; i < node->NumPoints(); ++i)
+  {
+    const size_t cluster = oldFromNewCentroids[node->Point(i)];
+    const double pointChange = clusterDistances[cluster];
+    if (pointChange > maxChange)
+    {
+      maxChange = pointChange;
+      maxChangeCluster = cluster;
+    }
+  }
+
+  node->Stat().Owner() = maxChangeCluster;
+  node->Stat().MinQueryNodeDistance() = maxChange;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
+    TreeType* node)
+{
+  // This is basically IterationUpdate(), but pulled out to be separate from the
+  // actual dual-tree algorithm.
+
+  // First, update the iteration.
+  const size_t itDiff = node->Stat().Iteration() - iteration;
+
+  if (itDiff == 1)
+  {
+    // The easy case.
+    if (node->Stat().Owner() < centroids.n_cols)
+    {
+      // During the last iteration, this node was pruned.  In addition, we have
+      // cached a lower bound on the second closest cluster.  So, use the
+      // triangle inequality: if the maximum distance between the point and the
+      // cluster centroid plus the distance that centroid moved is less than the
+      // lower bound minus the maximum moving centroid, then this cluster *must*
+      // still have the same owner.
+      const size_t owner = node->Stat().Owner();
+      const double closestUpperBound = node->Stat().MaxQueryNodeDistance() +
+          clusterDistances[owner];
+      const TreeType* nonOwner = (TreeType*) node->Stat().ClosestNonOwner();
+      const double tightestLowerBound = node->Stat().ClosestNonOwnerDistance() -
+          nonOwner->Stat().MinQueryNodeDistance() /* abused from earlier *;
+      if (closestUpperBound <= tightestLowerBound)
+      {
+        // Then the owner must not have changed.
+
+      }
+    }
+  }
+}
+*/
+
+
 } // namespace kmeans
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index 978c12c..24cad31 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -66,7 +66,7 @@ class DualTreeKMeansRules
 
   TraversalInfoType traversalInfo;
 
-  size_t IterationUpdate(TreeType& referenceNode) const;
+  double IterationUpdate(TreeType& referenceNode);
 
   bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
       potentialChild) const;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 33ea4ae..7676e1e 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -112,7 +112,11 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
-  IterationUpdate(referenceNode);
+  if (IterationUpdate(referenceNode) == DBL_MAX)
+  {
+    // The iteration update showed that the owner could not possibly change.
+    return DBL_MAX;
+  }
 
   traversalInfo.LastReferenceNode() = &referenceNode;
 
@@ -201,12 +205,13 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
 }
 
 template<typename MetricType, typename TreeType>
-inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
-    TreeType& referenceNode) const
+inline double DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
+    TreeType& referenceNode)
 {
   if (referenceNode.Stat().Iteration() == iteration)
     return 0;
 
+  const size_t itDiff = iteration - referenceNode.Stat().Iteration();
   referenceNode.Stat().Iteration() = iteration;
   referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
       0 : referenceNode.Parent()->Stat().ClustersPruned();
@@ -214,15 +219,22 @@ inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
       NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
 
   if (referenceNode.Stat().ClosestQueryNode() != NULL)
+  {
     referenceNode.Stat().MinQueryNodeDistance() =
         referenceNode.MinDistance((TreeType*)
         referenceNode.Stat().ClosestQueryNode());
+    referenceNode.Stat().MaxQueryNodeDistance() =
+        referenceNode.MaxDistance((TreeType*)
+        referenceNode.Stat().ClosestQueryNode());
+    distanceCalculations += 2;
+  }
+
 
-  const size_t itDiff = iteration - referenceNode.Stat().Iteration();
   if (itDiff > 1)
   {
-    // Maybe this can be tighter?
+//    referenceNode.Stat().BestMaxDistance() = DBL_MAX;
     referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+    referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
   }
   else
   {
@@ -231,24 +243,39 @@ inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
       // Update the distance to the closest query node.  If this node has an
       // owner, we know how far to increase the bound.  Otherwise, increase it
       // by the furthest amount that any centroid moved.
-//      if (referenceNode.Stat().Owner() < centroids.n_cols)
-//        referenceNode.Stat().MinQueryNodeDistance() +=
-//            clusterDistances(referenceNode.Stat().Owner());
-//      else
-//        referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
-//            clusterDistances(centroids.n_cols);
-      if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX)
-        referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+      if (referenceNode.Stat().Owner() < centroids.n_cols)
+      {
+        referenceNode.Stat().MinQueryNodeDistance() +=
+            clusterDistances(referenceNode.Stat().Owner());
+        referenceNode.Stat().MaxQueryNodeDistance() +=
+            clusterDistances(referenceNode.Stat().Owner());
+      }
       else
       {
         referenceNode.Stat().MinQueryNodeDistance() +=
             clusterDistances(centroids.n_cols);
-//referenceNode.Stat().MaxQueryNodeDistance() +
-//clusterDistances(centroids.n_cols);
+        referenceNode.Stat().MaxQueryNodeDistance() +=
+            clusterDistances(centroids.n_cols);
       }
     }
+    else
+    {
+      referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+      referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
+    }
   }
 
+//    if (referenceNode.Stat().BestMaxDistance() != DBL_MAX)
+//    {
+//      if (referenceNode.Stat().Owner() < centroids.n_cols)
+//        referenceNode.Stat().BestMaxDistance() +=
+//            clusterDistances(referenceNode.Stat().Owner());
+//      else
+//        referenceNode.Stat().BestMaxDistance() +=
+//            clusterDistances(centroids.n_cols);
+//    }
+//  }
+
   return 1;
 }
 



More information about the mlpack-git mailing list