[mlpack-svn] r16737 - in mlpack/trunk/src/mlpack: core/tree core/tree/rectangle_tree methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 2 13:28:24 EDT 2014


Author: andrewmw94
Date: Wed Jul  2 13:28:23 2014
New Revision: 16737

Log:
rectangle tree traverser

Added:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
   mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	Wed Jul  2 13:28:23 2014
@@ -32,8 +32,10 @@
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
   rectangle_tree/rectangle_tree_impl.hpp
-  rectangle_tree/rectangle_tree_traverser.hpp
-  rectangle_tree/rectangle_tree_traverser_impl.hpp
+  rectangle_tree/single_tree_traverser.hpp
+  rectangle_tree/single_tree_traverser_impl.hpp
+  rectangle_tree/dual_tree_traverser.hpp
+  rectangle_tree/dual_tree_traverser_impl.hpp
   rectangle_tree/r_tree_split.hpp
   rectangle_tree/r_tree_split_impl.hpp
   rectangle_tree/r_tree_descent_heuristic.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree.hpp	Wed Jul  2 13:28:23 2014
@@ -13,7 +13,8 @@
  */ 
 #include "bounds.hpp"
 #include "rectangle_tree/rectangle_tree.hpp"
-#include "rectangle_tree/rectangle_tree_traverser.hpp"
+#include "rectangle_tree/single_tree_traverser.hpp"
+#include "rectangle_tree/dual_tree_traverser.hpp"
 #include "rectangle_tree/r_tree_split.hpp"
 #include "rectangle_tree/r_tree_descent_heuristic.hpp"
 #include "rectangle_tree/traits.hpp"

Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp	Wed Jul  2 13:28:23 2014
@@ -0,0 +1,90 @@
+/**
+  * @file dual_tree_traverser.hpp
+  * @author Andrew Wells
+  *
+  * A nested class of Rectangle Tree for traversing rectangle type trees
+  * with a given set of rules which indicate the branches to prune and the
+  * order in which to recurse.  This is just here to make it compile.
+  */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+    DualTreeTraverser
+{
+ public:
+  /**
+   * Instantiate the dual-tree traverser with the given rule set.
+   */
+  DualTreeTraverser(RuleType& rule);
+
+  /**
+   * Traverse the two trees.  This does not reset the number of prunes.
+   *
+   * @param queryNode The query node to be traversed.
+   * @param referenceNode The reference node to be traversed.
+   * @param score The score of the current node combination.
+   */
+  void Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode);
+
+  //! Get the number of prunes.
+  size_t NumPrunes() const { return numPrunes; }
+  //! Modify the number of prunes.
+  size_t& NumPrunes() { return numPrunes; }
+
+  //! Get the number of visited combinations.
+  size_t NumVisited() const { return numVisited; }
+  //! Modify the number of visited combinations.
+  size_t& NumVisited() { return numVisited; }
+
+  //! Get the number of times a node combination was scored.
+  size_t NumScores() const { return numScores; }
+  //! Modify the number of times a node combination was scored.
+  size_t& NumScores() { return numScores; }
+
+  //! Get the number of times a base case was calculated.
+  size_t NumBaseCases() const { return numBaseCases; }
+  //! Modify the number of times a base case was calculated.
+  size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+  //! Reference to the rules with which the trees will be traversed.
+  RuleType& rule;
+
+  //! The number of prunes.
+  size_t numPrunes;
+
+  //! The number of node combinations that have been visited during traversal.
+  size_t numVisited;
+
+  //! The number of times a node combination was scored.
+  size_t numScores;
+
+  //! The number of times a base case was calculated.
+  size_t numBaseCases;
+
+  //! Traversal information, held in the class so that it isn't continually
+  //! being reallocated.
+  typename RuleType::TraversalInfoType traversalInfo;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp	Wed Jul  2 13:28:23 2014
@@ -0,0 +1,49 @@
+/**
+  * @file dual_tree_traverser_impl.hpp
+  * @author Andrew Wells
+  *
+  * A class for traversing rectangle type trees with a given set of rules
+  * which indicate the branches to prune and the order in which to recurse.
+  * This is a depth-first traverser.
+  */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include "dual_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+    rule(rule),
+    numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+		RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+{
+  //Do nothing.  Just here to prevent warnings.
+  if(queryNode.NumDescendants() > referenceNode.NumDescendants())
+    return;
+  return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file

Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp	Wed Jul  2 13:28:23 2014
@@ -79,10 +79,13 @@
   //! So other classes can use TreeType::Mat.
   typedef MatType Mat;
 
-  //! A traverser for rectangle type trees.  See
-  //! rectangle_tree_traverser.hpp for implementation.
+  //! A single traverser for rectangle type trees.  See
+  //! single_tree_traverser.hpp for implementation.
   template<typename RuleType>
-  class RectangleTreeTraverser;
+  class SingleTreeTraverser;
+  //! A dual tree traverser for rectangle type trees.
+  template<typename RuleType>
+  class DualTreeTraverser;
 
   /**
    * Construct this as the root node of a rectangle type tree using the given

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	Wed Jul  2 13:28:23 2014
@@ -58,7 +58,7 @@
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
 PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
-    "(experimental, may be slow.  Currently automatically sets single_mode.).", "R");
+    "(experimental, may be slow.  Currently automatically sets single_mode.).", "T");
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -132,6 +132,7 @@
   } else if (!singleMode && CLI::HasParam("r_tree"))  // R_tree requires single mode.
   {
     Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+    singleMode = true;
   }
   
   if (naive)
@@ -269,16 +270,71 @@
       // Make sure to notify the user that they are using an r tree.
       Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
       
-      // Build the reference tree.
+      // Because we may construct it differently, we need a pointer.
+      NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >* allknn = NULL;
+
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>
+      refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat>*
+      queryTree = NULL; // Empty for now.
+
+      Timer::Stop("tree_building");
+
+      if (CLI::GetParam<string>("query_file") != "")
+      {
+	Log::Info << "Loaded query data from '" << queryFile << "' ("
+	    << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
         RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-                      tree::RTreeDescentHeuristic,
-                      NeighborSearchStat<NearestNeighborSort>,
-                      arma::mat>
-        refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+	  	      tree::RTreeDescentHeuristic,
+  		      NeighborSearchStat<NearestNeighborSort>,
+  		      arma::mat> >(&refTree, queryTree,
+          referenceData, queryData, singleMode);
+      } else
+      {
+	      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+		    tree::RTreeDescentHeuristic,
+		    NeighborSearchStat<NearestNeighborSort>,
+		    arma::mat> >(&refTree,
+        referenceData, singleMode);
+      }
+      Log::Info << "Tree built." << endl;
+      
+      arma::mat distancesOut;
+      arma::Mat<size_t> neighborsOut;
+
+      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+      allknn->Search(k, neighborsOut, distancesOut);
+
+      Log::Info << "Neighbors computed." << endl;
+
+
+      delete allknn;
+            
+      
+      // Build the reference tree.
+      Log::Info << "Building reference tree..." << endl;
+      Timer::Start("tree_building");
+
       Timer::Stop("tree_building");
-      std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
     }
   }
   else // Cover trees.

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp	Wed Jul  2 13:28:23 2014
@@ -60,7 +60,7 @@
     metric(metric)
 {
   // C++11 will allow us to call out to other constructors so we can avoid this
-  // copypasta problem.
+  // copy/paste problem.
 
   // We'll time tree building, but only if we are building trees.
   Timer::Start("tree_building");

Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp	Wed Jul  2 13:28:23 2014
@@ -56,7 +56,7 @@
   ++baseCases;
 
   // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the poto insert it into.
+  // SortDistance() function will give us the position to insert it into.
   arma::vec queryDist = distances.unsafe_col(queryIndex);
   arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
   const size_t insertPosition = SortPolicy::SortDistance(queryDist,



More information about the mlpack-svn mailing list