[mlpack-git] master: Move PerformSplit and fix edge case bugs. (e140a47)

gitdub at mlpack.org gitdub at mlpack.org
Thu Sep 29 11:49:16 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit e140a470c7feeeebd98fc0ad0543028cd34b4d3e
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Sep 29 11:49:16 2016 -0400

    Move PerformSplit and fix edge case bugs.
    
    Those bugs never occurred with BinarySpaceTree due to assumptions that
    BinarySpaceTree can make, but they can arise in other situations.


>---------------------------------------------------------------

e140a470c7feeeebd98fc0ad0543028cd34b4d3e
 src/mlpack/core/tree/CMakeLists.txt                |   2 +-
 .../distributed_binary_traversal_impl.hpp          | 182 +++++++++++++++++++++
 .../core/tree/binary_space_tree/mean_split.hpp     |   2 +-
 .../core/tree/binary_space_tree/midpoint_split.hpp |   2 +-
 .../tree/binary_space_tree/rp_tree_max_split.hpp   |   2 +-
 .../tree/binary_space_tree/rp_tree_mean_split.hpp  |   2 +-
 .../tree/binary_space_tree/vantage_point_split.hpp |   2 +-
 .../tree/{binary_space_tree => }/perform_split.hpp |  22 ++-
 8 files changed, 203 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 77e9026..adc5bae 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -15,7 +15,6 @@ set(SOURCES
   binary_space_tree/mean_split_impl.hpp
   binary_space_tree/midpoint_split.hpp
   binary_space_tree/midpoint_split_impl.hpp
-  binary_space_tree/perform_split.hpp
   binary_space_tree/rp_tree_max_split.hpp
   binary_space_tree/rp_tree_max_split_impl.hpp
   binary_space_tree/rp_tree_mean_split.hpp
@@ -58,6 +57,7 @@ set(SOURCES
   octree/dual_tree_traverser.hpp
   octree/dual_tree_traverser_impl.hpp
   octree/traits.hpp
+  perform_split.hpp
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
   rectangle_tree/rectangle_tree_impl.hpp
diff --git a/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp b/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp
new file mode 100644
index 0000000..cccd703
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp
@@ -0,0 +1,182 @@
+/**
+ * @file distributed_binary_traversal_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Use MPI to perform a distributed traversal.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DISTRIBUTED_BINARY_TRAVERSAL_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DISTRIBUTED_BINARY_TRAVERSAL_IMPL_HPP
+
+#include "distributed_binary_traversal.hpp"
+#include "../binary_space_tree.hpp"
+#include "dual_tree_traverser.hpp"
+#include <boost/mpi.hpp>
+
+namespace mlpack {
+namespace tree {
+
+template<typename RuleType>
+DistributedBinaryTraversal<RuleType>::DistributedBinaryTraversal(
+    RuleType& rule) :
+    rule(&rule),
+    world()
+{
+  // Nothing to do.
+}
+
+template<typename RuleType>
+DistributedBinaryTraversal<RuleType>::DistributedBinaryTraversal() :
+    rule(NULL),
+    world()
+{
+  // We are an MPI child.  We must receive and construct our own RuleType
+  // object, query tree, and reference tree.  Once we have done that, we kick
+  // off the usual recursion, and when we're done, we send the results back.
+  typename RuleType::MPIWrapper wrapper;
+  Log::Info << "Process " << world.rank() << " is waiting for a message.\n";
+  Timer::Start("child_receive");
+  world.recv(0, 0, wrapper);
+  Timer::Stop("child_receive");
+  Log::Info << "Process " << world.rank() << " has received a message.\n";
+
+
+  // We've now received our information.  Start the recursion.
+  this->rule = wrapper.Rules();
+  Timer::Start("child_traversal");
+  Traverse(*wrapper.QueryTree(), *wrapper.ReferenceTree());
+  Timer::Stop("child_traversal");
+
+  // Now, we have to ship the neighbors and distances back to the master.
+  typename RuleType::MPIResultsWrapper resultsWrapper(rule->Neighbors(),
+                                                      rule->Distances());
+  Log::Info << "Process " << world.rank() << " is sending results.\n";
+  Timer::Start("send_results");
+  world.send(0, 0, resultsWrapper);
+  Timer::Stop("send_results");
+  Log::Info << "Process " << world.rank() << " is finished.\n";
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::Traverse(const size_t queryIndex,
+                                                    TreeType& referenceNode)
+{
+
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::Traverse(TreeType& queryNode,
+                                                    TreeType& referenceNode)
+{
+  // If we are the master, call the master traversal.  Otherwise, call the child
+  // traversal.
+  if (world.rank() == 0)
+  {
+    // Start the traversal, and pass the work to the children.
+    MasterTraverse(queryNode, referenceNode);
+  }
+  else
+  {
+    ChildTraverse(queryNode, referenceNode);
+  }
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::MasterTraverse(
+    TreeType& queryNode,
+    TreeType& referenceNode)
+{
+  // A list of jobs to be done.
+  std::queue<std::pair<TreeType*, TreeType*>> jobs;
+  jobs.push(&queryNode, &referenceNode);
+
+  // A list of which nodes are busy and which aren't.
+  std::vector<bool> busy(world.size() - 1, false);
+
+  while (!jobs.empty())
+  {
+    // Find an unused worker (wait for a response).
+    RuleType::MPIResultType result;
+    boost::mpi::status status;
+    status = communicator.recv(boost::mpi::any_source, boost::mpi::any_tag,
+        result);
+
+    // Immediately put that worker back to work on a new job.
+    RuleType::MPIWorkType work(queue.front().first, queue.front().second);
+    communicator.send(status.source(), 0 /* zero tag */, work);
+
+    if (result.tag() == 1) // Initialization tag; no data.
+    {
+      // Now, look through the results to add new jobs.
+      const RuleType::MPIWorkType& job = jobs[status.source()];
+      for (size_t i = 0; i < result.NumNewTasks(); ++i)
+        jobs.push_back(result.NewTask(job.QueryNode(), job.ReferenceNode()));
+
+      // And merge the results into the tree that we have.
+      for (size_t i = 0; i < result.NumTreeUpdates(); ++i)
+        result.MergeResult(job.QueryNode(), job.ReferenceNode());
+    }
+  }
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::ChildTraverse(
+    TreeType& queryNode,
+    TreeType& referenceNode)
+{
+  // We'll just call out to the standard dual-tree traversal for a single node.
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(*rule);
+
+  traverser.Traverse(queryNode, referenceNode);
+}
+
+template<typename RuleType>
+template<typename TreeType>
+size_t DistributedBinaryTraversal<RuleType>::GetTarget(
+    TreeType& queryNode,
+    TreeType& referenceNode) const
+{
+  // We assemble the ID of the target process in a bitwise manner.  The leftmost
+  // combination maps to process 0.  At any level of recursion, because this is
+  // a binary recursion, the query node may be either the left (L) child or the
+  // right (R) child, and the same applies to the reference node.  Thus the
+  // direction we have gone at a recursion can have four possibilities: LL, LR,
+  // RL, and RR.  Take L = 0 and R = 1; now a single recursion can be
+  // represented as two bits.  The highest-level recursion will be the two most
+  // significant bits and the most recent recursion will be the two least
+  // significant bits.  Thus, if the most recent recursion was RL and the
+  // higher-level recursion was LR, and there were no higher recursions than
+  // that, the index will be LRRL -> 0110 -> 6.  If any recursion was not a dual
+  // recursion, undefined behavior will happen.  It probably won't crash.
+  size_t index = 0;
+
+  TreeType* currentQuery = &queryNode;
+  TreeType* currentRef = &referenceNode;
+  size_t level = 0;
+  while (currentQuery->Parent() != NULL && currentRef->Parent() != NULL)
+  {
+    // Assemble this index.
+    size_t currentIndex = 0; // Assume LL, change if otherwise.
+    if (currentQuery->Parent()->Right() == currentQuery)
+      currentIndex += 2; // Now it's RL.
+    if (currentRef->Parent()->Right() == currentRef)
+      currentIndex++; // Now it's LR or RR.
+
+    // Append this index.
+    index += (currentIndex << (level * 2));
+    ++level;
+
+    currentQuery = currentQuery->Parent();
+    currentRef = currentRef->Parent();
+  }
+
+  return index + 1; // Index 0 is the root.
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index 7d7e67a..ecad1ba 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -10,7 +10,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
 
 #include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index f898c42..01d8eeb 100644
--- a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -11,7 +11,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_HPP
 
 #include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
index 8bd513d..9d01b0d 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
@@ -9,7 +9,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MAX_SPLIT_HPP
 
 #include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
index 718b1de..332122b 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
@@ -10,7 +10,7 @@
 
 #include <mlpack/core.hpp>
 #include "rp_tree_max_split.hpp"
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
index c9a3006..5a16c45 100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
@@ -9,7 +9,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
 
 #include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/perform_split.hpp b/src/mlpack/core/tree/perform_split.hpp
similarity index 90%
rename from src/mlpack/core/tree/binary_space_tree/perform_split.hpp
rename to src/mlpack/core/tree/perform_split.hpp
index acb8c90..ff14028 100644
--- a/src/mlpack/core/tree/binary_space_tree/perform_split.hpp
+++ b/src/mlpack/core/tree/perform_split.hpp
@@ -8,9 +8,8 @@
  * of the split column, and points from the right subtree are on the right side
  * of the split column.
  */
-
-#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_PERFORM_SPLIT_HPP
-#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_PERFORM_SPLIT_HPP
+#ifndef MLPACK_CORE_TREE_PERFORM_SPLIT_HPP
+#define MLPACK_CORE_TREE_PERFORM_SPLIT_HPP
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -41,12 +40,17 @@ size_t PerformSplit(MatType& data,
 
   // First half-iteration of the loop is out here because the termination
   // condition is in the middle.
-  while (SplitType::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+  while ((left <= right) &&
+      (SplitType::AssignToLeftNode(data.col(left), splitInfo)))
     left++;
   while ((!SplitType::AssignToLeftNode(data.col(right), splitInfo)) &&
       (left <= right) && (right > 0))
     right--;
 
+  // Shortcut for when all points are on the right.
+  if (left == right && right == 0)
+    return left;
+
   while (left <= right)
   {
     // Swap columns.
@@ -102,12 +106,17 @@ size_t PerformSplit(MatType& data,
 
   // First half-iteration of the loop is out here because the termination
   // condition is in the middle.
-  while (SplitType::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+  while ((left <= right) &&
+         (SplitType::AssignToLeftNode(data.col(left), splitInfo)))
     left++;
   while ((!SplitType::AssignToLeftNode(data.col(right), splitInfo)) &&
-      (left <= right) && (right > 0))
+         (left <= right) && (right > 0))
     right--;
 
+  // Shortcut for when all points are on the right.
+  if (left == right && right == 0)
+    return left;
+
   while (left <= right)
   {
     // Swap columns.
@@ -135,7 +144,6 @@ size_t PerformSplit(MatType& data,
   }
 
   Log::Assert(left == right + 1);
-
   return left;
 }
 




More information about the mlpack-git mailing list