[mlpack-git] master: Removed an overload of PerformSplit() for the UB tree. Added a function that implements the default binary split behaviour. Other minor fixes. (9bcd066)

gitdub at mlpack.org gitdub at mlpack.org
Fri Aug 26 15:24:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit 9bcd066e7a32021fd36778398703bf87552e5c31
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Fri Aug 26 22:24:08 2016 +0300

    Removed an overload of PerformSplit() for the UB tree. Added a function that implements the default binary split behaviour. Other minor fixes.


>---------------------------------------------------------------

9bcd066e7a32021fd36778398703bf87552e5c31
 src/mlpack/core/tree/CMakeLists.txt                |   1 +
 src/mlpack/core/tree/address.hpp                   |  30 ++--
 .../tree/binary_space_tree/binary_space_tree.hpp   |  80 -----------
 .../binary_space_tree/binary_space_tree_impl.hpp   | 156 +--------------------
 .../core/tree/binary_space_tree/mean_split.hpp     |  51 ++++++-
 .../core/tree/binary_space_tree/midpoint_split.hpp |  51 ++++++-
 .../tree/binary_space_tree/rp_tree_max_split.hpp   |  50 ++++++-
 .../tree/binary_space_tree/rp_tree_mean_split.hpp  |  50 ++++++-
 .../core/tree/binary_space_tree/ub_tree_split.hpp  |  10 +-
 .../tree/binary_space_tree/ub_tree_split_impl.hpp  |  10 +-
 .../tree/binary_space_tree/vantage_point_split.hpp |  44 ++++++
 src/mlpack/core/tree/cellbound.hpp                 |   7 +-
 src/mlpack/core/tree/cellbound_impl.hpp            |   2 +-
 13 files changed, 254 insertions(+), 288 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 3b05381..0e3c998 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -15,6 +15,7 @@ set(SOURCES
   binary_space_tree/mean_split_impl.hpp
   binary_space_tree/midpoint_split.hpp
   binary_space_tree/midpoint_split_impl.hpp
+  binary_space_tree/perform_split.hpp
   binary_space_tree/rp_tree_max_split.hpp
   binary_space_tree/rp_tree_max_split_impl.hpp
   binary_space_tree/rp_tree_mean_split.hpp
diff --git a/src/mlpack/core/tree/address.hpp b/src/mlpack/core/tree/address.hpp
index 4d4520c..72202e3 100644
--- a/src/mlpack/core/tree/address.hpp
+++ b/src/mlpack/core/tree/address.hpp
@@ -37,6 +37,16 @@ namespace addr {
  * variables should be equal-sized and the type of the address should correspond
  * to the type of the vector.
  *
+ * The function maps each floating point coordinate to an equal-sized unsigned
+ * integer datatype in such a way that the transform preserves the ordering
+ * (i.e. lower coordinates correspond to lower integers). Thus, the mapping
+ * saves the exponent and the mantissa of each floating point value
+ * consequently, furthermore the exponent is stored before the mantissa. In the
+ * case of negative numbers the resulting integer value should be inverted.
+ * In the multi-dimensional case, after we transform the representation, we
+ * have to interleave the bits of the new representation across all the elements
+ * in the address vector.
+ *
  * @param address The resulting address.
  * @param point The point that is being translated to the address.
  */
@@ -122,6 +132,8 @@ void PointToAddress(AddressType& address, const VecType& point)
  * variables should be equal-sized and the type of the address should correspond
  * to the type of the vector.
  *
+ * The function makes the backward transform to the function above.
+ *
  * @param address An address to translate.
  * @param point The point that corresponds to the address.
  */
@@ -201,9 +213,9 @@ void AddressToPoint(VecType& point, const AddressType& address)
 template<typename AddressType1, typename AddressType2>
 int CompareAddresses(const AddressType1& addr1, const AddressType2& addr2)
 {
-  static_assert(sizeof(typename AddressType1::elem_type) ==
-      sizeof(typename AddressType2::elem_type), "We aren't able to compare "
-      "adresses of distinct sizes");
+  static_assert(std::is_same<typename AddressType1::elem_type,
+      typename AddressType2::elem_type>::value == true, "We aren't able to "
+      "compare adresses of distinct types");
 
   assert(addr1.n_elem == addr2.n_elem);
 
@@ -225,13 +237,13 @@ template<typename AddressType1, typename AddressType2, typename AddressType3>
 bool Contains(const AddressType1& address, const AddressType2& loBound,
                      const AddressType3& hiBound)
 {
-  static_assert(sizeof(typename AddressType1::elem_type) ==
-      sizeof(typename AddressType2::elem_type), "We aren't able to compare "
-      "adresses of distinct sizes");
+  static_assert(std::is_same<typename AddressType1::elem_type,
+      typename AddressType2::elem_type>::value == true, "We aren't able to "
+      "compare adresses of distinct types");
 
-  static_assert(sizeof(typename AddressType1::elem_type) ==
-      sizeof(typename AddressType3::elem_type), "We aren't able to compare "
-      "adresses of distinct sizes");
+  static_assert(std::is_same<typename AddressType1::elem_type,
+      typename AddressType3::elem_type>::value == true, "We aren't able to "
+      "compare adresses of distinct types");
 
   assert(address.n_elem == loBound.n_elem);
   assert(address.n_elem == hiBound.n_elem);
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index bcbbbac..b024ee5 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -478,86 +478,6 @@ class BinarySpaceTree
                  SplitType<BoundType<MetricType>, MatType>& splitter);
 
   /**
-   * Perform the split process according to the information about the
-   * split.
-   *
-   * @param bound The bound used for this node.
-   * @param data The dataset used by the binary space tree.
-   * @param begin Index of the starting point in the dataset that belongs to
-   *    this node.
-   * @param count Number of points in this node.
-   * @param splitInfo The information about the split.
-   */
-  template<typename SplitInfo>
-  size_t PerformSplit(MatType& data,
-                      const size_t begin,
-                      const size_t count,
-                      const SplitInfo& splitInfo);
-
-  /**
-   * Perform the split process according to the information about the split and
-   * return a list of changed indices.
-   *
-   * @param bound The bound used for this node.
-   * @param data The dataset used by the binary space tree.
-   * @param begin Index of the starting point in the dataset that belongs to
-   *    this node.
-   * @param count Number of points in this node.
-   * @param splitInfo The information about the split.
-   * @param oldFromNew Vector which will be filled with the old positions for
-   *    each new point.
-   */
-  template<typename SplitInfo>
-  size_t PerformSplit(MatType& data,
-                      const size_t begin,
-                      const size_t count,
-                      const SplitInfo& splitInfo,
-                      std::vector<size_t>& oldFromNew);
-
-  /**
-   * An overload for the universal B tree. For the first time the function
-   * rearranges the whole dataset. Next time the function only returns the split
-   * column.
-   *
-   * @param bound The bound used for this node.
-   * @param data The dataset used by the binary space tree.
-   * @param begin Index of the starting point in the dataset that belongs to
-   *    this node.
-   * @param count Number of points in this node.
-   * @param splitInfo The information about the split.
-   * @param oldFromNew Vector which will be filled with the old positions for
-   *    each new point.
-   */
-  size_t PerformSplit(
-      MatType& data,
-      const size_t begin,
-      const size_t count,
-      const typename UBTreeSplit<BoundType<MetricType>,
-                                 MatType>::SplitInfo& splitInfo);
-
-  /**
-   * An overload for the universal B tree. For the first time the function
-   * rearranges the whole dataset. Next time the function only returns the split
-   * column.
-   *
-   * @param bound The bound used for this node.
-   * @param data The dataset used by the binary space tree.
-   * @param begin Index of the starting point in the dataset that belongs to
-   *    this node.
-   * @param count Number of points in this node.
-   * @param splitInfo The information about the split.
-   * @param oldFromNew Vector which will be filled with the old positions for
-   *    each new point.
-   */
-  size_t PerformSplit(
-      MatType& data,
-      const size_t begin,
-      const size_t count,
-      const typename UBTreeSplit<BoundType<MetricType>,
-                                 MatType>::SplitInfo& splitInfo,
-      std::vector<size_t>& oldFromNew);
-
-  /**
    * Update the bound of the current node. This method does not take into
    * account bound-specific properties.
    *
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index f2d96e7..7301a97 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -667,7 +667,7 @@ void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
   // Perform the actual splitting.  This will order the dataset such that
   // points that belong to the left subtree are on the left of splitCol, and
   // points from the right subtree are on the right side of splitCol.
-  splitCol = PerformSplit(*dataset, begin, count, splitInfo);
+  splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
 
   assert(splitCol > begin);
   assert(splitCol < begin + count);
@@ -733,7 +733,8 @@ SplitNode(std::vector<size_t>& oldFromNew,
   // Perform the actual splitting.  This will order the dataset such that
   // points that belong to the left subtree are on the left of splitCol, and
   // points from the right subtree are on the right side of splitCol.
-  splitCol = PerformSplit(*dataset, begin, count, splitInfo, oldFromNew);
+  splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
+      oldFromNew);
 
   assert(splitCol > begin);
   assert(splitCol < begin + count);
@@ -765,157 +766,6 @@ template<typename MetricType,
          template<typename BoundMetricType, typename...> class BoundType,
          template<typename SplitBoundType, typename SplitMatType>
              class SplitType>
-template<typename SplitInfo>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
-    SplitType>::PerformSplit(MatType& data,
-                             const size_t begin,
-                             const size_t count,
-                             const SplitInfo& splitInfo)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // splitVal should be on the left side of the matrix, and the points greater
-  // than splitVal should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
-    left++;
-  while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
-      (left <= right) && (right > 0))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
-        (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
-        (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
-template<typename SplitInfo>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
-    SplitType>::PerformSplit(MatType& data,
-                             const size_t begin,
-                             const size_t count,
-                             const SplitInfo& splitInfo,
-                             std::vector<size_t>& oldFromNew)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // splitVal should be on the left side of the matrix, and the points greater
-  // than splitVal should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
-    left++;
-  while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
-      (left <= right) && (right > 0))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // Update the indices for what we changed.
-    size_t t = oldFromNew[left];
-    oldFromNew[left] = oldFromNew[right];
-    oldFromNew[right] = t;
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
-        (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
-        (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
-    SplitType>::PerformSplit(
-    MatType& data,
-    const size_t begin,
-    const size_t count,
-    const typename UBTreeSplit<BoundType<MetricType>,
-                               MatType>::SplitInfo& splitInfo)
-{
-  return SplitType<BoundType<MetricType>, MatType>::PerformSplit(data, begin,
-      count, splitInfo);
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
-    SplitType>::PerformSplit(
-    MatType& data,
-    const size_t begin,
-    const size_t count,
-    const typename UBTreeSplit<BoundType<MetricType>,
-                               MatType>::SplitInfo& splitInfo,
-    std::vector<size_t>& oldFromNew)
-{
-  return SplitType<BoundType<MetricType>, MatType>::PerformSplit(data, begin,
-      count, splitInfo, oldFromNew);
-}
-
-template<typename MetricType,
-         typename StatisticType,
-         typename MatType,
-         template<typename BoundMetricType, typename...> class BoundType,
-         template<typename SplitBoundType, typename SplitMatType>
-             class SplitType>
 template<typename BoundType2>
 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
 UpdateBound(BoundType2& boundToUpdate)
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index d3979d8..4c07627 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -10,6 +10,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
 
 #include <mlpack/core.hpp>
+#include "perform_split.hpp"
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -23,13 +24,6 @@ template<typename BoundType, typename MatType = arma::mat>
 class MeanSplit
 {
  public:
-  /**
-   * Indicates that this class does not perform the actual splitting i.e.
-   * it does not reorder the dataset. If this variable is false, the class
-   * finds the partition and reorders the dataset.
-   */
-  static constexpr bool NeedRearrangeDataset = true;
-
   //! An information about the partition.
   struct SplitInfo
   {
@@ -59,6 +53,49 @@ class MeanSplit
                         SplitInfo& splitInfo);
 
   /**
+   * Perform the split process according to the information about the
+   * split.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo)
+  {
+    return split::PerformSplit<MatType, MeanSplit>(data, begin, count,
+        splitInfo);
+  }
+
+  /**
+   * Perform the split process according to the information about the split and
+   * return the list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo,
+                             std::vector<size_t>& oldFromNew)
+  {
+    return split::PerformSplit<MatType, MeanSplit>(data, begin, count,
+        splitInfo, oldFromNew);
+  }
+
+  /**
    * Indicates that a point should be assigned to the left subtree.
    *
    * @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index 522d298..66c3043 100644
--- a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -11,6 +11,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_HPP
 
 #include <mlpack/core.hpp>
+#include "perform_split.hpp"
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -24,13 +25,6 @@ template<typename BoundType, typename MatType = arma::mat>
 class MidpointSplit
 {
  public:
-  /**
-   * Indicates that this class does not perform the actual splitting i.e.
-   * it does not reorder the dataset. If this variable is false, the class
-   * finds the partition and reorders the dataset.
-   */
-  static constexpr bool NeedRearrangeDataset = true;
-
   //! A struct that contains an information about the split.
   struct SplitInfo
   {
@@ -59,6 +53,49 @@ class MidpointSplit
                         SplitInfo& splitInfo);
 
   /**
+   * Perform the split process according to the information about the
+   * split.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo)
+  {
+    return split::PerformSplit<MatType, MidpointSplit>(data, begin, count,
+        splitInfo);
+  }
+
+  /**
+   * Perform the split process according to the information about the split and
+   * return the list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo,
+                             std::vector<size_t>& oldFromNew)
+  {
+    return split::PerformSplit<MatType, MidpointSplit>(data, begin, count,
+        splitInfo, oldFromNew);
+  }
+
+  /**
    * Indicates that a point should be assigned to the left subtree.
    *
    * @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
index f49e2a2..d556cb5 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
@@ -9,6 +9,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MAX_SPLIT_HPP
 
 #include <mlpack/core.hpp>
+#include "perform_split.hpp"
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -28,12 +29,6 @@ class RPTreeMaxSplit
  public:
   //! The element type held by the matrix type.
   typedef typename MatType::elem_type ElemType;
-  /**
-   * Indicates that this class does not perform the actual splitting i.e.
-   * it does not reorder the dataset. If this variable is false, the class
-   * finds the partition and reorders the dataset.
-   */
-  static constexpr bool NeedRearrangeDataset = true;
   //! An information about the partition.
   struct SplitInfo
   {
@@ -61,6 +56,49 @@ class RPTreeMaxSplit
                         SplitInfo& splitInfo);
 
   /**
+   * Perform the split process according to the information about the
+   * split.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo)
+  {
+    return split::PerformSplit<MatType, RPTreeMaxSplit>(data, begin, count,
+        splitInfo);
+  }
+
+  /**
+   * Perform the split process according to the information about the split and
+   * return the list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo,
+                             std::vector<size_t>& oldFromNew)
+  {
+    return split::PerformSplit<MatType, RPTreeMaxSplit>(data, begin, count,
+        splitInfo, oldFromNew);
+  }
+
+  /**
    * Indicates that a point should be assigned to the left subtree.
    *
    * @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
index f54645c..64bb10b 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
@@ -10,6 +10,7 @@
 
 #include <mlpack/core.hpp>
 #include "rp_tree_max_split.hpp"
+#include "perform_split.hpp"
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -28,12 +29,6 @@ class RPTreeMeanSplit
  public:
   //! The element type held by the matrix type.
   typedef typename MatType::elem_type ElemType;
-  /**
-   * Indicates that this class does not perform the actual splitting i.e.
-   * it does not reorder the dataset. If this variable is false, the class
-   * finds the partition and reorders the dataset.
-   */
-  static constexpr bool NeedRearrangeDataset = true;
   //! An information about the partition.
   struct SplitInfo
   {
@@ -67,6 +62,49 @@ class RPTreeMeanSplit
                         SplitInfo& splitInfo);
 
   /**
+   * Perform the split process according to the information about the
+   * split.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo)
+  {
+    return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
+        splitInfo);
+  }
+
+  /**
+   * Perform the split process according to the information about the split and
+   * return the list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo,
+                             std::vector<size_t>& oldFromNew)
+  {
+    return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
+        splitInfo, oldFromNew);
+  }
+
+  /**
    * Indicates that a point should be assigned to the left subtree.
    *
    * @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
index f7291ca..c0f71b7 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
@@ -23,13 +23,6 @@ class UBTreeSplit
                                     uint32_t,
                                     uint64_t>::type AddressElemType;
 
-  /**
-   * This class performs the actual splitting i.e. it reorders the dataset.
-   * This variable should be equal to true if the class does not perform the
-   * actual splitting.
-   */
-  static constexpr bool NeedRearrangeDataset = false;
-
   //! An information about the partition.
   struct SplitInfo
   {
@@ -69,7 +62,8 @@ class UBTreeSplit
                              const SplitInfo& splitInfo);
 
   /**
-   * Rearrange the dataset according to the addresses.
+   * Rearrange the dataset according to the addresses and return the list
+   * of changed indices.
    *
    * @param data The dataset used by the binary space tree.
    * @param begin Index of the starting point in the dataset that belongs to
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index c3bc7e1..51beba8 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -204,18 +204,14 @@ size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
   if (splitInfo.addresses)
   {
     std::vector<size_t> newFromOld(data.n_cols);
-    std::vector<size_t> newToOld(data.n_cols);
 
     for (size_t i = 0; i < splitInfo.addresses->size(); i++)
-    {
       newFromOld[i] = i;
-      newToOld[i] = i;
-    }
 
     for (size_t i = 0; i < splitInfo.addresses->size(); i++)
     {
       size_t index = (*splitInfo.addresses)[i].second;
-      size_t oldI = newToOld[i];
+      size_t oldI = oldFromNew[i];
       size_t newIndex = newFromOld[index];
 
       data.swap_cols(i, newFromOld[index]);
@@ -224,10 +220,6 @@ size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
       newFromOld[index] = i;
       newFromOld[oldI] = tmp;
 
-      tmp = newToOld[i];
-      newToOld[i] = newToOld[newIndex];
-      newToOld[newIndex] = tmp;
-
       tmp = oldFromNew[i];
       oldFromNew[i] = oldFromNew[newIndex];
       oldFromNew[newIndex] = tmp;
diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
index 015d2d8..ef85ec7 100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
@@ -9,6 +9,7 @@
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
 
 #include <mlpack/core.hpp>
+#include "perform_split.hpp"
 
 namespace mlpack {
 namespace tree /** Trees and tree-building procedures. */ {
@@ -65,6 +66,49 @@ class VantagePointSplit
                         SplitInfo& splitInfo);
 
   /**
+   * Perform the split process according to the information about the
+   * split.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo)
+  {
+    return split::PerformSplit<MatType, VantagePointSplit>(data, begin, count,
+        splitInfo);
+  }
+
+  /**
+   * Perform the split process according to the information about the split and
+   * return the list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitInfo The information about the split.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const SplitInfo& splitInfo,
+                             std::vector<size_t>& oldFromNew)
+  {
+    return split::PerformSplit<MatType, VantagePointSplit>(data, begin, count,
+        splitInfo, oldFromNew);
+  }
+
+  /**
    * Indicates that a point should be assigned to the left subtree.
    * This method returns true if a point should be assigned to the left subtree,
    * i.e., if the distance from the point to the vantage point is less then the
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index a653c6e..8841f8b 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -270,7 +270,8 @@ class CellBound
                 const arma::Col<ElemType>& hiCorner,
                 const MatType& data);
   /**
-   * Initialize all subrectangles that touches the lower address.
+   * Initialize all subrectangles that touches the lower address. This function
+   * should be called before InitLowerBound().
    *
    * @param numEqualBits The number of equal leading bits of the lower address
    * and the high address.
@@ -278,8 +279,10 @@ class CellBound
    */
   template<typename MatType>
   void InitHighBound(size_t numEqualBits, const MatType& data);
+
   /**
-   * Initialize all subrectangles that touches the high address.
+   * Initialize all subrectangles that touches the high address. This function
+   * should be called after InitHighBound().
    *
    * @param numEqualBits The number of equal leading bits of the lower address
    * and the high address.
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index ebb9a34..ba18379 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -328,7 +328,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits,
 
     // We ran out of the limit of hyperrectangles. In that case we enlare
     // the last hyperrectangle.
-    if (numCorners >= maxNumBounds / 2)
+    if (numCorners >= maxNumBounds - numBounds)
       tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
   }
 




More information about the mlpack-git mailing list