[mlpack-svn] r15802 - mlpack/trunk/src/mlpack/methods/range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Sep 17 22:05:42 EDT 2013


Author: rcurtin
Date: Tue Sep 17 22:05:42 2013
New Revision: 15802

Log:
Refactor RangeSearch to work properly for both cover trees and kd-trees with the
new cover tree traversal.


Modified:
   mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	Tue Sep 17 22:05:42 2013
@@ -207,10 +207,11 @@
   //! Mappings to old query indices (used when this object builds trees).
   std::vector<size_t> oldFromNewQueries;
 
-  //! Indicates ownership of the reference tree (meaning we need to delete it).
-  bool ownReferenceTree;
-  //! Indicates ownership of the query tree (meaning we need to delete it).
-  bool ownQueryTree;
+  //! If true, this object is responsible for deleting the trees.
+  bool treeOwner;
+  //! If true, a query set was passed; if false, the query set is the reference
+  //! set.
+  bool hasQuerySet;
 
   //! If true, O(n^2) naive computation is used.
   bool naive;

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_impl.hpp	Tue Sep 17 22:05:42 2013
@@ -28,8 +28,8 @@
     queryCopy(querySet),
     referenceSet(referenceCopy),
     querySet(queryCopy),
-    ownReferenceTree(true),
-    ownQueryTree(true),
+    treeOwner(true),
+    hasQuerySet(true),
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
     metric(metric),
@@ -59,8 +59,8 @@
     referenceSet(referenceCopy),
     querySet(referenceCopy),
     queryTree(NULL),
-    ownReferenceTree(true),
-    ownQueryTree(false),
+    treeOwner(true),
+    hasQuerySet(false),
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
     metric(metric),
@@ -73,6 +73,10 @@
   referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
       (naive ? referenceCopy.n_cols : leafSize));
 
+  // If using dual-tree mode, then we need a second tree.
+  if (!singleMode)
+    queryTree = new TreeType(*referenceTree);
+
   Timer::Stop("range_search/tree_building");
 }
 
@@ -88,8 +92,8 @@
     querySet(querySet),
     referenceTree(referenceTree),
     queryTree(queryTree),
-    ownReferenceTree(false),
-    ownQueryTree(false),
+    treeOwner(false),
+    hasQuerySet(true),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -108,22 +112,31 @@
     querySet(referenceSet),
     referenceTree(referenceTree),
     queryTree(NULL),
-    ownReferenceTree(false),
-    ownQueryTree(false),
+    treeOwner(false),
+    hasQuerySet(false),
     naive(false),
     singleMode(singleMode),
     metric(metric),
     numPrunes(0)
 {
-  // Nothing else to initialize.
+  // If doing dual-tree range search, we must clone the reference tree.
+  if (!singleMode)
+    queryTree = new TreeType(*referenceTree);
 }
 
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::~RangeSearch()
 {
-  if (ownReferenceTree)
-    delete referenceTree;
-  if (ownQueryTree)
+  if (treeOwner)
+  {
+    if (referenceTree)
+      delete referenceTree;
+    if (queryTree)
+      delete queryTree;
+  }
+
+  // If doing dual-tree search with one dataset, we cloned the reference tree.
+  if (!treeOwner && !hasQuerySet && !singleMode)
     delete queryTree;
 }
 
@@ -145,9 +158,9 @@
   std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
   std::vector<std::vector<double> >* distancePtr = &distances;
 
-  if (ownQueryTree || (ownReferenceTree && !queryTree))
+  if (treeOwner && !(singleMode && hasQuerySet))
     distancePtr = new std::vector<std::vector<double> >;
-  if (ownReferenceTree || ownQueryTree)
+  if (treeOwner)
     neighborPtr = new std::vector<std::vector<size_t> >;
 
   // Resize each vector.
@@ -177,10 +190,7 @@
     // Create the traverser.
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
-    if (queryTree)
-      traverser.Traverse(*queryTree, *referenceTree);
-    else
-      traverser.Traverse(*referenceTree, *referenceTree);
+    traverser.Traverse(*queryTree, *referenceTree);
 
     numPrunes = traverser.NumPrunes();
   }
@@ -192,12 +202,12 @@
       << "." << std::endl;
 
   // Map points back to original indices, if necessary.
-  if (!ownReferenceTree && !ownQueryTree)
+  if (!treeOwner)
   {
     // No mapping needed.  We are done.
     return;
   }
-  else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+  else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
   {
     neighbors.clear();
     neighbors.resize(querySet.n_cols);
@@ -222,71 +232,48 @@
     delete neighborPtr;
     delete distancePtr;
   }
-  else if (ownReferenceTree)
+  else if (treeOwner && !hasQuerySet)
   {
-    if (!queryTree) // No query tree -- map both references and queries.
-    {
-      neighbors.clear();
-      neighbors.resize(querySet.n_cols);
-      distances.clear();
-      distances.resize(querySet.n_cols);
-
-      for (size_t i = 0; i < distances.size(); i++)
-      {
-        // Map distances (copy a column).
-        size_t refMapping = oldFromNewReferences[i];
-        distances[refMapping] = (*distancePtr)[i];
-
-        // Copy each neighbor individually, because we need to map it.
-        neighbors[refMapping].resize(distances[refMapping].size());
-        for (size_t j = 0; j < distances[refMapping].size(); j++)
-        {
-          neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
-        }
-      }
+    neighbors.clear();
+    neighbors.resize(querySet.n_cols);
+    distances.clear();
+    distances.resize(querySet.n_cols);
 
-      // Finished with temporary objects.
-      delete neighborPtr;
-      delete distancePtr;
-    }
-    else // Map only references.
+    for (size_t i = 0; i < distances.size(); i++)
     {
-      neighbors.clear();
-      neighbors.resize(querySet.n_cols);
+      // Map distances (copy a column).
+      size_t refMapping = oldFromNewReferences[i];
+      distances[refMapping] = (*distancePtr)[i];
 
-      // Map indices of neighbors.
-      for (size_t i = 0; i < neighbors.size(); i++)
+      // Copy each neighbor individually, because we need to map it.
+      neighbors[refMapping].resize(distances[refMapping].size());
+      for (size_t j = 0; j < distances[refMapping].size(); j++)
       {
-        neighbors[i].resize((*neighborPtr)[i].size());
-        for (size_t j = 0; j < neighbors[i].size(); j++)
-        {
-          neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
-        }
+        neighbors[refMapping][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
       }
-
-      // Finished with temporary object.
-      delete neighborPtr;
     }
+
+    // Finished with temporary objects.
+    delete neighborPtr;
+    delete distancePtr;
   }
-  else if (ownQueryTree)
+  else if (treeOwner && hasQuerySet && singleMode) // Map only references.
   {
     neighbors.clear();
     neighbors.resize(querySet.n_cols);
-    distances.clear();
-    distances.resize(querySet.n_cols);
 
-    for (size_t i = 0; i < distances.size(); i++)
+    // Map indices of neighbors.
+    for (size_t i = 0; i < neighbors.size(); i++)
     {
-      // Map distances (copy a column).
-      distances[oldFromNewQueries[i]] = (*distancePtr)[i];
-
-      // Map neighbors.
-      neighbors[oldFromNewQueries[i]] = (*neighborPtr)[i];
+      neighbors[i].resize((*neighborPtr)[i].size());
+      for (size_t j = 0; j < neighbors[i].size(); j++)
+      {
+        neighbors[i][j] = oldFromNewReferences[(*neighborPtr)[i][j]];
+      }
     }
 
-    // Finished with temporary objects.
+    // Finished with temporary object.
     delete neighborPtr;
-    delete distancePtr;
   }
 }
 

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules.hpp	Tue Sep 17 22:05:42 2013
@@ -52,20 +52,6 @@
   double Score(const size_t queryIndex, TreeType& referenceNode);
 
   /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
-   */
-  double Score(const size_t queryIndex,
-               TreeType& referenceNode,
-               const double baseCaseResult);
-
-  /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
    * for recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).  This is used when the score has already
@@ -91,20 +77,6 @@
   double Score(TreeType& queryNode, TreeType& referenceNode);
 
   /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryNode Candidate query node to recurse into.
-   * @param referenceNode Candidate reference node to recurse into.
-   * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
-   */
-  double Score(TreeType& queryNode,
-               TreeType& referenceNode,
-               const double baseCaseResult);
-
-  /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
    * for recursion, while DBL_MAX indicates that the node should not be recursed
    * into at all (it should be pruned).  This is used when the score has already
@@ -138,12 +110,16 @@
   //! The instantiated metric.
   MetricType& metric;
 
+  //! The last query index.
+  size_t lastQueryIndex;
+  //! The last reference index.
+  size_t lastReferenceIndex;
+
   //! Add all the points in the given node to the results for the given query
   //! point.  If the base case has already been calculated, we make sure to not
   //! add that to the results twice.
   void AddResult(const size_t queryIndex,
-                 TreeType& referenceNode,
-                 const bool hasBaseCase);
+                 TreeType& referenceNode);
 };
 
 }; // namespace range

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_rules_impl.hpp	Tue Sep 17 22:05:42 2013
@@ -26,7 +26,9 @@
     range(range),
     neighbors(neighbors),
     distances(distances),
-    metric(metric)
+    metric(metric),
+    lastQueryIndex(querySet.n_cols),
+    lastReferenceIndex(referenceSet.n_cols)
 {
   // Nothing to do.
 }
@@ -42,9 +44,20 @@
   if ((&referenceSet == &querySet) && (queryIndex == referenceIndex))
     return 0.0;
 
+  // If we have just performed this base case, don't do it again.
+  if ((lastQueryIndex == queryIndex) && (lastReferenceIndex == referenceIndex))
+    return 0.0; // No value to return... this shouldn't do anything bad.
+
+//  if (queryIndex == 0 && referenceIndex == 0)
+//    Log::Warn << "base case 0 0 called!\n";
+
   const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
       referenceSet.unsafe_col(referenceIndex));
 
+  // Update last indices, so we don't accidentally perform a base case twice.
+  lastQueryIndex = queryIndex;
+  lastReferenceIndex = referenceIndex;
+
   if (range.Contains(distance))
   {
     neighbors[queryIndex].push_back(referenceIndex);
@@ -59,36 +72,43 @@
 double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
                                                      TreeType& referenceNode)
 {
-  const math::Range distances =
-      referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+  // We must get the minimum and maximum distances and store them in this
+  // object.
+  math::Range distances;
 
-  // If the ranges do not overlap, prune this node.
-  if (!distances.Contains(range))
-    return DBL_MAX;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    // In this situation, we calculate the base case.  So we should check to be
+    // sure we haven't already done that.
+    double baseCase;
+    if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+        (referenceNode.Parent() != NULL) &&
+        (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
+    {
+      // If the tree has self-children and this is a self-child, the base case
+      // was already calculated.
+      baseCase = referenceNode.Parent()->Stat().LastDistance();
+      lastQueryIndex = queryIndex;
+      lastReferenceIndex = referenceNode.Point(0);
+    }
+    else
+    {
+      // We must calculate the base case by hand.
+      baseCase = BaseCase(queryIndex, referenceNode.Point(0));
+    }
+
+    // This may be possibly loose for non-ball bound trees.
+    distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance();
+    distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance();
 
-  // In this case, all of the points in the reference node will be part of the
-  // results.
-  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+    // Update last distance calculation.
+    referenceNode.Stat().LastDistance() = baseCase;
+  }
+  else
   {
-    AddResult(queryIndex, referenceNode, false);
-    return DBL_MAX; // We don't need to go any deeper.
+    distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
   }
 
-  // Otherwise the score doesn't matter.  Recursion order is irrelevant in range
-  // search.
-  return 0.0;
-}
-
-//! Single-tree scoring function.
-template<typename MetricType, typename TreeType>
-double RangeSearchRules<MetricType, TreeType>::Score(
-    const size_t queryIndex,
-    TreeType& referenceNode,
-    const double baseCaseResult)
-{
-  const math::Range distances = referenceNode.RangeDistance(
-      querySet.unsafe_col(queryIndex), baseCaseResult);
-
   // If the ranges do not overlap, prune this node.
   if (!distances.Contains(range))
     return DBL_MAX;
@@ -97,12 +117,12 @@
   // results.
   if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
   {
-    AddResult(queryIndex, referenceNode, true);
+    AddResult(queryIndex, referenceNode);
     return DBL_MAX; // We don't need to go any deeper.
   }
 
-  // Otherwise the score doesn't matter.  Recursion order is irrelevant in range
-  // search.
+  // Otherwise the score doesn't matter.  Recursion order is irrelevant in
+  // range search.
   return 0.0;
 }
 
@@ -122,35 +142,90 @@
 double RangeSearchRules<MetricType, TreeType>::Score(TreeType& queryNode,
                                                      TreeType& referenceNode)
 {
-  const math::Range distances = referenceNode.RangeDistance(&queryNode);
-
-  // If the ranges do not overlap, prune this node.
-  if (!distances.Contains(range))
-    return DBL_MAX;
-
-  // In this case, all of the points in the reference node will be part of all
-  // the results for each point in the query node.
-  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
+  math::Range distances;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
   {
-    for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
-      AddResult(queryNode.Descendant(i), referenceNode, false);
-    return DBL_MAX; // We don't need to go any deeper.
+    // It is possible that the base case has already been calculated.
+    double baseCase = 0.0;
+    bool alreadyDone = false;
+    if (tree::TreeTraits<TreeType>::HasSelfChildren)
+    {
+      TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode();
+      TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode();
+
+      // Did the query node's last combination do the base case?
+      if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0)))
+      {
+        baseCase = queryNode.Stat().LastDistance();
+        alreadyDone = true;
+      }
+
+      // Did the reference node's last combination do the base case?
+      if ((lastQuery != NULL) && (queryNode.Point(0) == lastQuery->Point(0)))
+      {
+        baseCase = referenceNode.Stat().LastDistance();
+        alreadyDone = true;
+      }
+
+      // If the query node is a self-child, did the query parent's last
+      // combination do the base case?
+      if ((queryNode.Parent() != NULL) &&
+          (queryNode.Point(0) == queryNode.Parent()->Point(0)))
+      {
+        TreeType* lastParentRef = (TreeType*)
+            queryNode.Parent()->Stat().LastDistanceNode();
+        if ((lastParentRef != NULL) &&
+            (referenceNode.Point(0) == lastParentRef->Point(0)))
+        {
+          baseCase = queryNode.Parent()->Stat().LastDistance();
+          alreadyDone = true;
+        }
+      }
+
+      // If the reference node is a self-child, did the reference parent's last
+      // combination do the base case?
+      if ((referenceNode.Parent() != NULL) &&
+          (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
+      {
+        TreeType* lastQueryRef = (TreeType*)
+            referenceNode.Parent()->Stat().LastDistanceNode();
+        if ((lastQueryRef != NULL) &&
+            (queryNode.Point(0) == lastQueryRef->Point(0)))
+        {
+          baseCase = referenceNode.Parent()->Stat().LastDistance();
+          alreadyDone = true;
+        }
+      }
+    }
+
+    if (!alreadyDone)
+    {
+      // We must calculate the base case.
+      baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+    }
+    else
+    {
+      // Make sure that if BaseCase() is called, we don't duplicate results.
+      lastQueryIndex = queryNode.Point(0);
+      lastReferenceIndex = referenceNode.Point(0);
+    }
+
+    distances.Lo() = baseCase - queryNode.FurthestDescendantDistance()
+        - referenceNode.FurthestDescendantDistance();
+    distances.Hi() = baseCase + queryNode.FurthestDescendantDistance()
+        + referenceNode.FurthestDescendantDistance();
+
+    // Update the last distances performed for the query and reference node.
+    queryNode.Stat().LastDistanceNode() = (void*) &referenceNode;
+    queryNode.Stat().LastDistance() = baseCase;
+    referenceNode.Stat().LastDistanceNode() = (void*) &queryNode;
+    referenceNode.Stat().LastDistance() = baseCase;
+  }
+  else
+  {
+    // Just perform the calculation.
+    distances = referenceNode.RangeDistance(&queryNode);
   }
-
-  // Otherwise the score doesn't matter.  Recursion order is irrelevant in range
-  // search.
-  return 0.0;
-}
-
-//! Dual-tree scoring function.
-template<typename MetricType, typename TreeType>
-double RangeSearchRules<MetricType, TreeType>::Score(
-    TreeType& queryNode,
-    TreeType& referenceNode,
-    const double baseCaseResult)
-{
-  const math::Range distances = referenceNode.RangeDistance(&queryNode,
-      baseCaseResult);
 
   // If the ranges do not overlap, prune this node.
   if (!distances.Contains(range))
@@ -160,11 +235,8 @@
   // the results for each point in the query node.
   if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
   {
-    AddResult(queryNode.Descendant(0), referenceNode, true);
-    // We have not calculated the base case for any descendants other than the
-    // first point.
-    for (size_t i = 1; i < queryNode.NumDescendants(); ++i)
-      AddResult(queryNode.Descendant(i), referenceNode, false);
+    for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
+      AddResult(queryNode.Descendant(i), referenceNode);
     return DBL_MAX; // We don't need to go any deeper.
   }
 
@@ -188,14 +260,15 @@
 //! point.
 template<typename MetricType, typename TreeType>
 void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex,
-                                                       TreeType& referenceNode,
-                                                       const bool hasBaseCase)
+                                                       TreeType& referenceNode)
 {
   // Some types of trees calculate the base case evaluation before Score() is
   // called, so if the base case has already been calculated, then we must avoid
   // adding that point to the results again.
   size_t baseCaseMod = 0;
-  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && hasBaseCase)
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid &&
+      (queryIndex == lastQueryIndex) &&
+      (referenceNode.Point(0) == lastReferenceIndex))
   {
     baseCaseMod = 1;
   }



More information about the mlpack-svn mailing list