[mlpack-svn] r13335 - in mlpack/trunk/src/mlpack/methods: emst kmeans maxip neighbor_search range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Aug 4 23:01:42 EDT 2012


Author: rcurtin
Date: 2012-08-04 23:01:42 -0400 (Sat, 04 Aug 2012)
New Revision: 13335

Modified:
   mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
   mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
   mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
   mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
Log:
Change include files and APIs to use the new tree traverser setup.


Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -23,10 +23,10 @@
 #include "edge_pair.hpp"
 
 #include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
 namespace mlpack {
 namespace emst /** Euclidean Minimum Spanning Trees. */ {
 

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -13,6 +13,8 @@
 #include "random_partition.hpp"
 #include "max_variance_new_cluster.hpp"
 
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
 namespace mlpack {
 namespace kmeans /** K-Means clustering. */ {
 

Modified: mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/kmeans/kmeans_impl.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -7,8 +7,6 @@
  */
 #include "kmeans.hpp"
 
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <mlpack/core/tree/hrectbound.hpp>
 #include <mlpack/core/tree/mrkd_statistic.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
 

Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,7 +9,7 @@
 
 #include <mlpack/core.hpp>
 #include "ip_metric.hpp"
-#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
 
 namespace mlpack {
 namespace maxip {

Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -13,6 +13,7 @@
 #include "max_ip_rules.hpp"
 
 #include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <queue>
 
 namespace mlpack {
 namespace maxip {

Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,7 +9,7 @@
 #define __MLPACK_METHODS_MAXIP_MAX_IP_RULES_HPP
 
 #include <mlpack/core.hpp>
-#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
 
 namespace mlpack {
 namespace maxip {

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,11 +9,11 @@
 #define __MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
 
 #include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
 #include <vector>
 #include <string>
 
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
 #include <mlpack/core/metrics/lmetric.hpp>
 #include "sort_policies/nearest_neighbor_sort.hpp"
 

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -10,9 +10,6 @@
 
 #include <mlpack/core.hpp>
 
-#include <mlpack/core/tree/traversers/single_tree_depth_first_traverser.hpp>
-#include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
-#include <mlpack/core/tree/traversers/dual_tree_depth_first_traverser.hpp>
 #include "neighbor_search_rules.hpp"
 
 using namespace mlpack::neighbor;
@@ -180,13 +177,11 @@
   if (singleMode)
   {
     // Create the helper object for the tree traversal.
-    NeighborSearchRules<SortPolicy, MetricType, TreeType> rules(referenceSet,
-        querySet, *neighborPtr, *distancePtr, metric);
+    typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
 
     // Create the traverser.
-    typename TreeType::template PreferredTraverser<
-      NeighborSearchRules<SortPolicy, MetricType, TreeType> >::Type
-      traverser(rules);
+    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -196,17 +191,12 @@
   }
   else // Dual-tree recursion.
   {
-    // Breaking a lot of design rules here...
-    typedef typename TreeType::template PreferredRules<SortPolicy, MetricType,
-        TreeType>::Type RuleType;
-
+    // Create the helper object for the tree traversal.
+    typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
     RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
 
-    typedef typename TreeType::template PreferredDualTraverser<RuleType>::Type
-        TraverserType;
+    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 
-    TraverserType traverser(rules);
-
     if (queryTree)
       traverser.Traverse(*queryTree, *referenceTree);
     else

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -30,11 +30,8 @@
   bool CanPrune(TreeType& queryNode, TreeType& referenceNode);
 
   // Get the order of points to recurse to.
-  void RecursionOrder(TreeType& queryNode,
-                      TreeType& referenceNode,
-                      arma::Mat<size_t>& recursionOrder,
-                      bool& queryRecurse,
-                      bool& referenceRecurse);
+  bool LeftFirst(const size_t queryIndex, TreeType& referenceNode);
+  bool LeftFirst(TreeType& staticNode, TreeType& recurseNode);
 
   // Update bounds.  Needs a better name.
   void UpdateAfterRecursion(TreeType& queryNode, TreeType& referenceNode);

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -84,154 +84,34 @@
     return true;
 }
 
-// Return the order in which we should recurse.
 template<typename SortPolicy, typename MetricType, typename TreeType>
-inline void NeighborSearchRules<
-    SortPolicy,
-    MetricType,
-    TreeType>::
-RecursionOrder(TreeType& queryNode,
-               TreeType& referenceNode,
-               arma::Mat<size_t>& recursionOrder,
-               bool& queryRecurse,
-               bool& referenceRecurse)
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
+    const size_t queryIndex,
+    TreeType& referenceNode)
 {
-  queryRecurse = !(queryNode.IsLeaf());
-  referenceRecurse = !(referenceNode.IsLeaf());
+  // This ends up with us calculating this distance twice (it will be done again
+  // in CanPrune()), but because single-neighbors recursion is not the most
+  // important in this method, we can let it slide.
+  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+  const double leftDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+      referenceNode.Left());
+  const double rightDistance = SortPolicy::BestPointToNodeDistance(queryPoint,
+      referenceNode.Right());
 
-  if (queryRecurse && !referenceRecurse)
-  {
-    // We only need to recurse into the query children.  Therefore, the elements
-    // in row 1 can be ignored.
-    recursionOrder.set_size(2, queryNode.NumChildren());
-    arma::vec recursionDistances(queryNode.NumChildren());
-    recursionDistances.fill(SortPolicy::WorstDistance());
-    size_t children = 0; // Number of children to recurse to.
+  return SortPolicy::IsBetter(leftDistance, rightDistance);
+}
 
-    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-    {
-      double distance = SortPolicy::BestNodeToNodeDistance(&queryNode.Child(i),
-          &referenceNode);
+template<typename SortPolicy, typename MetricType, typename TreeType>
+inline bool NeighborSearchRules<SortPolicy, MetricType, TreeType>::LeftFirst(
+    TreeType& staticNode,
+    TreeType& recurseNode)
+{
+  const double leftDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
+      recurseNode.Left());
+  const double rightDistance = SortPolicy::BestNodeToNodeDistance(&staticNode,
+      recurseNode.Right());
 
-      // Find where to insert.
-      size_t insertPosition;
-      for (insertPosition = 0; insertPosition < children; ++insertPosition)
-        if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
-          break;
-
-      // Now perform the actual insertion.
-      if ((children - insertPosition) > 0)
-      {
-        memmove(recursionDistances.memptr() + insertPosition + 1,
-                recursionDistances.memptr() + insertPosition,
-                sizeof(double) * (children - insertPosition));
-        memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
-                recursionOrder.memptr() + (insertPosition * 2),
-                sizeof(size_t) * (children - insertPosition) * 2);
-      }
-
-      // Insert.
-      recursionDistances[insertPosition] = distance;
-      recursionOrder(0, insertPosition) = i;
-      ++children;
-    }
-
-    // Strip extra columns.
-    if (children < queryNode.NumChildren())
-      recursionOrder.shed_cols(children, queryNode.NumChildren() - 1);
-  }
-  else if (!queryRecurse && referenceRecurse)
-  {
-    // We only need to recurse into the reference children.  Therefore, the
-    // elements in row 0 can be ignored.
-    recursionOrder.set_size(2, referenceNode.NumChildren());
-    arma::vec recursionDistances(referenceNode.NumChildren());
-    recursionDistances.fill(SortPolicy::WorstDistance());
-    size_t children = 0; // Number of children to recurse into.
-
-    for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
-    {
-      double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
-          &referenceNode.Child(i));
-
-      // Find where to insert.
-      size_t insertPosition;
-      for (insertPosition = 0; insertPosition < children; ++insertPosition)
-        if (SortPolicy::IsBetter(distance, recursionDistances[insertPosition]))
-          break;
-
-      // Now perform the actual insertion.
-      if ((children - insertPosition) > 0)
-      {
-        memmove(recursionDistances.memptr() + insertPosition + 1,
-                recursionDistances.memptr() + insertPosition,
-                sizeof(double) * (children - insertPosition));
-        memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
-                recursionOrder.memptr() + (insertPosition * 2),
-                sizeof(size_t) * (children - insertPosition) * 2);
-      }
-
-      // Insert.
-      recursionDistances[insertPosition] = distance;
-      recursionOrder(1, insertPosition) = i;
-      ++children;
-    }
-
-    // Strip extra columns.
-    if (children < referenceNode.NumChildren())
-      recursionOrder.shed_cols(children, referenceNode.NumChildren() - 1);
-  }
-  else if (queryRecurse && referenceRecurse)
-  {
-    // We need to recurse into both children.
-    const size_t maxChildren = referenceNode.NumChildren() *
-        queryNode.NumChildren();
-    recursionOrder.set_size(2, maxChildren);
-    arma::vec recursionDistances(maxChildren);
-    recursionDistances.fill(SortPolicy::WorstDistance());
-    size_t children = 0; // Number of children to recurse into.
-
-    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
-    {
-      // Check if we should even continue this direction.
-      if (CanPrune(queryNode.Child(i), referenceNode))
-        continue; // Don't go this way.
-
-      for (size_t j = 0; j < referenceNode.NumChildren(); ++j)
-      {
-        double distance = SortPolicy::BestNodeToNodeDistance(
-            &queryNode.Child(i), &referenceNode.Child(j));
-
-        // Find where to insert.
-        size_t insertPosition;
-        for (insertPosition = 0; insertPosition < children; ++insertPosition)
-          if (SortPolicy::IsBetter(distance,
-              recursionDistances[insertPosition]))
-            break;
-
-        // Move things to prepare for insertion.
-        if ((children - insertPosition) > 0)
-        {
-          memmove(recursionDistances.memptr() + insertPosition + 1,
-                  recursionDistances.memptr() + insertPosition,
-                  sizeof(double) * (children - insertPosition));
-          memmove(recursionOrder.memptr() + (insertPosition + 1) * 2,
-                  recursionOrder.memptr() + (insertPosition * 2),
-                  sizeof(size_t) * (children - insertPosition) * 2);
-        }
-
-        // Insert.
-        recursionDistances[insertPosition] = distance;
-        recursionOrder(0, insertPosition) = i;
-        recursionOrder(1, insertPosition) = j;
-        ++children;
-      }
-    }
-
-    // Strip extra columns.
-    if (children < maxChildren)
-      recursionOrder.shed_cols(children, maxChildren - 1);
-  }
+  return SortPolicy::IsBetter(leftDistance, rightDistance);
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2012-08-05 03:00:55 UTC (rev 13334)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search.hpp	2012-08-05 03:01:42 UTC (rev 13335)
@@ -9,11 +9,11 @@
 #define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
 
 #include <mlpack/core.hpp>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/binary_space_tree.hpp>
 
 #include <mlpack/core/metrics/lmetric.hpp>
 
+#include <mlpack/core/tree/binary_space_tree.hpp>
+
 namespace mlpack {
 namespace range /** Range-search routines. */ {
 




More information about the mlpack-svn mailing list