[mlpack-svn] r12664 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:26:02 EDT 2012
Author: rcurtin
Date: 2012-05-09 16:26:02 -0400 (Wed, 09 May 2012)
New Revision: 12664
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Add implementation of rules for dual-tree search.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-05-09 20:25:47 UTC (rev 12663)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp 2012-05-09 20:26:02 UTC (rev 12664)
@@ -26,6 +26,19 @@
// For single-tree traversal.
bool CanPrune(const size_t queryIndex, TreeType& referenceNode);
+ // For dual-tree traversal.
+ bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
+
+ // Get the order of points to recurse to.
+ void RecursionOrder(TreeType& queryNode,
+ TreeType& referenceNode,
+ arma::Mat<size_t>& recursionOrder,
+ bool& queryRecurse,
+ bool& referenceRecurse);
+
+ // Update bounds. Needs a better name.
+ void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);
+
private:
//! The reference set.
const arma::mat& referenceSet;
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-05-09 20:25:47 UTC (rev 12663)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2012-05-09 20:26:02 UTC (rev 12664)
@@ -69,6 +69,202 @@
return true; // There cannot be anything better in this node. So prune it.
}
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::CanPrune(
+ TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ const double distance = SortPolicy::BestNodeToNodeDistance(
+ &queryNode, &referenceNode);
+ const double bestDistance = queryNode.Stat().Bound();
+
+ if (SortPolicy::IsBetter(distance, bestDistance))
+ return false; // Can't prune.
+ else
+ return true;
+}
+
+// Return the order in which we should recurse.
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline void NeighborSearchRules<
+ SortPolicy,
+ MetricType,
+ TreeType>::
+RecursionOrder(TreeType& queryNode,
+ TreeType& referenceNode,
+ arma::Mat<size_t>& recursionOrder,
+ bool& queryRecurse,
+ bool& referenceRecurse)
+{
+ queryRecurse = !(queryNode.IsLeaf());
+ referenceRecurse = !(referenceNode.IsLeaf());
+
+ if (queryRecurse && !referenceRecurse)
+ {
+ // We only need to recurse into the query children. Therefore, the elements
+ // in row 1 can be ignored.
+ recursionOrder.set_size(2, queryNode.NumChildren());
+ arma::vec recursionDistances(queryNode.NumChildren());
+ recursionDistances.fill(SortPolicy::WorstDistance());
+ size_t children = 0; // Number of children to recurse to.
+
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ double distance = SortPolicy::BestNodeToNodeDistance(&queryNode.Child(i),
+ &referenceNode);
+
+ // Find where to insert.
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < children; ++insertPosition)
+ if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
+ break;
+
+ // Now perform the actual insertion.
+ if ((children - insertPosition) > 0)
+ {
+ memmove(recursionDistances.memptr() + insertPosition + 1,
+ recursionDistances.memptr() + insertPosition,
+ sizeof(double) * (children - insertPosition));
+ memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+ recursionOrder.memptr() + (insertPosition * 2),
+ sizeof(size_t) * (children - insertPosition) * 2);
+ }
+
+ // Insert.
+ recursionDistances[insertPosition] = distance;
+ recursionOrder(0, insertPosition) = i;
+ ++children;
+ }
+
+ // Strip extra columns.
+ if (children < queryNode.NumChildren())
+ recursionOrder.shed_cols(children, queryNode.NumChildren() - 1);
+ }
+ else if (!queryRecurse && referenceRecurse)
+ {
+ // We only need to recurse into the reference children. Therefore, the
+ // elements in row 0 can be ignored.
+ recursionOrder.set_size(2, referenceNode.NumChildren());
+ arma::vec recursionDistances(referenceNode.NumChildren());
+ recursionDistances.fill(SortPolicy::WorstDistance());
+ size_t children = 0; // Number of children to recurse into.
+
+ for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+ {
+ double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
+ &referenceNode.Child(i));
+
+ // Find where to insert.
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < children; ++insertPosition)
+ if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
+ break;
+
+ // Now perform the actual insertion.
+ if ((children - insertPosition) > 0)
+ {
+ memmove(recursionDistances.memptr() + insertPosition + 1,
+ recursionDistances.memptr() + insertPosition,
+ sizeof(double) * (children - insertPosition));
+ memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+ recursionOrder.memptr() + (insertPosition * 2),
+ sizeof(size_t) * (children - insertPosition) * 2);
+ }
+
+ // Insert.
+ recursionDistances[insertPosition] = distance;
+ recursionOrder(1, insertPosition) = i;
+ ++children;
+ }
+
+ // Strip extra columns.
+ if (children < referenceNode.NumChildren())
+ recursionOrder.shed_cols(children, referenceNode.NumChildren() - 1);
+ }
+ else if (queryRecurse && referenceRecurse)
+ {
+ // We need to recurse into both children.
+ const size_t maxChildren = referenceNode.NumChildren() *
+ queryNode.NumChildren();
+ recursionOrder.set_size(2, maxChildren);
+ arma::vec recursionDistances(maxChildren);
+ recursionDistances.fill(SortPolicy::WorstDistance());
+ size_t children = 0; // Number of children to recurse into.
+
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ // Check if we should even continue this direction.
+ if (CanPrune(queryNode.Child(i), referenceNode))
+ continue; // Don't go this way.
+
+ for (size_t j = 0; j < referenceNode.NumChildren(); ++j)
+ {
+ double distance = SortPolicy::BestNodeToNodeDistance(
+ &queryNode.Child(i), &referenceNode.Child(j));
+
+ // Find where to insert.
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < children; ++insertPosition)
+ if (SortPolicy::IsBetter(distance,
+ recursionDistances[insertPosition]))
+ break;
+
+ // Move things to prepare for insertion.
+ if ((children - insertPosition) > 0)
+ {
+ memmove(recursionDistances.memptr() + insertPosition + 1,
+ recursionDistances.memptr() + insertPosition,
+ sizeof(double) * (children - insertPosition));
+ memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
+ recursionOrder.memptr() + (insertPosition * 2),
+ sizeof(size_t) * (children - insertPosition) * 2);
+ }
+
+ // Insert.
+ recursionDistances[insertPosition] = distance;
+ recursionOrder(0, insertPosition) = i;
+ recursionOrder(1, insertPosition) = j;
+ ++children;
+ }
+ }
+
+ // Strip extra columns.
+ if (children < maxChildren)
+ recursionOrder.shed_cols(children, maxChildren - 1);
+ }
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<
+ SortPolicy,
+ MetricType,
+ TreeType>::
+UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
+{
+ // Find the worst distance that the children found (including any points), and
+ // update the bound accordingly.
+ double worstDistance = SortPolicy::BestDistance();
+
+ // First look through children nodes.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
+ worstDistance = queryNode.Child(i).Stat().Bound();
+ }
+
+ // Now look through children points.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
+ {
+ if (SortPolicy::IsBetter(worstDistance,
+ distances(distances.n_rows - 1, queryNode.Point(i))))
+ worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
+ }
+
+ // Take the worst distance from all of these, and update our bound to reflect
+ // that.
+ queryNode.Stat().Bound() = worstDistance;
+}
+
/**
* Helper function to insert a point into the neighbors and distances matrices.
*
More information about the mlpack-svn
mailing list