[mlpack-git] master: Refactor tree-building to use PerformSplit. (e3c593c)

gitdub at mlpack.org gitdub at mlpack.org
Thu Sep 29 11:49:53 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit e3c593c47c3da6945b0b60ce9ce3bab5de9e09d5
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Sep 29 11:49:53 2016 -0400

    Refactor tree-building to use PerformSplit.
    
    In some cases (like the LCDM dataset) this can result in up to three orders of
    magnitude speedup for tree building.


>---------------------------------------------------------------

e3c593c47c3da6945b0b60ce9ce3bab5de9e09d5
 src/mlpack/core/tree/octree/octree.hpp      |  20 +++
 src/mlpack/core/tree/octree/octree_impl.hpp | 205 ++++++++++++++++++----------
 2 files changed, 150 insertions(+), 75 deletions(-)

diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index 5cdecfc..3402915 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -419,6 +419,26 @@ class Octree
                  const double width,
                  std::vector<size_t>& oldFromNew,
                  const size_t maxLeafSize);
+
+  /**
+   * This is used for sorting points while splitting.
+   */
+  struct SplitInfo
+  {
+    //! Create the SplitInfo object.
+    SplitInfo(const size_t d, const arma::vec& c) : d(d), center(c) {}
+
+    //! The dimension we are splitting on.
+    size_t d;
+    //! The center of the node.
+    const arma::vec& center;
+
+    template<typename VecType>
+    static bool AssignToLeftNode(const VecType& point, const SplitInfo& s)
+    {
+      return point[s.d] < s.center[s.d];
+    }
+  };
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/octree/octree_impl.hpp b/src/mlpack/core/tree/octree/octree_impl.hpp
index a366bf8..2e58191 100644
--- a/src/mlpack/core/tree/octree/octree_impl.hpp
+++ b/src/mlpack/core/tree/octree/octree_impl.hpp
@@ -8,6 +8,8 @@
 #define MLPACK_CORE_TREE_OCTREE_OCTREE_IMPL_HPP
 
 #include "octree.hpp"
+#include <mlpack/core/tree/perform_split.hpp>
+#include <stack>
 
 namespace mlpack {
 namespace tree {
@@ -710,48 +712,77 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
   if (count <= maxLeafSize)
     return;
 
-  // We must split the dataset by sequentially creating each of the children.
-  // We do this in two steps: first we make a pass to count the number of points
-  // that will fall into each child; then in the second pass we rearrange the
-  // points and create the children.
-  arma::Col<size_t> childCounts(std::pow(2, dataset->n_rows),
-      arma::fill::zeros);
-  arma::Col<size_t> assignments(count, arma::fill::zeros);
-
-  // First pass: calculate number of points in each child, and find child
-  // assignments for each point.
-  for (size_t i = 0; i < count; ++i)
+  // This will hold the index of the first point in each child.
+  arma::Col<size_t> childBegins(std::pow(2, dataset->n_rows) + 1);
+  childBegins[0] = begin;
+  childBegins[childBegins.n_elem - 1] = begin + count;
+
+  // We will make log2(dim) passes, splitting along the last down to the first
+  // dimension.  The tuple holds { dim, begin, count, leftChildIndex }.
+  std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
+  stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
+      begin, count, 0));
+
+  while (!stack.empty())
   {
-    for (size_t d = 0; d < dataset->n_rows; ++d)
+    std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
+    stack.pop();
+
+    const size_t d = std::get<0>(t);
+    const size_t childBegin = std::get<1>(t);
+    const size_t childCount = std::get<2>(t);
+    const size_t leftChildIndex = std::get<3>(t);
+
+    // Perform a "half-split": after this split, all points belonging to
+    // children of index 2^(d - 1) - 1 and less will be on the left side, and
+    // all points belonging to children of index 2^(d - 1) and above will be on
+    // the right side.
+    SplitInfo s(d, center);
+    const size_t firstRight = split::PerformSplit<MatType, SplitInfo>(*dataset,
+        childBegin, childCount, s);
+
+    // We can set the first index of the right child.  The first index of the
+    // left child is already set.
+    const size_t rightChildIndex = leftChildIndex + std::pow(2, d);
+    childBegins[rightChildIndex] = firstRight;
+
+    // Now we have to recurse, if this was not the last dimension.
+    if (d != 0)
     {
-      // We are guaranteed that the points fall within 'width / 2' of the center
-      // in each dimension, so we just need to check which side of the center
-      // the points fall on.  The last dimension represents the most significant
-      // bit in the assignment; the bit is '1' if it falls to the right of the
-      // center.
-      if ((*dataset)(d, begin + i) > center(d))
-        assignments(i) |= (1 << d);
+      if (firstRight > childBegin)
+      {
+        stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
+            firstRight - childBegin, leftChildIndex));
+      }
+      else
+      {
+        // Set beginning indices correctly for all children below this level.
+        for (size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
+          childBegins[c] = childBegins[leftChildIndex];
+      }
+
+      if (firstRight < childBegin + childCount)
+      {
+        stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
+            childCount - (firstRight - childBegin), rightChildIndex));
+      }
+      else
+      {
+        // Set beginning indices correctly for all children below this level.
+        for (size_t c = rightChildIndex + 1;
+             c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
+          childBegins[c] = childBegins[rightChildIndex];
+      }
     }
-
-    childCounts(assignments(i))++;
   }
 
-  // Sort all of the points so we know where to copy them.
-  arma::uvec ordering = arma::stable_sort_index(assignments, "ascend");
-
-  // This strategy may copy the matrix during the computation, but that isn't
-  // really a problem.  We use non-contiguous submatrix views to extract the
-  // columns in the correct order.
-  dataset->cols(begin, begin + count - 1) = dataset->cols(begin + ordering);
-
   // Now that the dataset is reordered, we can create the children.
-  size_t childBegin = begin;
   arma::vec childCenter(center.n_elem);
   const double childWidth = width / 2.0;
-  for (size_t i = 0; i < childCounts.n_elem; ++i)
+  for (size_t i = 0; i < childBegins.n_elem - 1; ++i)
   {
     // If the child has no points, don't create it.
-    if (childCounts[i] == 0)
+    if (childBegins[i + 1] - childBegins[i] == 0)
       continue;
 
     // Create the correct center.
@@ -764,10 +795,9 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
         childCenter[d] = center[d] + childWidth;
     }
 
-    children.push_back(new Octree(this, childBegin, childCounts[i], childCenter,
-        childWidth, maxLeafSize));
-
-    childBegin += childCounts[i];
+    children.push_back(new Octree(this, childBegins[i],
+        childBegins[i + 1] - childBegins[i], childCenter, childWidth,
+        maxLeafSize));
   }
 }
 
@@ -784,51 +814,77 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
   if (count <= maxLeafSize)
     return;
 
-  // We must split the dataset by sequentially creating each of the children.
-  // We do this in two steps: first we make a pass to count the number of points
-  // that will fall into each child; then in the second pass we rearrange the
-  // points and create the children.
-  arma::Col<size_t> childCounts(std::pow(2, dataset->n_rows),
-      arma::fill::zeros);
-  arma::Col<size_t> assignments(count, arma::fill::zeros);
-
-  // First pass: calculate number of points in each child, and find child
-  // assignments for each point.
-  for (size_t i = 0; i < count; ++i)
+  // This will hold the index of the first point in each child.
+  arma::Col<size_t> childBegins(std::pow(2, dataset->n_rows) + 1);
+  childBegins[0] = begin;
+  childBegins[childBegins.n_elem - 1] = begin + count;
+
+  // We will make log2(dim) passes, splitting along the last down to the first
+  // dimension.  The tuple holds { dim, begin, count, leftChildIndex }.
+  std::stack<std::tuple<size_t, size_t, size_t, size_t>> stack;
+  stack.push(std::tuple<size_t, size_t, size_t, size_t>(dataset->n_rows - 1,
+      begin, count, 0));
+
+  while (!stack.empty())
   {
-    for (size_t d = 0; d < dataset->n_rows; ++d)
+    std::tuple<size_t, size_t, size_t, size_t> t = stack.top();
+    stack.pop();
+
+    const size_t d = std::get<0>(t);
+    const size_t childBegin = std::get<1>(t);
+    const size_t childCount = std::get<2>(t);
+    const size_t leftChildIndex = std::get<3>(t);
+
+    // Perform a "half-split": after this split, all points belonging to
+    // children of index 2^(d - 1) - 1 and less will be on the left side, and
+    // all points belonging to children of index 2^(d - 1) and above will be on
+    // the right side.
+    SplitInfo s(d, center);
+    const size_t firstRight = split::PerformSplit<MatType, SplitInfo>(*dataset,
+        childBegin, childCount, s, oldFromNew);
+
+    // We can set the first index of the right child.  The first index of the
+    // left child is already set.
+    const size_t rightChildIndex = leftChildIndex + std::pow(2, d);
+    childBegins[rightChildIndex] = firstRight;
+
+    // Now we have to recurse, if this was not the last dimension.
+    if (d != 0)
     {
-      // We are guaranteed that the points fall within 'width / 2' of the center
-      // in each dimension, so we just need to check which side of the center
-      // the points fall on.  The last dimension represents the most significant
-      // bit in the assignment; the bit is '1' if it falls to the right of the
-      // center.
-      if ((*dataset)(d, begin + i) > center(d))
-        assignments(i) |= (1 << d);
+      if (firstRight > childBegin)
+      {
+        stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, childBegin,
+            firstRight - childBegin, leftChildIndex));
+      }
+      else
+      {
+        // Set beginning indices correctly for all children below this level.
+        for (size_t c = leftChildIndex + 1; c < rightChildIndex; ++c)
+          childBegins[c] = childBegins[leftChildIndex];
+      }
+
+      if (firstRight < childBegin + childCount)
+      {
+        stack.push(std::tuple<size_t, size_t, size_t, size_t>(d - 1, firstRight,
+            childCount - (firstRight - childBegin), rightChildIndex));
+      }
+      else
+      {
+        // Set beginning indices correctly for all children below this level.
+        for (size_t c = rightChildIndex + 1;
+             c < rightChildIndex + (rightChildIndex - leftChildIndex); ++c)
+          childBegins[c] = childBegins[rightChildIndex];
+      }
     }
-
-    childCounts(assignments(i))++;
   }
 
-  // Sort all of the points so we know where to copy them.
-  arma::uvec ordering = arma::stable_sort_index(assignments, "ascend");
-
-  // This strategy may copy the matrix during the computation, but that isn't
-  // really a problem.  We use non-contiguous submatrix views to extract the
-  // columns in the correct order.
-  dataset->cols(begin, begin + count - 1) = dataset->cols(begin + ordering);
-  std::vector<size_t> oldFromNewCopy(oldFromNew); // We need the old indices.
-  for (size_t i = 0; i < count; ++i)
-    oldFromNew[i + begin] = oldFromNewCopy[ordering[i] + begin];
-
   // Now that the dataset is reordered, we can create the children.
-  size_t childBegin = begin;
   arma::vec childCenter(center.n_elem);
   const double childWidth = width / 2.0;
-  for (size_t i = 0; i < childCounts.n_elem; ++i)
+  for (size_t i = 0; i < childBegins.n_elem - 1; ++i)
   {
     // If the child has no points, don't create it.
-    if (childCounts[i] == 0)
+    if (childBegins[i + 1] - childBegins[i] == 0)
       continue;
 
     // Create the correct center.
@@ -841,10 +897,9 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
         childCenter[d] = center[d] + childWidth;
     }
 
-    children.push_back(new Octree(this, childBegin, childCounts[i], oldFromNew,
-        childCenter, childWidth, maxLeafSize));
-
-    childBegin += childCounts[i];
+    children.push_back(new Octree(this, childBegins[i],
+        childBegins[i + 1] - childBegins[i], oldFromNew, childCenter,
+        childWidth, maxLeafSize));
   }
 }
 




More information about the mlpack-git mailing list