[mlpack-git] master, mlpack-1.0.x: rectangle tree traverser (e496a4c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:50:24 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit e496a4c1e473e1f3546134edd1107348cc9710fb
Author: andrewmw94 <andrewmw94 at gmail.com>
Date: Wed Jul 2 17:28:23 2014 +0000
rectangle tree traverser
>---------------------------------------------------------------
e496a4c1e473e1f3546134edd1107348cc9710fb
src/mlpack/core/tree/CMakeLists.txt | 6 +-
src/mlpack/core/tree/rectangle_tree.hpp | 3 +-
.../dual_tree_traverser.hpp | 38 ++++++------
.../rectangle_tree/dual_tree_traverser_impl.hpp | 49 +++++++++++++++
.../core/tree/rectangle_tree/rectangle_tree.hpp | 9 ++-
src/mlpack/methods/neighbor_search/allknn_main.cpp | 70 +++++++++++++++++++---
.../neighbor_search/neighbor_search_impl.hpp | 2 +-
.../neighbor_search/neighbor_search_rules_impl.hpp | 2 +-
8 files changed, 144 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index c526a70..7988285 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -32,8 +32,10 @@ set(SOURCES
rectangle_tree.hpp
rectangle_tree/rectangle_tree.hpp
rectangle_tree/rectangle_tree_impl.hpp
- rectangle_tree/rectangle_tree_traverser.hpp
- rectangle_tree/rectangle_tree_traverser_impl.hpp
+ rectangle_tree/single_tree_traverser.hpp
+ rectangle_tree/single_tree_traverser_impl.hpp
+ rectangle_tree/dual_tree_traverser.hpp
+ rectangle_tree/dual_tree_traverser_impl.hpp
rectangle_tree/r_tree_split.hpp
rectangle_tree/r_tree_split_impl.hpp
rectangle_tree/r_tree_descent_heuristic.hpp
diff --git a/src/mlpack/core/tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree.hpp
index 743f7f5..f0d156e 100644
--- a/src/mlpack/core/tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree.hpp
@@ -13,7 +13,8 @@
*/
#include "bounds.hpp"
#include "rectangle_tree/rectangle_tree.hpp"
-#include "rectangle_tree/rectangle_tree_traverser.hpp"
+#include "rectangle_tree/single_tree_traverser.hpp"
+#include "rectangle_tree/dual_tree_traverser.hpp"
#include "rectangle_tree/r_tree_split.hpp"
#include "rectangle_tree/r_tree_descent_heuristic.hpp"
#include "rectangle_tree/traits.hpp"
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
similarity index 69%
copy from src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
copy to src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 7cd1871..1091224 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -1,28 +1,27 @@
/**
- * @file dual_tree_traverser.hpp
- * @author Ryan Curtin
- *
- * Defines the DualTreeTraverser for the BinarySpaceTree tree type. This is a
- * nested class of BinarySpaceTree which traverses two trees in a depth-first
- * manner with a given set of rules which indicate the branches which can be
- * pruned and the order in which to recurse.
- */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+ * @file dual_tree_traverser.hpp
+ * @author Andrew Wells
+ *
+ * A nested class of Rectangle Tree for traversing rectangle type trees
+ * with a given set of rules which indicate the branches to prune and the
+ * order in which to recurse. This is just here to make it compile.
+ */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_HPP
#include <mlpack/core.hpp>
-#include "binary_space_tree.hpp"
+#include "rectangle_tree.hpp"
namespace mlpack {
namespace tree {
-template<typename BoundType,
- typename StatisticType,
- typename MatType,
- typename SplitType>
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
DualTreeTraverser
{
public:
@@ -38,8 +37,8 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
* @param referenceNode The reference node to be traversed.
* @param score The score of the current node combination.
*/
- void Traverse(BinarySpaceTree& queryNode,
- BinarySpaceTree& referenceNode);
+ void Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode);
//! Get the number of prunes.
size_t NumPrunes() const { return numPrunes; }
@@ -88,5 +87,4 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
// Include implementation.
#include "dual_tree_traverser_impl.hpp"
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-
+#endif
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..8af988e
--- /dev/null
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,49 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Andrew Wells
+ *
+ * A class for traversing rectangle type trees with a given set of rules
+ * which indicate the branches to prune and the order in which to recurse.
+ * This is a depth-first traverser.
+ */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+#include "dual_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+{
+ //Do nothing. Just here to prevent warnings.
+ if(queryNode.NumDescendants() > referenceNode.NumDescendants())
+ return;
+ return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e6833b9..e50d372 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -79,10 +79,13 @@ class RectangleTree
//! So other classes can use TreeType::Mat.
typedef MatType Mat;
- //! A traverser for rectangle type trees. See
- //! rectangle_tree_traverser.hpp for implementation.
+ //! A single traverser for rectangle type trees. See
+ //! single_tree_traverser.hpp for implementation.
template<typename RuleType>
- class RectangleTreeTraverser;
+ class SingleTreeTraverser;
+ //! A dual tree traverser for rectangle type trees.
+ template<typename RuleType>
+ class DualTreeTraverser;
/**
* Construct this as the root node of a rectangle type tree using the given
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 432c23e..b4dc0cf 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
"(experimental, may be slow).", "c");
PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
- "(experimental, may be slow. Currently automatically sets single_mode.).", "R");
+ "(experimental, may be slow. Currently automatically sets single_mode.).", "T");
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -132,6 +132,7 @@ int main(int argc, char *argv[])
} else if (!singleMode && CLI::HasParam("r_tree")) // R_tree requires single mode.
{
Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+ singleMode = true;
}
if (naive)
@@ -269,16 +270,71 @@ int main(int argc, char *argv[])
// Make sure to notify the user that they are using an r tree.
Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
- // Build the reference tree.
+ // Because we may construct it differently, we need a pointer.
+ NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >* allknn = NULL;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
Log::Info << "Building reference tree..." << endl;
Timer::Start("tree_building");
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>
+ refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat>
- refTree(referenceData, leafSize, leafSize/3, 5, 2, 0);
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree, queryTree,
+ referenceData, queryData, singleMode);
+ } else
+ {
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RTreeSplit<tree::RTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree,
+ referenceData, singleMode);
+ }
+ Log::Info << "Tree built." << endl;
+
+ arma::mat distancesOut;
+ arma::Mat<size_t> neighborsOut;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allknn->Search(k, neighborsOut, distancesOut);
+
+ Log::Info << "Neighbors computed." << endl;
+
+
+ delete allknn;
+
+
+ // Build the reference tree.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
Timer::Stop("tree_building");
- std::cout << "completed tree building " << refTree.NumDescendants() << std::endl;
}
}
else // Cover trees.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index dac46c0..bcada1f 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -60,7 +60,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
metric(metric)
{
// C++11 will allow us to call out to other constructors so we can avoid this
- // copypasta problem.
+ // copy/paste problem.
// We'll time tree building, but only if we are building trees.
Timer::Start("tree_building");
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index c96b10a..5b4eca9 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -56,7 +56,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
++baseCases;
// If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the poto insert it into.
+ // SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
const size_t insertPosition = SortPolicy::SortDistance(queryDist,
More information about the mlpack-git
mailing list