[mlpack-git] master: Handle SecondClosestBound() a little better. Debugging information is still there, and it is going to need to be seriously refactored. We still don't have properly working Hamerly prunes; they go away after a couple iterations incorrectly. (db0d792)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:56 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit db0d792ef83a231ca777a6c12a5450745d6cffdf
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Jan 26 15:30:01 2015 -0500

    Handle SecondClosestBound() a little better. Debugging information is still there, and it is going to need to be seriously refactored. We still don't have properly working Hamerly prunes; they go away after a couple iterations incorrectly.


>---------------------------------------------------------------

db0d792ef83a231ca777a6c12a5450745d6cffdf
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 97 ++++++++++++++++++----
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 69 ++++++++++++---
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  | 16 ++++
 3 files changed, 157 insertions(+), 25 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index ab293af..592a175 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -70,6 +70,9 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+  for (size_t i = 0; i < oldFromNewCentroids.size(); ++i)
+    Log::Warn << oldFromNewCentroids[i] << " ";
+  Log::Warn << "\n";
 
   // Now calculate distances between centroids.
   neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
@@ -154,6 +157,19 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ClusterTreeUpdate(
   node->Stat().FirstBound() = firstBound;
 }
 
+template<typename TreeType>
+bool IsDescendantOf(
+    const TreeType& potentialParent,
+    const TreeType& potentialChild)
+{
+  if (potentialChild.Parent() == &potentialParent)
+    return true;
+  else if (potentialChild.Parent() == NULL)
+    return false;
+  else
+    return IsDescendantOf(potentialParent, *potentialChild.Parent());
+}
+
 template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     TreeType* node,
@@ -201,12 +217,64 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
     // same owner in the next iteration.  Note that MaxQueryNodeDistance() has
     // already been adjusted for cluster movement.
 
+    // Re-set second closest bound if necessary.
+    if (node->Stat().SecondClosestBound() == DBL_MAX)
+    {
+      if (node->Parent() == NULL)
+        node->Stat().SecondClosestBound() = 0.0; // Don't prune the root.
+
+      else
+      {
+        if (node->Parent()->Stat().SecondClosestBound() != DBL_MAX &&
+node->Stat().LastSecondClosestBound() != DBL_MAX)
+          node->Stat().SecondClosestBound() =
+std::max(node->Parent()->Stat().SecondClosestBound(),
+node->Stat().LastSecondClosestBound());
+        else
+          node->Stat().SecondClosestBound() =
+std::min(node->Parent()->Stat().SecondClosestBound(),
+node->Stat().LastSecondClosestBound());
+      }
+//      if (node->Begin() == 35871)
+//        Log::Warn << "Update second closest bound for r35871c" <<
+//node->Count() << " to " << node->Stat().SecondClosestBound() << ", which could "
+//      << "have been parent's (" << node->Parent()->Stat().SecondClosestBound()
+//<< ") or adjusted last iteration's (" << node->Stat().LastSecondClosestBound()
+//<< ").\n";
+    }
+
+//    if (node->Begin() == 35871)
+//      Log::Warn << "r35871c" << node->Count() << " has second bound " <<
+//node->Stat().SecondClosestBound() << " (q" << ((TreeType*)
+//node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//node->Stat().SecondClosestQueryNode())->Count() << ") and parent has second "
+//          << "bound " << node->Parent()->Stat().SecondClosestBound() << " (q"
+//          << ((TreeType*)
+//node->Parent()->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//node->Parent()->Stat().SecondClosestQueryNode())->Count() << ").\n";
+
+    if (node->Parent() != NULL &&
+node->Parent()->Stat().SecondClosestQueryNode() != NULL &&
+node->Stat().SecondClosestQueryNode() != NULL && !IsDescendantOf(*((TreeType*)
+node->Stat().SecondClosestQueryNode()), *((TreeType*)
+node->Parent()->Stat().SecondClosestQueryNode())) &&
+node->Parent()->Stat().SecondClosestBound() < node->Stat().SecondClosestBound())
+    {
+//      if (node->Begin() == 35871)
+//        Log::Warn << "Take second closest bound for r35871c" <<
+//node->Count() << " from parent: " << node->Parent()->Stat().SecondClosestBound()
+//<< " (was " << node->Stat().SecondClosestBound() << ").\n";
+          node->Stat().SecondClosestBound() =
+node->Parent()->Stat().SecondClosestBound();
+    }
+
     if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
         - clusterDistances[clusters])
     {
       node->Stat().HamerlyPruned() = true;
-      Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
-          << "Hamerly pruned.\n";
+//      if (node->Begin() == 35871)
+        Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
+            << "Hamerly pruned.\n";
 
       // Check the second bound.  (This is time-consuming...)
       for (size_t j = 0; j < node->NumDescendants(); ++j)
@@ -223,13 +291,6 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
           distances(i) = distance;
         }
 
-        // Re-set second closest bound if necessary.
-        if (node->Stat().ClustersPruned() == size_t(-1))
-        {
-//          Log::Warn << "Update second closest bound!\n";
-          node->Stat().SecondClosestBound() = node->Parent()->Stat().SecondClosestBound();
-        }
-
         if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
         {
           Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
@@ -240,9 +301,10 @@ node->Stat().MinQueryNodeDistance() << ".\n";
 node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
               << "! (" << node->Stat().SecondClosestBound() - secondClosestDist
 << ")\n";
+
         }
-//        if (node->Begin() == 37591)
-//          Log::Warn << "r37591c" << node->Count() << ": " << distances.t();
+//        if (node->Begin() == 35871)
+//          Log::Warn << "r35871c" << node->Count() << ": " << distances.t();
       }
     }
 //    else
@@ -280,16 +342,21 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
   // be rebuilt.
   node->Stat().ClosestQueryNode() = NULL;
 
-//  if (node->Begin() == 37591)
-//    Log::Warn << "scb for r37591c" << node->Count() << " updated to " <<
+//  if (node->Begin() == 35871)
+//    Log::Warn << "scb for r35871c" << node->Count() << " updated to " <<
 //node->Stat().SecondClosestBound() << ".\n";
 
-//  if (!node->Stat().HamerlyPruned())
+  if (!node->Stat().HamerlyPruned())
     for (size_t i = 0; i < node->NumChildren(); ++i)
       TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
           centroids, dataset);
-}
 
+  node->Stat().LastSecondClosestBound() = node->Stat().SecondClosestBound() -
+      clusterDistances[clusters];
+  // This should change later, but I'm not yet sure how to do it.
+  node->Stat().SecondClosestBound() = DBL_MAX;
+  node->Stat().SecondClosestQueryNode() = NULL;
+}
 
 } // namespace kmeans
 } // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index f51c380..0a84376 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -136,9 +136,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     referenceNode.Stat().MaxQueryNodeDistance() = std::min(
         referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
         referenceNode.Stat().MaxQueryNodeDistance());
-    referenceNode.Stat().SecondClosestBound() = std::min(
-        referenceNode.Parent()->Stat().SecondClosestBound(),
-        referenceNode.Stat().SecondClosestBound());
+//    referenceNode.Stat().SecondClosestBound() = std::min(
+//        referenceNode.Parent()->Stat().SecondClosestBound(),
+//        referenceNode.Stat().SecondClosestBound());
 //    if (referenceNode.Begin() == 37591)
 //      Log::Warn << "Update second closest bound for r37591c" <<
 //referenceNode.Count() << " to parent's, which "
@@ -175,6 +175,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
       const double minDistance = referenceNode.MinDistance(&queryNode);
       ++distanceCalculations;
       score = PellegMooreScore(queryNode, referenceNode, minDistance);
+//      if (referenceNode.Begin() == 37591)
+//        Log::Warn << "mQND for r37591c" << referenceNode.Count() << " is "
+//            << referenceNode.Stat().MinQueryNodeDistance() << ".\n";
 
       if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
       {
@@ -187,24 +190,49 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
         {
           referenceNode.Stat().SecondClosestBound() =
               referenceNode.Stat().MinQueryNodeDistance();
+          referenceNode.Stat().SecondClosestQueryNode() =
+              referenceNode.Stat().ClosestQueryNode();
 //          if (referenceNode.Begin() == 37591)
 //            Log::Warn << "scb for r37591c" << referenceNode.Count() << " taken "
 //                << "from minDistance, which is " <<
 //referenceNode.Stat().MinQueryNodeDistance() << ".\n";
         }
-        ++distanceCalculations;
-        referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
-        referenceNode.Stat().MinQueryNodeDistance() = minDistance;
-        referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
 
-//        if (referenceNode.Begin() == 37591)
-//          Log::Warn << "mQND for r37591c" << referenceNode.Count() << " updated to " << minDistance << " and "
+        if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX &&
+            score == DBL_MAX &&
+            minDistance < referenceNode.Stat().SecondClosestBound())
+        {
+          referenceNode.Stat().SecondClosestBound() = minDistance;
+          referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
+//          if (referenceNode.Begin() == 37591)
+//            Log::Warn << "scb for r37591c" << referenceNode.Count() << " taken "
+//                << "from minDistance for pruned query node, which is " <<
+//minDistance << ".\n";
+        }
+
+        if (score != DBL_MAX)
+        {
+          ++distanceCalculations;
+          referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+          referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+          referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+
+//          if (referenceNode.Begin() == 37591)
+//            Log::Warn << "mQND for r37591c" << referenceNode.Count() << " updated to " << minDistance << " and "
 //              << "MQND to " << maxDistance << " with furthest query node " <<
 //              queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+        }
       }
       else if (IsDescendantOf(*((TreeType*)
           referenceNode.Stat().ClosestQueryNode()), queryNode))
       {
+//        if (referenceNode.Begin() == 37591)
+//          Log::Warn << "Old closest for r37591c" << referenceNode.Count() <<
+//              " is q" << ((TreeType*)
+//referenceNode.Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//referenceNode.Stat().ClosestQueryNode())->Count() << " with mQND " <<
+//referenceNode.Stat().MinQueryNodeDistance() << " and MQND " <<
+//referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
         const double maxDistance = referenceNode.MaxDistance(&queryNode);
         ++distanceCalculations;
         referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
@@ -219,6 +247,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
       else if (minDistance < referenceNode.Stat().SecondClosestBound())
       {
         referenceNode.Stat().SecondClosestBound() = minDistance;
+        referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
 //        if (referenceNode.Begin() == 37591)
 //          Log::Warn << "scb for r37591c" << referenceNode.Count() << " updated to " << minDistance << " via "
 //              << queryNode.Begin() << "c" << queryNode.Count() << ".\n";
@@ -226,6 +255,14 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
     }
   }
 
+  if (((TreeType*) referenceNode.Stat().ClosestQueryNode())->NumDescendants() > 1)
+  {
+    referenceNode.Stat().SecondClosestBound() =
+        referenceNode.Stat().MinQueryNodeDistance();
+    referenceNode.Stat().SecondClosestQueryNode() =
+        referenceNode.Stat().ClosestQueryNode();
+  }
+
   if (score == DBL_MAX)
   {
     referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
@@ -335,6 +372,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
           queryNode)) &&
       (&queryNode != (TreeType*) referenceNode.Stat().ClosestQueryNode()))
   {
+//    if (referenceNode.Begin() == 37591)
+//      Log::Warn << "Elkan prune r37591c" << referenceNode.Count() << ", q" <<
+//queryNode.Begin() << "c" << queryNode.Count() << "!\n";
     // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
     // means that N_q cannot possibly hold any clusters that own any points in
     // N_r.
@@ -346,15 +386,24 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
 
 template<typename MetricType, typename TreeType>
 double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
-    TreeType& /* queryNode */,
+    TreeType& queryNode,
     TreeType& referenceNode,
     const double minDistance) const
 {
   // If the minimum distance to the node is greater than the bound, then every
   // cluster in the query node cannot possibly be the nearest neighbor of any of
   // the points in the reference node.
+//  if (referenceNode.Begin() == 37591)
+//      Log::Warn << "Pelleg-Moore prune attempt r37591c" << referenceNode.Count() << ", "
+//          << "q" << queryNode.Begin() << "c" << queryNode.Count() << "; "
+//          << "minDistance " << minDistance << ", MQND " <<
+//referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
   if (minDistance > referenceNode.Stat().MaxQueryNodeDistance())
+  {
+//    if (referenceNode.Begin() == 37591)
+//      Log::Warn << "Attempt successful!\n";
     return DBL_MAX;
+  }
 
   return minDistance;
 }
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 87e4368..0b01fa6 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -21,6 +21,8 @@ class DualTreeKMeansStatistic
       minQueryNodeDistance(DBL_MAX),
       maxQueryNodeDistance(DBL_MAX),
       secondClosestBound(DBL_MAX),
+      secondClosestQueryNode(NULL),
+      lastSecondClosestBound(DBL_MAX),
       hamerlyPruned(false),
       clustersPruned(size_t(-1)),
       iteration(size_t() - 1),
@@ -68,6 +70,16 @@ class DualTreeKMeansStatistic
   //! Modify the lower bound on the second closest cluster distance.
   double& SecondClosestBound() { return secondClosestBound; }
 
+  //! Get the second closest query node.
+  void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+  //! Modify the second closest query node.
+  void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+
+  //! Get last iteration's second closest bound.
+  double LastSecondClosestBound() const { return lastSecondClosestBound; }
+  //! Modify last iteration's second closest bound.
+  double& LastSecondClosestBound() { return lastSecondClosestBound; }
+
   //! Get whether or not this node is Hamerly pruned this iteration.
   bool HamerlyPruned() const { return hamerlyPruned; }
   //! Modify whether or not this node is Hamerly pruned this iteration.
@@ -123,6 +135,10 @@ class DualTreeKMeansStatistic
   double maxQueryNodeDistance;
   //! A lower bound on the distance to the second closest cluster.
   double secondClosestBound;
+  //! The second closest query node.
+  void* secondClosestQueryNode;
+  //! The second closest lower bound, on the previous iteration.
+  double lastSecondClosestBound;
   //! Whether or not this node is pruned for the next iteration.
   bool hamerlyPruned;
 



More information about the mlpack-git mailing list