[mlpack-svn] r15779 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 16:09:48 EDT 2013


Author: rcurtin
Date: Fri Sep 13 16:09:48 2013
New Revision: 15779

Log:
Uh, this should fix the broken build.  I will clean this up in the forthcoming
days.  The traversers are re-implemented to make the overloads of Score() and
Rescore() that take base case values unnecessary, and actually the dual-tree
traverser now satisfies the definition given in the ICML paper correctly.


Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	Fri Sep 13 16:09:48 2013
@@ -458,6 +458,12 @@
    * Returns a string representation of this object.
    */
   std::string ToString() const;
+
+  size_t DistanceComps() const { return distanceComps; }
+  size_t& DistanceComps() { return distanceComps; }
+
+ private:
+  size_t distanceComps;
 };
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	Fri Sep 13 16:09:48 2013
@@ -31,7 +31,8 @@
     parentDistance(0),
     furthestDescendantDistance(0),
     localMetric(metric == NULL),
-    metric(metric)
+    metric(metric),
+    distanceComps(0)
 {
   // If we need to create a metric, do that.  We'll just do it on the heap.
   if (localMetric)
@@ -66,6 +67,9 @@
 
   // Initialize statistic.
   stat = StatisticType(*this);
+
+  Log::Info << distanceComps << " distance computations during tree "
+      << "construction." << std::endl;
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -82,7 +86,8 @@
     parentDistance(0),
     furthestDescendantDistance(0),
     localMetric(false),
-    metric(&metric)
+    metric(&metric),
+    distanceComps(0)
 {
   // If there is only one point in the dataset, uh, we're done.
   if (dataset.n_cols == 1)
@@ -113,6 +118,9 @@
 
   // Initialize statistic.
   stat = StatisticType(*this);
+
+  Log::Info << distanceComps << " distance computations during tree "
+      << "construction." << std::endl;
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -138,7 +146,8 @@
     parentDistance(parentDistance),
     furthestDescendantDistance(0),
     localMetric(false),
-    metric(&metric)
+    metric(&metric),
+    distanceComps(0)
 {
   // If the size of the near set is 0, this is a leaf.
   if (nearSetSize == 0)
@@ -176,7 +185,8 @@
     parentDistance(parentDistance),
     furthestDescendantDistance(furthestDescendantDistance),
     localMetric(metric == NULL),
-    metric(metric)
+    metric(metric),
+    distanceComps(0)
 {
   // If necessary, create a local metric.
   if (localMetric)
@@ -199,7 +209,8 @@
     parentDistance(other.parentDistance),
     furthestDescendantDistance(other.furthestDescendantDistance),
     localMetric(false),
-    metric(other.metric)
+    metric(other.metric),
+    distanceComps(0)
 {
   // Copy each child by hand.
   for (size_t i = 0; i < other.NumChildren(); ++i)
@@ -409,6 +420,7 @@
     size_t tempSize = 0;
     children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
         indices, distances, 0, tempSize, usedSetSize, *metric));
+    distanceComps += children.back()->DistanceComps();
 
     // Every point in the near set should be a leaf.
     for (size_t i = 0; i < nearSetSize; ++i)
@@ -417,6 +429,7 @@
       children.push_back(new CoverTree(dataset, base, indices[i],
           INT_MIN, this, distances[i], indices, distances, 0, tempSize,
           usedSetSize, *metric));
+      distanceComps += children.back()->DistanceComps();
       usedSetSize++;
     }
 
@@ -458,6 +471,8 @@
   // Remove any implicit nodes we may have created.
   RemoveNewImplicitNodes();
 
+  distanceComps += children[0]->DistanceComps();
+
   // Now the arrays, in memory, look like this:
   // [ childFar | childUsed | far | used ]
   // but we need to move the used points past our far set:
@@ -504,7 +519,8 @@
       children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
           this, distances[0], indices, distances, childNearSetSize, farSetSize,
           usedSetSize, *metric));
-      numDescendants += children[children.size() - 1]->NumDescendants();
+      distanceComps += children.back()->DistanceComps();
+      numDescendants += children.back()->NumDescendants();
 
       // Because the far set size is 0, we don't have to do any swapping to
       // move the point into the used set.
@@ -545,11 +561,13 @@
     children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
         this, distances[0], childIndices, childDistances, childNearSetSize,
         childFarSetSize, childUsedSetSize, *metric));
-    numDescendants += children[children.size() - 1]->NumDescendants();
+    numDescendants += children.back()->NumDescendants();
 
     // Remove any implicit nodes.
     RemoveNewImplicitNodes();
 
+    distanceComps += children.back()->DistanceComps();
+
     // Now with the child created, it returns the childIndices and
     // childDistances vectors in this form:
     // [ childFar | childUsed ]
@@ -628,6 +646,7 @@
 {
   // For each point, rebuild the distances.  The indices do not need to be
   // modified.
+  distanceComps += pointSetSize;
   for (size_t i = 0; i < pointSetSize; ++i)
   {
     distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
@@ -894,6 +913,7 @@
     // Set its parent correctly.
     old->Child(0).Parent() = this;
     old->Child(0).ParentDistance() = old->ParentDistance();
+    old->Child(0).DistanceComps() = old->DistanceComps();
 
     // Remove its child (so it doesn't delete it).
     old->Children().erase(old->Children().begin() + old->Children().size() - 1);

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp	Fri Sep 13 16:09:48 2013
@@ -48,6 +48,12 @@
   //! Modify the number of pruned nodes.
   size_t& NumPrunes() { return numPrunes; }
 
+  ///// These are all fake because this is a patch for kd-trees only and I still
+  ///// want it to compile!
+  size_t NumVisited() const { return 0; }
+  size_t NumScores() const { return 0; }
+  size_t NumBaseCases() const { return 0; }
+
  private:
   //! The instantiated rule set for pruning branches.
   RuleType& rule;
@@ -57,18 +63,12 @@
 
   //! Prepare map for recursion.
   void PruneMap(CoverTree& queryNode,
-                CoverTree& candidateQueryNode,
                 std::map<int, std::vector<DualCoverTreeMapEntry<
                     MetricType, RootPointPolicy, StatisticType> > >&
                     referenceMap,
                 std::map<int, std::vector<DualCoverTreeMapEntry<
                     MetricType, RootPointPolicy, StatisticType> > >& childMap);
 
-  void PruneMapForSelfChild(CoverTree& candidateQueryNode,
-                            std::map<int, std::vector<DualCoverTreeMapEntry<
-                                MetricType, RootPointPolicy, StatisticType> > >&
-                                referenceMap);
-
   void ReferenceRecursion(CoverTree& queryNode,
                           std::map<int, std::vector<DualCoverTreeMapEntry<
                               MetricType, RootPointPolicy, StatisticType> > >&

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp	Fri Sep 13 16:09:48 2013
@@ -21,12 +21,6 @@
   CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
   //! The score of the node.
   double score;
-  //! The index of the reference node used for the base case evaluation.
-  size_t referenceIndex;
-  //! The index of the query node used for the base case evaluation.
-  size_t queryIndex;
-  //! The base case evaluation.
-  double baseCase;
 
   //! Comparison operator, for sorting within the map.
   bool operator<(const DualCoverTreeMapEntry& other) const
@@ -53,18 +47,13 @@
   typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
       MapEntryType;
 
-  // Start by creating a map and adding the reference node to it.
+  // Start by creating a map and adding the reference root node to it.
   std::map<int, std::vector<MapEntryType> > refMap;
 
   MapEntryType rootRefEntry;
 
   rootRefEntry.referenceNode = &referenceNode;
   rootRefEntry.score = 0.0; // Must recurse into.
-  rootRefEntry.referenceIndex = referenceNode.Point();
-  rootRefEntry.queryIndex = queryNode.Point();
-  rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
-      referenceNode.Point());
-//  rule.UpdateAfterRecursion(queryNode, referenceNode);
 
   refMap[referenceNode.Scale()].push_back(rootRefEntry);
 
@@ -75,13 +64,10 @@
 template<typename RuleType>
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::Traverse(
-  CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
-  std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
-      StatisticType> > >& referenceMap)
+    CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+        StatisticType> > >& referenceMap)
 {
-//  Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
-//      << queryNode.Scale() << "\n";
-
   // Convenience typedef.
   typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
       MapEntryType;
@@ -99,26 +85,17 @@
   {
     // Recurse into the non-self-children first.  The recursion order cannot
     // affect the runtime of the algorithm, because each query child recursion's
-    // results are separate and independent.
-    for (size_t i = 1; i < queryNode.NumChildren(); ++i)
+    // results are separate and independent.  I don't think this is true in
+    // every case, and we may have to modify this section to consider scores in
+    // the future.
+    std::map<int, std::vector<MapEntryType> > childMap;
+    PruneMap(queryNode, referenceMap, childMap);
+    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
     {
-      std::map<int, std::vector<MapEntryType> > childMap;
-      PruneMap(queryNode, queryNode.Child(i), referenceMap, childMap);
-
-//      Log::Debug << "Recurse into query child " << i << ": " <<
-//          queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
-//          << "; this parent is " << queryNode.Point() << " scale " <<
-//          queryNode.Scale() << std::endl;
-      Traverse(queryNode.Child(i), childMap);
+      // We need a copy of the map for this child.
+      std::map<int, std::vector<MapEntryType> > thisChildMap = childMap;
+      Traverse(queryNode.Child(i), thisChildMap);
     }
-
-    PruneMapForSelfChild(queryNode.Child(0), referenceMap);
-
-    // Now we can use the existing map (without a copy) for the self-child.
-//    Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
-//        << " scale " << queryNode.Child(0).Scale() << "; this parent is "
-//        << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
-    Traverse(queryNode.Child(0), referenceMap);
   }
 
   if (queryNode.Scale() != INT_MIN)
@@ -129,7 +106,6 @@
   Log::Assert((*referenceMap.begin()).first == INT_MIN);
   Log::Assert(queryNode.Scale() == INT_MIN);
   std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
-//  Log::Debug << "Onto base case evaluations\n";
 
   for (size_t i = 0; i < pointVector.size(); ++i)
   {
@@ -138,36 +114,26 @@
 
     CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
         frame.referenceNode;
-    const double oldScore = frame.score;
-    const size_t refIndex = frame.referenceIndex;
-    const size_t queryIndex = frame.queryIndex;
-//    Log::Debug << "Consider query " << queryNode.Point() << ", reference "
-//        << refNode->Point() << "\n";
-//    Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
-//        << " and parent query node " << queryIndex << "\n";
 
-    // First, ensure that we have not already calculated the base case.
-    if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
+    // If the point is the same as both parents, then we have already done this
+    // base case.
+    if ((refNode->Point() == refNode->Parent()->Point()) &&
+        (queryNode.Point() == queryNode.Parent()->Point()))
     {
-//      Log::Debug << "Pruned because we already did the base case and its value "
-//          << " was " << frame.baseCase << std::endl;
       ++numPrunes;
       continue;
     }
 
-    // Now, check if we can prune it.
-    const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
+    // Score the node, to see if we can prune it.
+    double score = rule.Score(queryNode, *refNode);
 
-    if (rescore == DBL_MAX)
+    if (score == DBL_MAX)
     {
-//      Log::Debug << "Pruned after rescoring\n";
       ++numPrunes;
       continue;
     }
 
     // If not, compute the base case.
-//    Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
-//        " " << pointVector[i].referenceNode->Point() << "\n";
     rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
   }
 }
@@ -176,16 +142,12 @@
 template<typename RuleType>
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::PruneMap(
-    CoverTree& /* queryNode */,
-    CoverTree& candidateQueryNode,
+    CoverTree& queryNode,
     std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
         RootPointPolicy, StatisticType> > >& referenceMap,
     std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
         RootPointPolicy, StatisticType> > >& childMap)
 {
-//  Log::Debug << "Prep for recurse into query child point " <<
-//      candidateQueryNode.Point() << " scale " <<
-//      candidateQueryNode.Scale() << std::endl;
   typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
       MapEntryType;
 
@@ -194,15 +156,17 @@
   typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
       referenceMap.rbegin();
 
-  while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
+  while ((it != referenceMap.rend()))
   {
     // Get a reference to the vector representing the entries at this scale.
-    const std::vector<MapEntryType>& scaleVector = (*it).second;
+    std::vector<MapEntryType>& scaleVector = (*it).second;
+
+    // Before traversing all the points in this scale, sort by score.
+    std::sort(scaleVector.begin(), scaleVector.end());
 
     const int thisScale = (*it).first;
     childMap[thisScale].reserve(scaleVector.size());
     std::vector<MapEntryType>& newScaleVector = childMap[thisScale];
-//    newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
 
     // Loop over each entry in the vector.
     for (size_t j = 0; j < scaleVector.size(); ++j)
@@ -212,27 +176,9 @@
       // First evaluate if we can prune without performing the base case.
       CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
           frame.referenceNode;
-      const double oldScore = frame.score;
-
-      // Try to prune based on shell().  This is hackish and will need to be
-      // refined or cleaned at some point.
-//      double score = rule.PrescoreQ(queryNode, candidateQueryNode, *refNode,
-//          frame.baseCase);
-
-//      if (score == DBL_MAX)
-//      {
-//        ++numPrunes;
-//        continue;
-//      }
-
-//      Log::Debug << "Recheck reference node " << refNode->Point() <<
-//          " scale " << refNode->Scale() << " which has old score " <<
-//          oldScore << " with old reference index " << frame.referenceIndex
-//          << " and old query index " << frame.queryIndex << std::endl;
-
-      double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
 
-//      Log::Debug << "Rescored as " << score << std::endl;
+      // Perform the actual scoring.
+      double score = rule.Score(queryNode, *refNode);
 
       if (score == DBL_MAX)
       {
@@ -241,29 +187,12 @@
         continue;
       }
 
-      // Evaluate base case.
-//      Log::Debug << "Must evaluate base case "
-//          << candidateQueryNode.Point() << " " << refNode->Point()
-//          << "\n";
-      double baseCase = rule.BaseCase(candidateQueryNode.Point(),
-          refNode->Point());
-//      Log::Debug << "Base case was " << baseCase << std::endl;
-
-      score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
-      if (score == DBL_MAX)
-      {
-        // Pruned.  Move on.
-        ++numPrunes;
-        continue;
-      }
+      // If it isn't pruned, we must evaluate the base case.
+      rule.BaseCase(queryNode.Point(), refNode->Point());
 
       // Add to child map.
       newScaleVector.push_back(frame);
       newScaleVector.back().score = score;
-      newScaleVector.back().baseCase = baseCase;
-      newScaleVector.back().referenceIndex = refNode->Point();
-      newScaleVector.back().queryIndex = candidateQueryNode.Point();
     }
 
     // If we didn't add anything, then strike this vector from the map.
@@ -272,108 +201,6 @@
 
     ++it; // Advance to next scale.
   }
-
-  childMap[INT_MIN] = referenceMap[INT_MIN];
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMapForSelfChild(
-    CoverTree& candidateQueryNode,
-    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
-        StatisticType> > >& referenceMap)
-{
-//  Log::Debug << "Prep for recurse into query self-child point " <<
-//      candidateQueryNode.Point() << " scale " <<
-//      candidateQueryNode.Scale() << std::endl;
-  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      MapEntryType;
-
-  // Create the child reference map.  We will do this by recursing through
-  // every entry in the reference map and evaluating (or pruning) it.  But
-  // in this setting we do not recurse into any children of the reference
-  // entries.
-  if (referenceMap.empty())
-    return; // Nothing to do.
-  typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
-      referenceMap.rbegin();
-
-  while (it != referenceMap.rend() && (*it).first != INT_MIN)
-  {
-    // Get a reference to the vector representing the entries at this scale.
-    std::vector<MapEntryType>& newScaleVector = (*it).second;
-    const std::vector<MapEntryType> scaleVector = newScaleVector;
-
-    newScaleVector.clear();
-    newScaleVector.reserve(scaleVector.size());
-
-    // Loop over each entry in the vector.
-    for (size_t j = 0; j < scaleVector.size(); ++j)
-    {
-      const MapEntryType& frame = scaleVector[j];
-
-      // First evaluate if we can prune without performing the base case.
-      CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
-        frame.referenceNode;
-      const double oldScore = frame.score;
-      double baseCase = frame.baseCase;
-      const size_t queryIndex = frame.queryIndex;
-      const size_t refIndex = frame.referenceIndex;
-
-//      Log::Debug << "Recheck reference node " << refNode->Point() << " scale "
-//          << refNode->Scale() << " which has old score " << oldScore
-//          << std::endl;
-
-      // Have we performed the base case yet?
-      double score;
-      if ((refIndex != refNode->Point()) ||
-          (queryIndex != candidateQueryNode.Point()))
-      {
-        // Attempt to rescore before performing the base case.
-        score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
-        if (score == DBL_MAX)
-        {
-          ++numPrunes;
-          continue;
-        }
-
-        // If not pruned, we have to perform the base case.
-        baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
-      }
-
-      score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
-//      Log::Debug << "Rescored as " << score << std::endl;
-
-      if (score == DBL_MAX)
-      {
-        // Pruned.  Move on.
-        ++numPrunes;
-        continue;
-      }
-
-//      Log::Debug << "Kept in map\n";
-
-      // Add to child map.
-      newScaleVector.push_back(frame);
-      newScaleVector.back().score = score;
-      newScaleVector.back().baseCase = baseCase;
-      newScaleVector.back().queryIndex = candidateQueryNode.Point();
-      newScaleVector.back().referenceIndex = refNode->Point();
-    }
-
-    // If we didn't add anything, then strike this vector from the map.
-    if (newScaleVector.size() == 0)
-    {
-      referenceMap.erase((*it).first);
-      if (referenceMap.empty())
-        break;
-    }
-
-    ++it; // Advance to next scale.
-  }
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -411,115 +238,30 @@
 
       CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
           frame.referenceNode;
-      const double score = frame.score;
-      const size_t refIndex = frame.referenceIndex;
-      const size_t refPoint = refNode->Point();
-      const size_t queryIndex = frame.queryIndex;
-      const size_t queryPoint = queryNode.Point();
-      double baseCase = frame.baseCase;
-
-//      Log::Debug << "Currently inspecting reference node " << refNode->Point()
-//          << " scale " << refNode->Scale() << " earlier query index " <<
-//          queryIndex << std::endl;
-
-//      Log::Debug << "Old score " << score << " with refParent " << refIndex
-//          << " and queryIndex " << queryIndex << "\n";
-
-      // First we recalculate the score of this node to find if we can prune it.
-      if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
-      {
-//        Log::Warn << "Pruned after rescore\n";
-        ++numPrunes;
-        continue;
-      }
-
-      // If this is a self-child, the base case has already been evaluated.
-      // We also must ensure that the base case was evaluated with this query
-      // point.
-      if ((refPoint != refIndex) || (queryPoint != queryIndex))
-      {
-//        Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
-//            << refPoint << "\n";
-        baseCase = rule.BaseCase(queryPoint, refPoint);
-//        Log::Debug << "Base case " << baseCase << std::endl;
-      }
 
       // Create the score for the children.
-      double childScore = rule.Score(queryNode, *refNode, baseCase);
+      double score = rule.Score(queryNode, *refNode);
 
       // Now if this childScore is DBL_MAX we can prune all children.  In this
       // recursion setup pruning is all or nothing for children.
-      if (childScore == DBL_MAX)
+      if (score == DBL_MAX)
       {
-//        Log::Warn << "Pruned all children.\n";
-        numPrunes += refNode->NumChildren();
+        ++numPrunes;
         continue;
       }
 
-      // We must treat the self-leaf differently.  The base case has already
-      // been performed.
-      childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
+      // If it is not pruned, we must evaluate the base case.
+      rule.BaseCase(queryNode.Point(), refNode->Point());
 
-      if (childScore != DBL_MAX)
+      // Add the children.
+      for (size_t j = 0; j < refNode->NumChildren(); ++j)
       {
-        MapEntryType newFrame;
-        newFrame.referenceNode = &refNode->Child(0);
-        newFrame.score = childScore;
-        newFrame.baseCase = baseCase;
-        newFrame.referenceIndex = refPoint;
-        newFrame.queryIndex = queryNode.Point();
-
-        referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
-      }
-      else
-      {
-        ++numPrunes;
-      }
-
-      // Add the non-self-leaf children.
-      for (size_t j = 1; j < refNode->NumChildren(); ++j)
-      {
-        const size_t queryIndex = queryNode.Point();
-        const size_t refIndex = refNode->Child(j).Point();
-
-        // We need to incorporate shell() here to try and avoid base case
-        // computations.  TODO
-//        Log::Debug << "Prescore query " << queryNode.Point() << " scale "
-//            << queryNode.Scale() << ", reference " << refNode->Point() <<
-//            " scale " << refNode->Scale() << ", reference child " <<
-//            refNode->Child(j).Point() << " scale " << refNode->Child(j).Scale()
-//            << " with base case " << baseCase;
-//        childScore = rule.Prescore(queryNode, *refNode, refNode->Child(j),
-//            frame.baseCase);
-//        Log::Debug << " and result " << childScore << ".\n";
-
-//        if (childScore == DBL_MAX)
-//        {
-//          ++numPrunes;
-//          continue;
-//        }
-
-        // Calculate the base case of each child.
-        baseCase = rule.BaseCase(queryIndex, refIndex);
-
-        // See if we can prune it.
-        double childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
-
-        if (childScore == DBL_MAX)
-        {
-          ++numPrunes;
-          continue;
-        }
+//        const size_t queryIndex = queryNode.Point();
+//        const size_t refIndex = refNode->Child(j).Point();
 
         MapEntryType newFrame;
         newFrame.referenceNode = &refNode->Child(j);
-        newFrame.score = childScore;
-        newFrame.baseCase = baseCase;
-        newFrame.referenceIndex = refIndex;
-        newFrame.queryIndex = queryIndex;
-
-//        Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
-//            " scale " << refNode->Child(j).Scale() << std::endl;
+        newFrame.score = score; // Use the score of the parent.
 
         referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
       }
@@ -528,7 +270,6 @@
     // Now clear the memory for this scale; it isn't needed anymore.
     referenceMap.erase((*referenceMap.rbegin()).first);
   }
-
 }
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp	Fri Sep 13 16:09:48 2013
@@ -64,12 +64,8 @@
   // largest scale.
   std::map<int, std::vector<MapEntryType> > mapQueue;
 
-  // Manually add the children of the first node.  These cannot be pruned
-  // anyway.
-  double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
-
   // Create the score for the children.
-  double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
+  double rootChildScore = rule.Score(queryIndex, referenceNode);
 
   if (rootChildScore == DBL_MAX)
   {
@@ -77,6 +73,12 @@
   }
   else
   {
+    // Manually add the children of the first node.
+    // Often, a ruleset will return without doing any computation on cover trees
+    // using TreeTraits::FirstPointIsCentroid; this is an optimization that
+    // (theoretically) the compiler should get right.
+    double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
+
     // Don't add the self-leaf.
     size_t i = 0;
     if (referenceNode.Child(0).NumChildren() == 0)
@@ -132,12 +134,8 @@
         continue;
       }
 
-      // If we are a self-child, the base case has already been evaluated.
-      if (point != parent)
-        baseCase = rule.BaseCase(queryIndex, point);
-
       // Create the score for the children.
-      const double childScore = rule.Score(queryIndex, *node, baseCase);
+      const double childScore = rule.Score(queryIndex, *node);
 
       // Now if this childScore is DBL_MAX we can prune all children.  In this
       // recursion setup pruning is all or nothing for children.
@@ -147,6 +145,13 @@
         continue;
       }
 
+      // If we are a self-child, the base case has already been evaluated.
+      // Often, a ruleset will return without doing any computation on cover
+      // trees using TreeTraits::FirstPointIsCentroid; this is an optimization
+      // that (theoretically) the compiler should get right.
+      if (point != parent)
+        baseCase = rule.BaseCase(queryIndex, point);
+
       // Don't add the self-leaf.
       size_t j = 0;
       if (node->Child(0).NumChildren() == 0)
@@ -181,16 +186,32 @@
     const size_t point = node->Point();
 
     // First, recalculate the score of this node to find if we can prune it.
-    double actualScore = rule.Rescore(queryIndex, *node, score);
+    double rescore = rule.Rescore(queryIndex, *node, score);
 
-    if (actualScore == DBL_MAX)
+    if (rescore == DBL_MAX)
     {
       ++numPrunes;
       continue;
     }
 
-    // There are no self-leaves; evaluate the base case.
-    rule.BaseCase(queryIndex, point);
+    // For this to be a valid dual-tree algorithm, we *must* evaluate the
+    // combination, even if pruning it will make no difference.  It's the
+    // definition.
+    const double actualScore = rule.Score(queryIndex, *node);
+
+    if (actualScore == DBL_MAX)
+    {
+      ++numPrunes;
+      continue;
+    }
+    else
+    {
+      // Evaluate the base case, since the combination was not pruned.
+      // Often, a ruleset will return without doing any computation on cover
+      // trees using TreeTraits::FirstPointIsCentroid; this is an optimization
+      // that (theoretically) the compiler should get right.
+      rule.BaseCase(queryIndex, point);
+    }
   }
 }
 



More information about the mlpack-svn mailing list