[mlpack-svn] r12664 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:26:02 EDT 2012


Author: rcurtin
Date: 2012-05-09 16:26:02 -0400 (Wed, 09 May 2012)
New Revision: 12664

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add implementation of rules for dual-tree search.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-05-09 20:25:47 UTC (rev 12663)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-05-09 20:26:02 UTC (rev 12664)
@@ -26,6 +26,19 @@
   // For single-tree traversal.
   bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
 
+  // For dual-tree traversal.
+  bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
+
+  // Get the order of points to recurse to.
+  void RecursionOrder(TreeType& queryNode,
+                      TreeType& referenceNode,
+                      arma::Mat<size_t>& recursionOrder,
+                      bool& queryRecurse,
+                      bool& referenceRecurse);
+
+  // Update bounds.  Needs a better name.
+  void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-05-09 20:25:47 UTC (rev 12663)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-05-09 20:26:02 UTC (rev 12664)
@@ -69,6 +69,202 @@
     return true; // There cannot be anything better in this node.  So prune it.
 }
 
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::CanPrune(
+    TreeType& queryNode,
+    TreeType& referenceNode)
+{
+  const double distance = SortPolicy::BestNodeToNodeDistance(
+      &queryNode, &referenceNode);
+  const double bestDistance = queryNode.Stat().Bound();
+
+  if (SortPolicy::IsBetter(distance, bestDistance))
+    return false; // Can't prune.
+  else
+    return true;
+}
+
+// Return the order in which we should recurse.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline void NeighborSearchRules<
+    SortPolicy,
+    MetricType,
+    TreeType>::
+RecursionOrder(TreeType& queryNode,
+               TreeType& referenceNode,
+               arma::Mat<size_t>& recursionOrder,
+               bool& queryRecurse,
+               bool& referenceRecurse)
+{
+  queryRecurse = !(queryNode.IsLeaf());
+  referenceRecurse = !(referenceNode.IsLeaf());
+
+  if (queryRecurse && !referenceRecurse)
+  {
+    // We only need to recurse into the query children.  Therefore, the elements
+    // in row 1 can be ignored.
+    recursionOrder.set_size(2, queryNode.NumChildren());
+    arma::vec recursionDistances(queryNode.NumChildren());
+    recursionDistances.fill(SortPolicy::WorstDistance());
+    size_t children = 0; // Number of children to recurse to.
+
+    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+    {
+      double distance = SortPolicy::BestNodeToNodeDistance(&queryNode.Child(i),
+          &referenceNode);
+
+      // Find where to insert.
+      size_t insertPosition;
+      for (insertPosition = 0; insertPosition < children; ++insertPosition)
+        if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
+          break;
+
+      // Now perform the actual insertion.
+      if ((children - insertPosition) > 0)
+      {
+        memmove(recursionDistances.memptr() + insertPosition + 1,
+                recursionDistances.memptr() + insertPosition,
+                sizeof(double) * (children - insertPosition));
+        memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+                recursionOrder.memptr() + (insertPosition * 2),
+                sizeof(size_t) * (children - insertPosition) * 2);
+      }
+
+      // Insert.
+      recursionDistances[insertPosition] = distance;
+      recursionOrder(0, insertPosition) = i;
+      ++children;
+    }
+
+    // Strip extra columns.
+    if (children < queryNode.NumChildren())
+      recursionOrder.shed_cols(children, queryNode.NumChildren() - 1);
+  }
+  else if (!queryRecurse && referenceRecurse)
+  {
+    // We only need to recurse into the reference children.  Therefore, the
+    // elements in row 0 can be ignored.
+    recursionOrder.set_size(2, referenceNode.NumChildren());
+    arma::vec recursionDistances(referenceNode.NumChildren());
+    recursionDistances.fill(SortPolicy::WorstDistance());
+    size_t children = 0; // Number of children to recurse into.
+
+    for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+    {
+      double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+          &referenceNode.Child(i));
+
+      // Find where to insert.
+      size_t insertPosition;
+      for (insertPosition = 0; insertPosition < children; ++insertPosition)
+        if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
+          break;
+
+      // Now perform the actual insertion.
+      if ((children - insertPosition) > 0)
+      {
+        memmove(recursionDistances.memptr() + insertPosition + 1,
+                recursionDistances.memptr() + insertPosition,
+                sizeof(double) * (children - insertPosition));
+        memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+                recursionOrder.memptr() + (insertPosition * 2),
+                sizeof(size_t) * (children - insertPosition) * 2);
+      }
+
+      // Insert.
+      recursionDistances[insertPosition] = distance;
+      recursionOrder(1, insertPosition) = i;
+      ++children;
+    }
+
+    // Strip extra columns.
+    if (children < referenceNode.NumChildren())
+      recursionOrder.shed_cols(children, referenceNode.NumChildren() - 1);
+  }
+  else if (queryRecurse && referenceRecurse)
+  {
+    // We need to recurse into both children.
+    const size_t maxChildren = referenceNode.NumChildren() *
+        queryNode.NumChildren();
+    recursionOrder.set_size(2, maxChildren);
+    arma::vec recursionDistances(maxChildren);
+    recursionDistances.fill(SortPolicy::WorstDistance());
+    size_t children = 0; // Number of children to recurse into.
+
+    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+    {
+      // Check if we should even continue this direction.
+      if (CanPrune(queryNode.Child(i), referenceNode))
+        continue; // Don't go this way.
+
+      for (size_t j = 0; j < referenceNode.NumChildren(); ++j)
+      {
+        double distance = SortPolicy::BestNodeToNodeDistance(
+            &queryNode.Child(i), &referenceNode.Child(j));
+
+        // Find where to insert.
+        size_t insertPosition;
+        for (insertPosition = 0; insertPosition < children; ++insertPosition)
+          if (SortPolicy::IsBetter(distance,
+              recursionDistances[insertPosition]))
+            break;
+
+        // Move things to prepare for insertion.
+        if ((children - insertPosition) > 0)
+        {
+          memmove(recursionDistances.memptr() + insertPosition + 1,
+                  recursionDistances.memptr() + insertPosition,
+                  sizeof(double) * (children - insertPosition));
+          memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+                  recursionOrder.memptr() + (insertPosition * 2),
+                  sizeof(size_t) * (children - insertPosition) * 2);
+        }
+
+        // Insert.
+        recursionDistances[insertPosition] = distance;
+        recursionOrder(0, insertPosition) = i;
+        recursionOrder(1, insertPosition) = j;
+        ++children;
+      }
+    }
+
+    // Strip extra columns.
+    if (children < maxChildren)
+      recursionOrder.shed_cols(children, maxChildren - 1);
+  }
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<
+    SortPolicy,
+    MetricType,
+    TreeType>::
+UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
+{
+  // Find the worst distance that the children found (including any points), and
+  // update the bound accordingly.
+  double worstDistance = SortPolicy::BestDistance();
+
+  // First look through children nodes.
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
+      worstDistance = queryNode.Child(i).Stat().Bound();
+  }
+
+  // Now look through children points.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+  {
+    if (SortPolicy::IsBetter(worstDistance,
+        distances(distances.n_rows - 1, queryNode.Point(i))))
+      worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
+  }
+
+  // Take the worst distance from all of these, and update our bound to reflect
+  // that.
+  queryNode.Stat().Bound() = worstDistance;
+}
+
 /**
  * Helper function to insert a point into the neighbors and distances matrices.
  *




More information about the mlpack-svn mailing list