[mlpack-git] master, mlpack-1.0.x: Modify CoverTree::DualTreeTraverser to properly handle TraversalInfo objects. (41a730e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:31 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 41a730ee43311e15bacae3e35dac9328193d97d4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Feb 6 20:21:02 2014 +0000

    Modify CoverTree::DualTreeTraverser to properly handle TraversalInfo objects.


>---------------------------------------------------------------

41a730ee43311e15bacae3e35dac9328193d97d4
 .../core/tree/cover_tree/dual_tree_traverser.hpp   |  51 +++---
 .../tree/cover_tree/dual_tree_traverser_impl.hpp   | 180 +++++++++++++--------
 2 files changed, 147 insertions(+), 84 deletions(-)

diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
index 0e7142e..e09d2d0 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
@@ -13,10 +13,6 @@
 namespace mlpack {
 namespace tree {
 
-//! Forward declaration of struct to be used for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry;
-
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
@@ -35,14 +31,6 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
    */
   void Traverse(CoverTree& queryNode, CoverTree& referenceNode);
 
-  /**
-   * Helper function for traversal of the two trees.
-   */
-  void Traverse(CoverTree& queryNode,
-                std::map<int, std::vector<DualCoverTreeMapEntry<
-                    MetricType, RootPointPolicy, StatisticType> > >&
-                    referenceMap);
-
   //! Get the number of pruned nodes.
   size_t NumPrunes() const { return numPrunes; }
   //! Modify the number of pruned nodes.
@@ -61,17 +49,44 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
   //! The number of pruned nodes.
   size_t numPrunes;
 
+  //! Struct used for traversal.
+  struct DualCoverTreeMapEntry
+  {
+    //! The node this entry refers to.
+    CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+    //! The score of the node.
+    double score;
+    //! The base case.
+    double baseCase;
+    //! The traversal info associated with the call to Score() for this entry.
+    typename RuleType::TraversalInfoType traversalInfo;
+
+    //! Comparison operator, for sorting within the map.
+    bool operator<(const DualCoverTreeMapEntry& other) const
+    {
+      if (score == other.score)
+        return (baseCase < other.baseCase);
+      else
+        return (score < other.score);
+    }
+  };
+
+  /**
+   * Helper function for traversal of the two trees.
+   */
+  void Traverse(CoverTree& queryNode,
+                std::map<int, std::vector<DualCoverTreeMapEntry> >&
+                    referenceMap);
+
   //! Prepare map for recursion.
   void PruneMap(CoverTree& queryNode,
-                std::map<int, std::vector<DualCoverTreeMapEntry<
-                    MetricType, RootPointPolicy, StatisticType> > >&
+                std::map<int, std::vector<DualCoverTreeMapEntry> >&
                     referenceMap,
-                std::map<int, std::vector<DualCoverTreeMapEntry<
-                    MetricType, RootPointPolicy, StatisticType> > >& childMap);
+                std::map<int, std::vector<DualCoverTreeMapEntry> >&
+                    childMap);
 
   void ReferenceRecursion(CoverTree& queryNode,
-                          std::map<int, std::vector<DualCoverTreeMapEntry<
-                              MetricType, RootPointPolicy, StatisticType> > >&
+                          std::map<int, std::vector<DualCoverTreeMapEntry> >&
                               referenceMap);
 };
 
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
index 55d88a7..741e257 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
@@ -13,22 +13,6 @@
 namespace mlpack {
 namespace tree {
 
-//! The object placed in the map for tree traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry
-{
-  //! The node this entry refers to.
-  CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
-  //! The score of the node.
-  double score;
-
-  //! Comparison operator, for sorting within the map.
-  bool operator<(const DualCoverTreeMapEntry& other) const
-  {
-    return (score < other.score);
-  }
-};
-
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 template<typename RuleType>
 CoverTree<MetricType, RootPointPolicy, StatisticType>::
@@ -44,16 +28,15 @@ DualTreeTraverser<RuleType>::Traverse(
     CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
     CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
 {
-  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      MapEntryType;
-
   // Start by creating a map and adding the reference root node to it.
-  std::map<int, std::vector<MapEntryType> > refMap;
+  std::map<int, std::vector<DualCoverTreeMapEntry> > refMap;
 
-  MapEntryType rootRefEntry;
+  DualCoverTreeMapEntry rootRefEntry;
 
   rootRefEntry.referenceNode = &referenceNode;
   rootRefEntry.score = 0.0; // Must recurse into.
+  rootRefEntry.baseCase = 0.0;
+  rootRefEntry.traversalInfo = rule.TraversalInfo();
 
   refMap[referenceNode.Scale()].push_back(rootRefEntry);
 
@@ -65,13 +48,8 @@ template<typename RuleType>
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::Traverse(
     CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
-    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
-        StatisticType> > >& referenceMap)
+    std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
 {
-  // Convenience typedef.
-  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      MapEntryType;
-
   if (referenceMap.size() == 0)
     return; // Nothing to do!
 
@@ -92,14 +70,16 @@ DualTreeTraverser<RuleType>::Traverse(
     // results are separate and independent.  I don't think this is true in
     // every case, and we may have to modify this section to consider scores in
     // the future.
-    std::map<int, std::vector<MapEntryType> > childMap;
-    PruneMap(queryNode, referenceMap, childMap);
-    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+    for (size_t i = 1; i < queryNode.NumChildren(); ++i)
     {
       // We need a copy of the map for this child.
-      std::map<int, std::vector<MapEntryType> > thisChildMap = childMap;
-      Traverse(queryNode.Child(i), thisChildMap);
+      std::map<int, std::vector<DualCoverTreeMapEntry> > childMap;
+      PruneMap(queryNode.Child(i), referenceMap, childMap);
+      Traverse(queryNode.Child(i), childMap);
     }
+    std::map<int, std::vector<DualCoverTreeMapEntry> > selfChildMap;
+    PruneMap(queryNode.Child(0), referenceMap, selfChildMap);
+    Traverse(queryNode.Child(0), selfChildMap);
   }
 
   if (queryNode.Scale() != INT_MIN)
@@ -109,12 +89,13 @@ DualTreeTraverser<RuleType>::Traverse(
   // evaluations to do.
   Log::Assert((*referenceMap.begin()).first == INT_MIN);
   Log::Assert(queryNode.Scale() == INT_MIN);
-  std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
+  std::vector<DualCoverTreeMapEntry>& pointVector =
+      (*referenceMap.begin()).second;
 
   for (size_t i = 0; i < pointVector.size(); ++i)
   {
     // Get a reference to the frame.
-    const MapEntryType& frame = pointVector[i];
+    const DualCoverTreeMapEntry& frame = pointVector[i];
 
     CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
         frame.referenceNode;
@@ -128,7 +109,9 @@ DualTreeTraverser<RuleType>::Traverse(
       continue;
     }
 
-    // Score the node, to see if we can prune it.
+    // Score the node, to see if we can prune it, after restoring the traversal
+    // info.
+    rule.TraversalInfo() = frame.traversalInfo;
     double score = rule.Score(queryNode, *refNode);
 
     if (score == DBL_MAX)
@@ -147,41 +130,91 @@ template<typename RuleType>
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::PruneMap(
     CoverTree& queryNode,
-    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
-        RootPointPolicy, StatisticType> > >& referenceMap,
-    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
-        RootPointPolicy, StatisticType> > >& childMap)
+    std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap,
+    std::map<int, std::vector<DualCoverTreeMapEntry> >& childMap)
 {
-  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      MapEntryType;
-
   if (referenceMap.empty())
     return; // Nothing to do.
-  typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
-      referenceMap.rbegin();
 
-  while ((it != referenceMap.rend()))
+  // Copy the zero set first.
+  if ((*referenceMap.begin()).first == INT_MIN)
   {
     // Get a reference to the vector representing the entries at this scale.
-    std::vector<MapEntryType>& scaleVector = (*it).second;
+    std::vector<DualCoverTreeMapEntry>& scaleVector =
+        (*referenceMap.begin()).second;
 
     // Before traversing all the points in this scale, sort by score.
     std::sort(scaleVector.begin(), scaleVector.end());
 
+    const int thisScale = (*referenceMap.begin()).first;
+    childMap[thisScale].reserve(scaleVector.size());
+    std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
+
+    // Loop over each entry in the vector.
+    for (size_t j = 0; j < scaleVector.size(); ++j)
+    {
+      const DualCoverTreeMapEntry& frame = scaleVector[j];
+
+      // First evaluate if we can prune without performing the base case.
+      CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+          frame.referenceNode;
+
+      // Perform the actual scoring, after restoring the traversal info.
+      rule.TraversalInfo() = frame.traversalInfo;
+      double score = rule.Score(queryNode, *refNode);
+
+      if (score == DBL_MAX)
+      {
+        // Pruned.  Move on.
+        ++numPrunes;
+        continue;
+      }
+
+      // If it isn't pruned, we must evaluate the base case.
+      const double baseCase = rule.BaseCase(queryNode.Point(),
+          refNode->Point());
+
+      // Add to child map.
+      newScaleVector.push_back(frame);
+      newScaleVector.back().score = score;
+      newScaleVector.back().baseCase = baseCase;
+      newScaleVector.back().traversalInfo = rule.TraversalInfo();
+    }
+
+    // If we didn't add anything, then strike this vector from the map.
+    if (newScaleVector.size() == 0)
+      childMap.erase((*referenceMap.begin()).first);
+  }
+
+  typename std::map<int, std::vector<DualCoverTreeMapEntry> >::reverse_iterator
+      it = referenceMap.rbegin();
+
+  while ((it != referenceMap.rend()))
+  {
     const int thisScale = (*it).first;
+    if (thisScale == INT_MIN) // We already did it.
+      break;
+
+    // Get a reference to the vector representing the entries at this scale.
+    std::vector<DualCoverTreeMapEntry>& scaleVector = (*it).second;
+
+    // Before traversing all the points in this scale, sort by score.
+    std::sort(scaleVector.begin(), scaleVector.end());
+
     childMap[thisScale].reserve(scaleVector.size());
-    std::vector<MapEntryType>& newScaleVector = childMap[thisScale];
+    std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
 
     // Loop over each entry in the vector.
     for (size_t j = 0; j < scaleVector.size(); ++j)
     {
-      const MapEntryType& frame = scaleVector[j];
+      const DualCoverTreeMapEntry& frame = scaleVector[j];
 
       // First evaluate if we can prune without performing the base case.
       CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
           frame.referenceNode;
 
-      // Perform the actual scoring.
+      // Perform the actual scoring, after restoring the traversal info.
+      rule.TraversalInfo() = frame.traversalInfo;
       double score = rule.Score(queryNode, *refNode);
 
       if (score == DBL_MAX)
@@ -192,11 +225,14 @@ DualTreeTraverser<RuleType>::PruneMap(
       }
 
       // If it isn't pruned, we must evaluate the base case.
-      rule.BaseCase(queryNode.Point(), refNode->Point());
+      const double baseCase = rule.BaseCase(queryNode.Point(),
+          refNode->Point());
 
       // Add to child map.
       newScaleVector.push_back(frame);
       newScaleVector.back().score = score;
+      newScaleVector.back().baseCase = baseCase;
+      newScaleVector.back().traversalInfo = rule.TraversalInfo();
     }
 
     // If we didn't add anything, then strike this vector from the map.
@@ -212,17 +248,19 @@ template<typename RuleType>
 void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 DualTreeTraverser<RuleType>::ReferenceRecursion(
     CoverTree& queryNode,
-    std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
-        StatisticType> > >& referenceMap)
+    std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
 {
-  typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
-      MapEntryType;
-
   // First, reduce the maximum scale in the reference map down to the scale of
   // the query node.
-  while (!referenceMap.empty() &&
-      ((*referenceMap.rbegin()).first > queryNode.Scale()))
+  while (!referenceMap.empty())
   {
+    // Hacky bullshit to imitate jl cover tree.
+    if (queryNode.Parent() == NULL && (*referenceMap.rbegin()).first <
+        queryNode.Scale())
+      break;
+    if (queryNode.Parent() != NULL && (*referenceMap.rbegin()).first <=
+        queryNode.Scale())
+      break;
     // If the query node's scale is INT_MIN and the reference map's maximum
     // scale is INT_MIN, don't try to recurse...
     if ((queryNode.Scale() == INT_MIN) &&
@@ -230,7 +268,7 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
       break;
 
     // Get a reference to the current largest scale.
-    std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
+    std::vector<DualCoverTreeMapEntry>& scaleVector = (*referenceMap.rbegin()).second;
 
     // Before traversing all the points in this scale, sort by score.
     std::sort(scaleVector.begin(), scaleVector.end());
@@ -239,13 +277,13 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
     for (size_t i = 0; i < scaleVector.size(); ++i)
     {
       // Get a reference to the current element.
-      const MapEntryType& frame = scaleVector.at(i);
+      const DualCoverTreeMapEntry& frame = scaleVector.at(i);
 
       CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
           frame.referenceNode;
 
       // Create the score for the children.
-      double score = rule.Score(queryNode, *refNode);
+      double score = rule.Rescore(queryNode, *refNode, frame.score);
 
       // Now if this childScore is DBL_MAX we can prune all children.  In this
       // recursion setup pruning is all or nothing for children.
@@ -256,17 +294,27 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
       }
 
       // If it is not pruned, we must evaluate the base case.
-      rule.BaseCase(queryNode.Point(), refNode->Point());
 
       // Add the children.
       for (size_t j = 0; j < refNode->NumChildren(); ++j)
       {
-//        const size_t queryIndex = queryNode.Point();
-//        const size_t refIndex = refNode->Child(j).Point();
-
-        MapEntryType newFrame;
+        rule.TraversalInfo() = frame.traversalInfo;
+        double childScore = rule.Score(queryNode, refNode->Child(j));
+        if (childScore == DBL_MAX)
+        {
+          ++numPrunes;
+          continue;
+        }
+
+        // It wasn't pruned; evaluate the base case.
+        const double baseCase = rule.BaseCase(queryNode.Point(),
+            refNode->Child(j).Point());
+
+        DualCoverTreeMapEntry newFrame;
         newFrame.referenceNode = &refNode->Child(j);
-        newFrame.score = score; // Use the score of the parent.
+        newFrame.score = childScore; // Use the score of the parent.
+        newFrame.baseCase = baseCase;
+        newFrame.traversalInfo = rule.TraversalInfo();
 
         referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
       }



More information about the mlpack-git mailing list