[mlpack-git] master: Add traversers for octree. (4175ead)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Sep 24 12:44:33 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit 4175eadaf2ba61263be8b8c569aaa8c4434e6159
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Sep 24 12:44:33 2016 -0400
Add traversers for octree.
>---------------------------------------------------------------
4175eadaf2ba61263be8b8c569aaa8c4434e6159
src/mlpack/core/tree/CMakeLists.txt | 4 +
src/mlpack/core/tree/octree.hpp | 2 +
.../core/tree/octree/dual_tree_traverser.hpp | 78 +++++++++++
.../core/tree/octree/dual_tree_traverser_impl.hpp | 147 +++++++++++++++++++++
src/mlpack/core/tree/octree/octree.hpp | 13 ++
.../core/tree/octree/single_tree_traverser.hpp | 53 ++++++++
.../tree/octree/single_tree_traverser_impl.hpp | 67 ++++++++++
7 files changed, 364 insertions(+)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 7cf2166..77e9026 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -53,6 +53,10 @@ set(SOURCES
octree.hpp
octree/octree.hpp
octree/octree_impl.hpp
+ octree/single_tree_traverser.hpp
+ octree/single_tree_traverser_impl.hpp
+ octree/dual_tree_traverser.hpp
+ octree/dual_tree_traverser_impl.hpp
octree/traits.hpp
rectangle_tree.hpp
rectangle_tree/rectangle_tree.hpp
diff --git a/src/mlpack/core/tree/octree.hpp b/src/mlpack/core/tree/octree.hpp
index f60f738..ed72ae5 100644
--- a/src/mlpack/core/tree/octree.hpp
+++ b/src/mlpack/core/tree/octree.hpp
@@ -11,5 +11,7 @@
#include "bounds.hpp"
#include "octree/octree.hpp"
#include "octree/traits.hpp"
+#include "octree/single_tree_traverser.hpp"
+#include "octree/dual_tree_traverser.hpp"
#endif
diff --git a/src/mlpack/core/tree/octree/dual_tree_traverser.hpp b/src/mlpack/core/tree/octree/dual_tree_traverser.hpp
new file mode 100644
index 0000000..ec9774a
--- /dev/null
+++ b/src/mlpack/core/tree/octree/dual_tree_traverser.hpp
@@ -0,0 +1,78 @@
+/**
+ * @file dual_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Define the dual-tree traverser for the Octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_HPP
+#define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+#include "octree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+class Octree<MetricType, StatisticType, MatType>::DualTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the given dual-tree traverser with the given rule set.
+ */
+ DualTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the two trees. This does not reset the statistics of the
+ * traversals (it just adds to them).
+ */
+ void Traverse(Octree& queryNode, Octree& referenceNode);
+
+ //! Get the number of pruned nodes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of pruned nodes (i.e. to reset it).
+ size_t& NumPrunes() { return numPrunes; }
+
+ //! Get the number of visited node combinations.
+ size_t NumVisited() const { return numVisited; }
+ //! Modify the number of visited node combinations.
+ size_t& NumVistied() { return numVisited; }
+
+ //! Get the number of times a node was scored.
+ size_t NumScores() const { return numScores; }
+ //! Modify the number of times a node was scored.
+ size_t& NumScores() { return numScores; }
+
+ //! Get the number of times a base case was computed.
+ size_t NumBaseCases() const { return numBaseCases; }
+ //! Modify the number of times a base case was computed.
+ size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+ //! The rule type to use.
+ RuleType& rule;
+
+ //! The number of prunes.
+ size_t numPrunes;
+ //! The number of visited node combinations.
+ size_t numVisited;
+ //! The number of times a node was scored.
+ size_t numScores;
+ //! The number of times a base case was calculated.
+ size_t numBaseCases;
+
+ //! Traversal information, held in the class so that it isn't continually
+ //! being reallocated.
+ typename RuleType::TraversalInfoType traversalInfo;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..7058adb
--- /dev/null
+++ b/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,147 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the dual-tree traverser for the octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_traverser.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
+ DualTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0),
+ numVisited(0),
+ numScores(0),
+ numBaseCases(0)
+{
+ // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
+ Traverse(Octree& queryNode, Octree& referenceNode)
+{
+ // Increment the visit counter.
+ ++numVisited;
+
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
+ if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ const size_t begin = queryNode.Point(0);
+ const size_t end = begin + queryNode.NumPoints();
+ for (size_t q = begin; q < end; ++q)
+ {
+ // First, see if we can prune the reference node for this query point.
+ rule.TraversalInfo() = traversalInfo;
+ const double score = rule.Score(q, referenceNode);
+ if (score == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ const size_t rBegin = referenceNode.Point(0);
+ const size_t rEnd = rBegin + referenceNode.NumPoints();
+ for (size_t r = rBegin; r < rEnd; ++r)
+ rule.BaseCase(q, r);
+
+ numBaseCases += referenceNode.NumPoints();
+ }
+ }
+ else if (!queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // We have to recurse down the query node. Order does not matter.
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ rule.TraversalInfo() = traversalInfo;
+ const double score = rule.Score(queryNode.Child(i), referenceNode);
+ if (score == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ Traverse(queryNode.Child(i), referenceNode);
+ }
+ }
+ else if (queryNode.IsLeaf() && !referenceNode.IsLeaf())
+ {
+ // We have to recurse down the reference node, so we need to do it in an
+ // ordered manner.
+ arma::vec scores(referenceNode.NumChildren());
+ std::vector<typename RuleType::TraversalInfoType> tis;
+ for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+ {
+ rule.TraversalInfo() = traversalInfo;
+ scores[i] = rule.Score(queryNode, referenceNode.Child(i));
+ tis.push_back(rule.TraversalInfo());
+ }
+
+ // Sort the scores.
+ arma::uvec scoreOrder = arma::sort_index(scores);
+ for (size_t i = 0; i < scoreOrder.n_elem; ++i)
+ {
+ if (scores[scoreOrder[i]] == DBL_MAX)
+ {
+ // We don't need to check any more---all children past here are pruned.
+ numPrunes += scoreOrder.n_elem - i;
+ break;
+ }
+
+ rule.TraversalInfo() = tis[scoreOrder[i]];
+ Traverse(queryNode, referenceNode.Child(scoreOrder[i]));
+ }
+ }
+ else
+ {
+ // We have to recurse down both the query and reference nodes. Query order
+ // does not matter, so we will do that in sequence. However we will
+ // allocate the arrays for recursion at this level.
+ arma::vec scores(referenceNode.NumChildren());
+ std::vector<typename RuleType::TraversalInfoType>
+ tis(referenceNode.NumChildren());
+ for (size_t j = 0; j < queryNode.NumChildren(); ++j)
+ {
+ // Now we have to recurse down the reference node, which we will do in a
+ // prioritized manner.
+ for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+ {
+ rule.TraversalInfo() = traversalInfo;
+ scores[i] = rule.Score(queryNode.Child(j), referenceNode.Child(i));
+ tis[i] = rule.TraversalInfo();
+ }
+
+ // Sort the scores.
+ arma::uvec scoreOrder = arma::sort_index(scores);
+ for (size_t i = 0; i < scoreOrder.n_elem; ++i)
+ {
+ if (scores[scoreOrder[i]] == DBL_MAX)
+ {
+ // We don't need to check any more---all children past here are pruned.
+ numPrunes += scoreOrder.n_elem - i;
+ break;
+ }
+
+ rule.TraversalInfo() = tis[scoreOrder[i]];
+ Traverse(queryNode.Child(j), referenceNode.Child(scoreOrder[i]));
+ }
+ }
+ }
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index f7fa385..0a8ffd2 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -25,6 +25,14 @@ class Octree
//! The type of element held in MatType.
typedef typename MatType::elem_type ElemType;
+ //! A single-tree traverser; see single_tree_traverser.hpp.
+ template<typename RuleType>
+ class SingleTreeTraverser;
+
+ //! A dual-tree traverser; see dual_tree_traverser.hpp.
+ template<typename RuleType>
+ class DualTreeTraverser;
+
private:
//! The children held by this node.
std::vector<Octree*> children;
@@ -265,6 +273,11 @@ class Octree
typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
/**
+ * Return whether or not the node is a leaf.
+ */
+ bool IsLeaf() const { return NumChildren() == 0; }
+
+ /**
* Return the index of the nearest child node to the given query node. If it
* can't decide, it will return NumChildren() (invalid index).
*/
diff --git a/src/mlpack/core/tree/octree/single_tree_traverser.hpp b/src/mlpack/core/tree/octree/single_tree_traverser.hpp
new file mode 100644
index 0000000..a82149f
--- /dev/null
+++ b/src/mlpack/core/tree/octree/single_tree_traverser.hpp
@@ -0,0 +1,53 @@
+/**
+ * @file single_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of the single tree traverser for the octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_HPP
+#define MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+#include "octree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+class Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the traverser with the given rule set.
+ */
+ SingleTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the reference tree with the given query point. This does not
+ * reset the number of pruned nodes.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Node in reference tree.
+ */
+ void Traverse(const size_t queryIndex, Octree& referenceNode);
+
+ //! Get the number of pruned nodes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of pruned nodes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ private:
+ //! The instantiated rule.
+ RuleType& rule;
+ //! The number of reference nodes that have been pruned.
+ size_t numPrunes;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp
new file mode 100644
index 0000000..49ba9c2
--- /dev/null
+++ b/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp
@@ -0,0 +1,67 @@
+/**
+ * @file single_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the single tree traverser for octrees.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "single_tree_traverser.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>::
+ SingleTreeTraverser(RuleType& rule) :
+ rule(rule)
+{
+ // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>::
+ Traverse(const size_t queryIndex, Octree& referenceNode)
+{
+ // If we are a leaf, run the base cases.
+ if (referenceNode.NumChildren() == 0)
+ {
+ const size_t refBegin = referenceNode.Point(0);
+ const size_t refEnd = refBegin + referenceNode.NumPoints();
+ for (size_t r = refBegin; r < refEnd; ++r)
+ rule.BaseCase(queryIndex, r);
+ }
+ else
+ {
+ // Do a prioritized recursion, by scoring all candidates and then sorting
+ // them.
+ arma::vec scores(referenceNode.NumChildren());
+ for (size_t i = 0; i < scores.n_elem; ++i)
+ scores[i] = rule.Score(queryIndex, referenceNode.Child(i));
+
+ // Sort the scores.
+ arma::uvec sortedIndices = arma::sort_index(scores);
+
+ for (size_t i = 0; i < sortedIndices.n_elem; ++i)
+ {
+ // If the node is pruned, all subsequent nodes in sorted order will also
+ // be pruned.
+ if (scores[sortedIndices[i]] == DBL_MAX)
+ {
+ numPrunes += (sortedIndices.n_elem - i);
+ break;
+ }
+
+ Traverse(queryIndex, referenceNode.Child(sortedIndices[i]));
+ }
+ }
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list