[mlpack-git] master: Add traversers for octree. (4175ead)

gitdub at mlpack.org gitdub at mlpack.org
Sat Sep 24 12:44:33 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit 4175eadaf2ba61263be8b8c569aaa8c4434e6159
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Sep 24 12:44:33 2016 -0400

    Add traversers for octree.


>---------------------------------------------------------------

4175eadaf2ba61263be8b8c569aaa8c4434e6159
 src/mlpack/core/tree/CMakeLists.txt                |   4 +
 src/mlpack/core/tree/octree.hpp                    |   2 +
 .../core/tree/octree/dual_tree_traverser.hpp       |  78 +++++++++++
 .../core/tree/octree/dual_tree_traverser_impl.hpp  | 147 +++++++++++++++++++++
 src/mlpack/core/tree/octree/octree.hpp             |  13 ++
 .../core/tree/octree/single_tree_traverser.hpp     |  53 ++++++++
 .../tree/octree/single_tree_traverser_impl.hpp     |  67 ++++++++++
 7 files changed, 364 insertions(+)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 7cf2166..77e9026 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -53,6 +53,10 @@ set(SOURCES
   octree.hpp
   octree/octree.hpp
   octree/octree_impl.hpp
+  octree/single_tree_traverser.hpp
+  octree/single_tree_traverser_impl.hpp
+  octree/dual_tree_traverser.hpp
+  octree/dual_tree_traverser_impl.hpp
   octree/traits.hpp
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
diff --git a/src/mlpack/core/tree/octree.hpp b/src/mlpack/core/tree/octree.hpp
index f60f738..ed72ae5 100644
--- a/src/mlpack/core/tree/octree.hpp
+++ b/src/mlpack/core/tree/octree.hpp
@@ -11,5 +11,7 @@
 #include "bounds.hpp"
 #include "octree/octree.hpp"
 #include "octree/traits.hpp"
+#include "octree/single_tree_traverser.hpp"
+#include "octree/dual_tree_traverser.hpp"
 
 #endif
diff --git a/src/mlpack/core/tree/octree/dual_tree_traverser.hpp b/src/mlpack/core/tree/octree/dual_tree_traverser.hpp
new file mode 100644
index 0000000..ec9774a
--- /dev/null
+++ b/src/mlpack/core/tree/octree/dual_tree_traverser.hpp
@@ -0,0 +1,78 @@
+/**
+ * @file dual_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Define the dual-tree traverser for the Octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_HPP
+#define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+#include "octree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+class Octree<MetricType, StatisticType, MatType>::DualTreeTraverser
+{
+ public:
+  /**
+   * Instantiate the given dual-tree traverser with the given rule set.
+   */
+  DualTreeTraverser(RuleType& rule);
+
+  /**
+   * Traverse the two trees.  This does not reset the statistics of the
+   * traversals (it just adds to them).
+   */
+  void Traverse(Octree& queryNode, Octree& referenceNode);
+
+  //! Get the number of pruned nodes.
+  size_t NumPrunes() const { return numPrunes; }
+  //! Modify the number of pruned nodes (i.e. to reset it).
+  size_t& NumPrunes() { return numPrunes; }
+
+  //! Get the number of visited node combinations.
+  size_t NumVisited() const { return numVisited; }
+  //! Modify the number of visited node combinations.
+  size_t& NumVistied() { return numVisited; }
+
+  //! Get the number of times a node was scored.
+  size_t NumScores() const { return numScores; }
+  //! Modify the number of times a node was scored.
+  size_t& NumScores() { return numScores; }
+
+  //! Get the number of times a base case was computed.
+  size_t NumBaseCases() const { return numBaseCases; }
+  //! Modify the number of times a base case was computed.
+  size_t& NumBaseCases() { return numBaseCases; }
+
+ private:
+  //! The rule type to use.
+  RuleType& rule;
+
+  //! The number of prunes.
+  size_t numPrunes;
+  //! The number of visited node combinations.
+  size_t numVisited;
+  //! The number of times a node was scored.
+  size_t numScores;
+  //! The number of times a base case was calculated.
+  size_t numBaseCases;
+
+  //! Traversal information, held in the class so that it isn't continually
+  //! being reallocated.
+  typename RuleType::TraversalInfoType traversalInfo;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..7058adb
--- /dev/null
+++ b/src/mlpack/core/tree/octree/dual_tree_traverser_impl.hpp
@@ -0,0 +1,147 @@
+/**
+ * @file dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the dual-tree traverser for the octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define MLPACK_CORE_TREE_OCTREE_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_traverser.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
+    DualTreeTraverser(RuleType& rule) :
+    rule(rule),
+    numPrunes(0),
+    numVisited(0),
+    numScores(0),
+    numBaseCases(0)
+{
+  // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>::
+    Traverse(Octree& queryNode, Octree& referenceNode)
+{
+  // Increment the visit counter.
+  ++numVisited;
+
+  // Store the current traversal info.
+  traversalInfo = rule.TraversalInfo();
+
+  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+  {
+    const size_t begin = queryNode.Point(0);
+    const size_t end = begin + queryNode.NumPoints();
+    for (size_t q = begin; q < end; ++q)
+    {
+      // First, see if we can prune the reference node for this query point.
+      rule.TraversalInfo() = traversalInfo;
+      const double score = rule.Score(q, referenceNode);
+      if (score == DBL_MAX)
+      {
+        ++numPrunes;
+        continue;
+      }
+
+      const size_t rBegin = referenceNode.Point(0);
+      const size_t rEnd = rBegin + referenceNode.NumPoints();
+      for (size_t r = rBegin; r < rEnd; ++r)
+        rule.BaseCase(q, r);
+
+      numBaseCases += referenceNode.NumPoints();
+    }
+  }
+  else if (!queryNode.IsLeaf() && referenceNode.IsLeaf())
+  {
+    // We have to recurse down the query node.  Order does not matter.
+    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+    {
+      rule.TraversalInfo() = traversalInfo;
+      const double score = rule.Score(queryNode.Child(i), referenceNode);
+      if (score == DBL_MAX)
+      {
+        ++numPrunes;
+        continue;
+      }
+
+      Traverse(queryNode.Child(i), referenceNode);
+    }
+  }
+  else if (queryNode.IsLeaf() && !referenceNode.IsLeaf())
+  {
+    // We have to recurse down the reference node, so we need to do it in an
+    // ordered manner.
+    arma::vec scores(referenceNode.NumChildren());
+    std::vector<typename RuleType::TraversalInfoType> tis;
+    for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+    {
+      rule.TraversalInfo() = traversalInfo;
+      scores[i] = rule.Score(queryNode, referenceNode.Child(i));
+      tis.push_back(rule.TraversalInfo());
+    }
+
+    // Sort the scores.
+    arma::uvec scoreOrder = arma::sort_index(scores);
+    for (size_t i = 0; i < scoreOrder.n_elem; ++i)
+    {
+      if (scores[scoreOrder[i]] == DBL_MAX)
+      {
+        // We don't need to check any more---all children past here are pruned.
+        numPrunes += scoreOrder.n_elem - i;
+        break;
+      }
+
+      rule.TraversalInfo() = tis[scoreOrder[i]];
+      Traverse(queryNode, referenceNode.Child(scoreOrder[i]));
+    }
+  }
+  else
+  {
+    // We have to recurse down both the query and reference nodes.  Query order
+    // does not matter, so we will do that in sequence.  However we will
+    // allocate the arrays for recursion at this level.
+    arma::vec scores(referenceNode.NumChildren());
+    std::vector<typename RuleType::TraversalInfoType>
+        tis(referenceNode.NumChildren());
+    for (size_t j = 0; j < queryNode.NumChildren(); ++j)
+    {
+      // Now we have to recurse down the reference node, which we will do in a
+      // prioritized manner.
+      for (size_t i = 0; i < referenceNode.NumChildren(); ++i)
+      {
+        rule.TraversalInfo() = traversalInfo;
+        scores[i] = rule.Score(queryNode.Child(j), referenceNode.Child(i));
+        tis[i] = rule.TraversalInfo();
+      }
+
+      // Sort the scores.
+      arma::uvec scoreOrder = arma::sort_index(scores);
+      for (size_t i = 0; i < scoreOrder.n_elem; ++i)
+      {
+        if (scores[scoreOrder[i]] == DBL_MAX)
+        {
+          // We don't need to check any more---all children past here are pruned.
+          numPrunes += scoreOrder.n_elem - i;
+          break;
+        }
+
+        rule.TraversalInfo() = tis[scoreOrder[i]];
+        Traverse(queryNode.Child(j), referenceNode.Child(scoreOrder[i]));
+      }
+    }
+  }
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index f7fa385..0a8ffd2 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -25,6 +25,14 @@ class Octree
   //! The type of element held in MatType.
   typedef typename MatType::elem_type ElemType;
 
+  //! A single-tree traverser; see single_tree_traverser.hpp.
+  template<typename RuleType>
+  class SingleTreeTraverser;
+
+  //! A dual-tree traverser; see dual_tree_traverser.hpp.
+  template<typename RuleType>
+  class DualTreeTraverser;
+
  private:
   //! The children held by this node.
   std::vector<Octree*> children;
@@ -265,6 +273,11 @@ class Octree
       typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
 
   /**
+   * Return whether or not the node is a leaf.
+   */
+  bool IsLeaf() const { return NumChildren() == 0; }
+
+  /**
    * Return the index of the nearest child node to the given query node.  If it
    * can't decide, it will return NumChildren() (invalid index).
    */
diff --git a/src/mlpack/core/tree/octree/single_tree_traverser.hpp b/src/mlpack/core/tree/octree/single_tree_traverser.hpp
new file mode 100644
index 0000000..a82149f
--- /dev/null
+++ b/src/mlpack/core/tree/octree/single_tree_traverser.hpp
@@ -0,0 +1,53 @@
+/**
+ * @file single_tree_traverser.hpp
+ * @author Ryan Curtin
+ *
+ * Definition of the single tree traverser for the octree.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_HPP
+#define MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+#include "octree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+class Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser
+{
+ public:
+  /**
+   * Instantiate the traverser with the given rule set.
+   */
+  SingleTreeTraverser(RuleType& rule);
+
+  /**
+   * Traverse the reference tree with the given query point.  This does not
+   * reset the number of pruned nodes.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Node in reference tree.
+   */
+  void Traverse(const size_t queryIndex, Octree& referenceNode);
+
+  //! Get the number of pruned nodes.
+  size_t NumPrunes() const { return numPrunes; }
+  //! Modify the number of pruned nodes.
+  size_t& NumPrunes() { return numPrunes; }
+
+ private:
+  //! The instantiated rule.
+  RuleType& rule;
+  //! The number of reference nodes that have been pruned.
+  size_t numPrunes;
+};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp
new file mode 100644
index 0000000..49ba9c2
--- /dev/null
+++ b/src/mlpack/core/tree/octree/single_tree_traverser_impl.hpp
@@ -0,0 +1,67 @@
+/**
+ * @file single_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the single tree traverser for octrees.
+ */
+#ifndef MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define MLPACK_CORE_TREE_OCTREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "single_tree_traverser.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>::
+    SingleTreeTraverser(RuleType& rule) :
+    rule(rule)
+{
+  // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename RuleType>
+void Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>::
+    Traverse(const size_t queryIndex, Octree& referenceNode)
+{
+  // If we are a leaf, run the base cases.
+  if (referenceNode.NumChildren() == 0)
+  {
+    const size_t refBegin = referenceNode.Point(0);
+    const size_t refEnd = refBegin + referenceNode.NumPoints();
+    for (size_t r = refBegin; r < refEnd; ++r)
+      rule.BaseCase(queryIndex, r);
+  }
+  else
+  {
+    // Do a prioritized recursion, by scoring all candidates and then sorting
+    // them.
+    arma::vec scores(referenceNode.NumChildren());
+    for (size_t i = 0; i < scores.n_elem; ++i)
+      scores[i] = rule.Score(queryIndex, referenceNode.Child(i));
+
+    // Sort the scores.
+    arma::uvec sortedIndices = arma::sort_index(scores);
+
+    for (size_t i = 0; i < sortedIndices.n_elem; ++i)
+    {
+      // If the node is pruned, all subsequent nodes in sorted order will also
+      // be pruned.
+      if (scores[sortedIndices[i]] == DBL_MAX)
+      {
+        numPrunes += (sortedIndices.n_elem - i);
+        break;
+      }
+
+      Traverse(queryIndex, referenceNode.Child(sortedIndices[i]));
+    }
+  }
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif




More information about the mlpack-git mailing list