[mlpack-git] master, mlpack-1.0.x: rectangle tree traverser (e496a4c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:50:24 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit e496a4c1e473e1f3546134edd1107348cc9710fb
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Wed Jul 2 17:28:23 2014 +0000

    rectangle tree traverser


>---------------------------------------------------------------

e496a4c1e473e1f3546134edd1107348cc9710fb
 src/mlpack/core/tree/CMakeLists.txt                |  6 +-
 src/mlpack/core/tree/rectangle_tree.hpp            |  3 +-
 .../dual_tree_traverser.hpp                        | 38 ++++++------
 .../rectangle_tree/dual_tree_traverser_impl.hpp    | 49 +++++++++++++++
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  9 ++-
 src/mlpack/methods/neighbor_search/allknn_main.cpp | 70 +++++++++++++++++++---
 .../neighbor_search/neighbor_search_impl.hpp       |  2 +-
 .../neighbor_search/neighbor_search_rules_impl.hpp |  2 +-
 8 files changed, 144 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index c526a70..7988285 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -32,8 +32,10 @@ set(SOURCES
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
   rectangle_tree/rectangle_tree_impl.hpp
-  rectangle_tree/rectangle_tree_traverser.hpp
-  rectangle_tree/rectangle_tree_traverser_impl.hpp
+  rectangle_tree/single_tree_traverser.hpp
+  rectangle_tree/single_tree_traverser_impl.hpp
+  rectangle_tree/dual_tree_traverser.hpp
+  rectangle_tree/dual_tree_traverser_impl.hpp
   rectangle_tree/r_tree_split.hpp
   rectangle_tree/r_tree_split_impl.hpp
   rectangle_tree/r_tree_descent_heuristic.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index 743f7f5..f0d156e 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -13,7 +13,8 @@
  */ 
 #include "bounds.hpp"
 #include "rectangle_tree/rectangle_tree.hpp"
-#include "rectangle_tree/rectangle_tree_traverser.hpp"
+#include "rectangle_tree/single_tree_traverser.hpp"
+#include "rectangle_tree/dual_tree_traverser.hpp"
 #include "rectangle_tree/r_tree_split.hpp"
 #include "rectangle_tree/r_tree_descent_heuristic.hpp"
 #include "rectangle_tree/traits.hpp"
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
similarity index 69%
copy from src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
copy to src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 7cd1871..1091224 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -1,28 +1,27 @@
 /**
- * @file dual_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * Defines the DualTreeTraverser for the BinarySpaceTree tree type.  This is a
- * nested class of BinarySpaceTree which traverses two trees in a depth-first
- * manner with a given set of rules which indicate the branches which can be
- * pruned and the order in which to recurse.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+  * @file dual_tree_traverser.hpp
+  * @author Andrew Wells
+  *
+  * A nested class of Rectangle Tree for traversing rectangle type trees
+  * with a given set of rules which indicate the branches to prune and the
+  * order in which to recurse.  This is just here to make it compile.
+  */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
 
 #include <mlpack/core.hpp>
 
-#include "binary_space_tree.hpp"
+#include "rectangle_tree.hpp"
 
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType,
-         typename StatisticType,
-         typename MatType,
-         typename SplitType>
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
 template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
     DualTreeTraverser
 {
  public:
@@ -38,8 +37,8 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
    * @param referenceNode The reference node to be traversed.
    * @param score The score of the current node combination.
    */
-  void Traverse(BinarySpaceTree& queryNode,
-                BinarySpaceTree& referenceNode);
+  void Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode);
 
   //! Get the number of prunes.
   size_t NumPrunes() const { return numPrunes; }
@@ -88,5 +87,4 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 // Include implementation.
 #include "dual_tree_traverser_impl.hpp"
 
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-
+#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..8af988e
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,49 @@
+/**
+  * @file dual_tree_traverser_impl.hpp
+  * @author Andrew Wells
+  *
+  * A class for traversing rectangle type trees with a given set of rules
+  * which indicate the branches to prune and the order in which to recurse.
+  * This is a depth-first traverser.
+  */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include "dual_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+    rule(rule),
+    numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+{
+  //Do nothing.  Just here to prevent warnings.
+  if(queryNode.NumDescendants() > referenceNode.NumDescendants())
+    return;
+  return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e6833b9..e50d372 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -79,10 +79,13 @@ class RectangleTree
   //! So other classes can use TreeType::Mat.
   typedef MatType Mat;
 
-  //! A traverser for rectangle type trees.  See
-  //! rectangle_tree_traverser.hpp for implementation.
+  //! A single traverser for rectangle type trees.  See
+  //! single_tree_traverser.hpp for implementation.
   template<typename RuleType>
-  class RectangleTreeTraverser;
+  class SingleTreeTraverser;
+  //! A dual tree traverser for rectangle type trees.
+  template<typename RuleType>
+  class DualTreeTraverser;
 
   /**
    * Construct this as the root node of a rectangle type tree using the given
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 432c23e..b4dc0cf 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
 PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
-    "(experimental, may be slow.  Currently automatically sets single_mode.).", "R");
+    "(experimental, may be slow.  Currently automatically sets single_mode.).", "T");
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -132,6 +132,7 @@ int main(int argc, char *argv[])
   } else if (!singleMode && CLI::HasParam("r_tree"))  // R_tree requires single mode.
   {
     Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+    singleMode = true;
   }
   
   if (naive)
@@ -269,16 +270,71 @@ int main(int argc, char *argv[])
       // Make sure to notify the user that they are using an r tree.
       Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
       
-      // Build the reference tree.
+      // Because we may construct it differently, we need a pointer.
+      NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >* allknn = NULL;
+
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>
+      refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>*
+      queryTree = NULL; // Empty for now.
+
+      Timer::Stop("tree_building");
+
+      if (CLI::GetParam<string>("query_file") != "")
+      {
+	Log::Info << "Loaded query data from '" << queryFile << "' ("
+	    << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
         RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-                      tree::RTreeDescentHeuristic,
-                      NeighborSearchStat<NearestNeighborSort>,
-                      arma::mat>
-        refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+	  	      tree::RTreeDescentHeuristic,
+  		      NeighborSearchStat<NearestNeighborSort>,
+  		      arma::mat> >(&refTree, queryTree,
+          referenceData, queryData, singleMode);
+      } else
+      {
+	      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >(&refTree,
+        referenceData, singleMode);
+      }
+      Log::Info << "Tree built." << endl;
+      
+      arma::mat distancesOut;
+      arma::Mat<size_t> neighborsOut;
+
+      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+      allknn->Search(k, neighborsOut, distancesOut);
+
+      Log::Info << "Neighbors computed." << endl;
+
+
+      delete allknn;
+            
+      
+      // Build the reference tree.
+      Log::Info << "Building reference tree..." << endl;
+      Timer::Start("tree_building");
+
       Timer::Stop("tree_building");
-      std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
     }
   }
   else // Cover trees.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index dac46c0..bcada1f 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -60,7 +60,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
     metric(metric)
 {
   // C++11 will allow us to call out to other constructors so we can avoid this
-  // copypasta problem.
+  // copy/paste problem.
 
   // We'll time tree building, but only if we are building trees.
   Timer::Start("tree_building");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index c96b10a..5b4eca9 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -56,7 +56,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
   ++baseCases;
 
   // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the poto insert it into.
+  // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
   arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
   const size_t insertPosition = SortPolicy::SortDistance(queryDist,



More information about the mlpack-git mailing list