[mlpack-git] master: Move PerformSplit and fix edge case bugs. (e140a47)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Sep 29 11:49:16 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit e140a470c7feeeebd98fc0ad0543028cd34b4d3e
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Sep 29 11:49:16 2016 -0400
Move PerformSplit and fix edge case bugs.
Those bugs never occurred with BinarySpaceTree due to assumptions that
BinarySpaceTree can make, but they can arise in other situations.
>---------------------------------------------------------------
e140a470c7feeeebd98fc0ad0543028cd34b4d3e
src/mlpack/core/tree/CMakeLists.txt | 2 +-
.../distributed_binary_traversal_impl.hpp | 182 +++++++++++++++++++++
.../core/tree/binary_space_tree/mean_split.hpp | 2 +-
.../core/tree/binary_space_tree/midpoint_split.hpp | 2 +-
.../tree/binary_space_tree/rp_tree_max_split.hpp | 2 +-
.../tree/binary_space_tree/rp_tree_mean_split.hpp | 2 +-
.../tree/binary_space_tree/vantage_point_split.hpp | 2 +-
.../tree/{binary_space_tree => }/perform_split.hpp | 22 ++-
8 files changed, 203 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 77e9026..adc5bae 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -15,7 +15,6 @@ set(SOURCES
binary_space_tree/mean_split_impl.hpp
binary_space_tree/midpoint_split.hpp
binary_space_tree/midpoint_split_impl.hpp
- binary_space_tree/perform_split.hpp
binary_space_tree/rp_tree_max_split.hpp
binary_space_tree/rp_tree_max_split_impl.hpp
binary_space_tree/rp_tree_mean_split.hpp
@@ -58,6 +57,7 @@ set(SOURCES
octree/dual_tree_traverser.hpp
octree/dual_tree_traverser_impl.hpp
octree/traits.hpp
+ perform_split.hpp
rectangle_tree.hpp
rectangle_tree/rectangle_tree.hpp
rectangle_tree/rectangle_tree_impl.hpp
diff --git a/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp b/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp
new file mode 100644
index 0000000..cccd703
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/distributed_binary_traversal_impl.hpp
@@ -0,0 +1,182 @@
+/**
+ * @file distributed_binary_traversal_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Use MPI to perform a distributed traversal.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DISTRIBUTED_BINARY_TRAVERSAL_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DISTRIBUTED_BINARY_TRAVERSAL_IMPL_HPP
+
+#include "distributed_binary_traversal.hpp"
+#include "../binary_space_tree.hpp"
+#include "dual_tree_traverser.hpp"
+#include <boost/mpi.hpp>
+
+namespace mlpack {
+namespace tree {
+
+template<typename RuleType>
+DistributedBinaryTraversal<RuleType>::DistributedBinaryTraversal(
+ RuleType& rule) :
+ rule(&rule),
+ world()
+{
+ // Nothing to do.
+}
+
+template<typename RuleType>
+DistributedBinaryTraversal<RuleType>::DistributedBinaryTraversal() :
+ rule(NULL),
+ world()
+{
+ // We are an MPI child. We must receive and construct our own RuleType
+ // object, query tree, and reference tree. Once we have done that, we kick
+ // off the usual recursion, and when we're done, we send the results back.
+ typename RuleType::MPIWrapper wrapper;
+ Log::Info << "Process " << world.rank() << " is waiting for a message.\n";
+ Timer::Start("child_receive");
+ world.recv(0, 0, wrapper);
+ Timer::Stop("child_receive");
+ Log::Info << "Process " << world.rank() << " has received a message.\n";
+
+
+ // We've now received our information. Start the recursion.
+ this->rule = wrapper.Rules();
+ Timer::Start("child_traversal");
+ Traverse(*wrapper.QueryTree(), *wrapper.ReferenceTree());
+ Timer::Stop("child_traversal");
+
+ // Now, we have to ship the neighbors and distances back to the master.
+ typename RuleType::MPIResultsWrapper resultsWrapper(rule->Neighbors(),
+ rule->Distances());
+ Log::Info << "Process " << world.rank() << " is sending results.\n";
+ Timer::Start("send_results");
+ world.send(0, 0, resultsWrapper);
+ Timer::Stop("send_results");
+ Log::Info << "Process " << world.rank() << " is finished.\n";
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::Traverse(const size_t queryIndex,
+ TreeType& referenceNode)
+{
+
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::Traverse(TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ // If we are the master, call the master traversal. Otherwise, call the child
+ // traversal.
+ if (world.rank() == 0)
+ {
+ // Start the traversal, and pass the work to the children.
+ MasterTraverse(queryNode, referenceNode);
+ }
+ else
+ {
+ ChildTraverse(queryNode, referenceNode);
+ }
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::MasterTraverse(
+ TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ // A list of jobs to be done.
+ std::queue<std::pair<TreeType*, TreeType*>> jobs;
+ jobs.push(&queryNode, &referenceNode);
+
+ // A list of which nodes are busy and which aren't.
+ std::vector<bool> busy(world.size() - 1, false);
+
+ while (!jobs.empty())
+ {
+ // Find an unused worker (wait for a response).
+ RuleType::MPIResultType result;
+ boost::mpi::status status;
+ status = communicator.recv(boost::mpi::any_source, boost::mpi::any_tag,
+ result);
+
+ // Immediately put that worker back to work on a new job.
+ RuleType::MPIWorkType work(queue.front().first, queue.front().second);
+ communicator.send(status.source(), 0 /* zero tag */, work);
+
+ if (result.tag() == 1) // Initialization tag; no data.
+ {
+ // Now, look through the results to add new jobs.
+ const RuleType::MPIWorkType& job = jobs[status.source()];
+ for (size_t i = 0; i < result.NumNewTasks(); ++i)
+ jobs.push_back(result.NewTask(job.QueryNode(), job.ReferenceNode()));
+
+ // And merge the results into the tree that we have.
+ for (size_t i = 0; i < result.NumTreeUpdates(); ++i)
+ result.MergeResult(job.QueryNode(), job.ReferenceNode());
+ }
+ }
+}
+
+template<typename RuleType>
+template<typename TreeType>
+void DistributedBinaryTraversal<RuleType>::ChildTraverse(
+ TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ // We'll just call out to the standard dual-tree traversal for a single node.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(*rule);
+
+ traverser.Traverse(queryNode, referenceNode);
+}
+
+template<typename RuleType>
+template<typename TreeType>
+size_t DistributedBinaryTraversal<RuleType>::GetTarget(
+ TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // We assemble the ID of the target process in a bitwise manner. The leftmost
+ // combination maps to process 0. At any level of recursion, because this is
+ // a binary recursion, the query node may be either the left (L) child or the
+ // right (R) child, and the same applies to the reference node. Thus the
+ // direction we have gone at a recursion can have four possibilities: LL, LR,
+ // RL, and RR. Take L = 0 and R = 1; now a single recursion can be
+ // represented as two bits. The highest-level recursion will be the two most
+ // significant bits and the most recent recursion will be the two least
+ // significant bits. Thus, if the most recent recursion was RL and the
+ // higher-level recursion was LR, and there were no higher recursions than
+ // that, the index will be LRRL -> 0110 -> 6. If any recursion was not a dual
+ // recursion, undefined behavior will happen. It probably won't crash.
+ size_t index = 0;
+
+ TreeType* currentQuery = &queryNode;
+ TreeType* currentRef = &referenceNode;
+ size_t level = 0;
+ while (currentQuery->Parent() != NULL && currentRef->Parent() != NULL)
+ {
+ // Assemble this index.
+ size_t currentIndex = 0; // Assume LL, change if otherwise.
+ if (currentQuery->Parent()->Right() == currentQuery)
+ currentIndex += 2; // Now it's RL.
+ if (currentRef->Parent()->Right() == currentRef)
+ currentIndex++; // Now it's LR or RR.
+
+ // Append this index.
+ index += (currentIndex << (level * 2));
+ ++level;
+
+ currentQuery = currentQuery->Parent();
+ currentRef = currentRef->Parent();
+ }
+
+ return index + 1; // Index 0 is the root.
+}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index 7d7e67a..ecad1ba 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -10,7 +10,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
#include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index f898c42..01d8eeb 100644
--- a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -11,7 +11,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_HPP
#include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
index 8bd513d..9d01b0d 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
@@ -9,7 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MAX_SPLIT_HPP
#include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
index 718b1de..332122b 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
@@ -10,7 +10,7 @@
#include <mlpack/core.hpp>
#include "rp_tree_max_split.hpp"
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
index c9a3006..5a16c45 100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
@@ -9,7 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
#include <mlpack/core.hpp>
-#include "perform_split.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
diff --git a/src/mlpack/core/tree/binary_space_tree/perform_split.hpp b/src/mlpack/core/tree/perform_split.hpp
similarity index 90%
rename from src/mlpack/core/tree/binary_space_tree/perform_split.hpp
rename to src/mlpack/core/tree/perform_split.hpp
index acb8c90..ff14028 100644
--- a/src/mlpack/core/tree/binary_space_tree/perform_split.hpp
+++ b/src/mlpack/core/tree/perform_split.hpp
@@ -8,9 +8,8 @@
* of the split column, and points from the right subtree are on the right side
* of the split column.
*/
-
-#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_PERFORM_SPLIT_HPP
-#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_PERFORM_SPLIT_HPP
+#ifndef MLPACK_CORE_TREE_PERFORM_SPLIT_HPP
+#define MLPACK_CORE_TREE_PERFORM_SPLIT_HPP
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -41,12 +40,17 @@ size_t PerformSplit(MatType& data,
// First half-iteration of the loop is out here because the termination
// condition is in the middle.
- while (SplitType::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+ while ((left <= right) &&
+ (SplitType::AssignToLeftNode(data.col(left), splitInfo)))
left++;
while ((!SplitType::AssignToLeftNode(data.col(right), splitInfo)) &&
(left <= right) && (right > 0))
right--;
+ // Shortcut for when all points are on the right.
+ if (left == right && right == 0)
+ return left;
+
while (left <= right)
{
// Swap columns.
@@ -102,12 +106,17 @@ size_t PerformSplit(MatType& data,
// First half-iteration of the loop is out here because the termination
// condition is in the middle.
- while (SplitType::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+ while ((left <= right) &&
+ (SplitType::AssignToLeftNode(data.col(left), splitInfo)))
left++;
while ((!SplitType::AssignToLeftNode(data.col(right), splitInfo)) &&
- (left <= right) && (right > 0))
+ (left <= right) && (right > 0))
right--;
+ // Shortcut for when all points are on the right.
+ if (left == right && right == 0)
+ return left;
+
while (left <= right)
{
// Swap columns.
@@ -135,7 +144,6 @@ size_t PerformSplit(MatType& data,
}
Log::Assert(left == right + 1);
-
return left;
}
More information about the mlpack-git
mailing list