[mlpack-git] master, mlpack-1.0.x: Incorporate patch from yashdv to move splitting procedure to a different class. (a2feea8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:45:27 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit a2feea8ccd1db9244a426aa007e63205acc4aa75
Author: Yash Vadalia <yashdv at gmail.com>
Date:   Mon Mar 24 18:52:43 2014 +0000

    Incorporate patch from yashdv to move splitting procedure to a different class.


>---------------------------------------------------------------

a2feea8ccd1db9244a426aa007e63205acc4aa75
 src/mlpack/core/tree/CMakeLists.txt                |   2 +
 .../tree/binary_space_tree/binary_space_tree.hpp   |  30 +-
 .../binary_space_tree/binary_space_tree_impl.hpp   | 375 +++++++++------------
 .../tree/binary_space_tree/dual_tree_traverser.hpp |   8 +-
 .../binary_space_tree/dual_tree_traverser_impl.hpp |  19 +-
 .../core/tree/binary_space_tree/mean_split.hpp     | 118 +++++++
 .../tree/binary_space_tree/mean_split_impl.hpp     | 195 +++++++++++
 .../binary_space_tree/single_tree_traverser.hpp    |   8 +-
 .../single_tree_traverser_impl.hpp                 |  17 +-
 9 files changed, 527 insertions(+), 245 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index b090514..85c1d8d 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -7,6 +7,8 @@ set(SOURCES
   binary_space_tree/binary_space_tree_impl.hpp
   binary_space_tree/dual_tree_traverser.hpp
   binary_space_tree/dual_tree_traverser_impl.hpp
+  binary_space_tree/mean_split.hpp
+  binary_space_tree/mean_split_impl.hpp
   binary_space_tree/single_tree_traverser.hpp
   binary_space_tree/single_tree_traverser_impl.hpp
   binary_space_tree/traits.hpp
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index b169d86..183c711 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -7,6 +7,7 @@
 #define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
 
 #include <mlpack/core.hpp>
+#include "mean_split.hpp"
 
 #include "../statistic.hpp"
 
@@ -30,10 +31,15 @@ namespace tree /** Trees and tree-building procedures. */ {
  *     bounds/.
  * @tparam StatisticType Extra data contained in the node.  See statistic.hpp
  *     for the necessary skeleton interface.
+ * @tparam MatType The dataset class.
+ * @tparam SplitType The class that partitions the dataset/points at a
+ *     particular node into two parts. Its definition decides the way this split
+ *     is done.
  */
 template<typename BoundType,
          typename StatisticType = EmptyStatistic,
-         typename MatType = arma::mat>
+         typename MatType = arma::mat,
+         typename SplitType = MeanSplit<BoundType, MatType> >
 class BinarySpaceTree
 {
  private:
@@ -463,28 +469,6 @@ class BinarySpaceTree
    */
   void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
 
-  /**
-   * Find the index to split on for this node, given that we are splitting in
-   * the given split dimension on the specified split value.
-   *
-   * @param data Dataset which we are using.
-   * @param splitDim Dimension of dataset to split on.
-   * @param splitVal Value to split on, in the given split dimension.
-   */
-  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
-
-  /**
-   * Find the index to split on for this node, given that we are splitting in
-   * the given split dimension on the specified split value.  Also returns a
-   * list of the changed indices.
-   *
-   * @param data Dataset which we are using.
-   * @param splitDim Dimension of dataset to split on.
-   * @param splitVal Value to split on, in the given split dimension.
-   * @param oldFromNew Vector holding permuted indices.
-   */
-  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
-      std::vector<size_t>& oldFromNew);
  public:
   /**
    * Returns a string representation of this object.
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index 26f2336..25ed4f5 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -18,8 +18,11 @@ namespace tree {
 
 // Each of these overloads is kept as a separate function to keep the overhead
 // from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t leafSize) :
     left(NULL),
@@ -39,8 +42,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     std::vector<size_t>& oldFromNew,
     const size_t leafSize) :
@@ -66,8 +72,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     std::vector<size_t>& oldFromNew,
     std::vector<size_t>& newFromOld,
@@ -99,8 +108,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
     newFromOld[oldFromNew[i]] = i;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -122,8 +134,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -150,8 +165,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -203,8 +221,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
  * Create a binary space tree by copying the other tree.  Be careful!  This can
  * take a long time and use a lot of memory.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     const BinarySpaceTree& other) :
     left(NULL),
     right(NULL),
@@ -238,8 +259,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
  * destructors in turn.  This will invalidate any pointers or references to any
  * nodes which are children of this one.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+  ~BinarySpaceTree()
 {
   if (left)
     delete left;
@@ -258,9 +283,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
  * @param queryCount The Count() of the node to find.
  * @return The found node, or NULL if nothing is found.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-const BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+const BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
     size_t queryBegin,
     size_t queryCount) const
 {
@@ -288,9 +316,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
  * @param queryCount the Count() of the node to find
  * @return the found node, or NULL
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
     const size_t queryBegin,
     const size_t queryCount)
 {
@@ -309,8 +340,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
     return NULL;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::ExtendTree(
     size_t level)
 {
   --level;
@@ -340,16 +374,24 @@ size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
  *     to avoid exceeding the stack limit
  */
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    TreeSize() const
 {
   // Recursively count the nodes on each side of the tree.  The plus one is
   // because we have to count this node, too.
   return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    TreeDepth() const
 {
   // Recursively count the depth on each side of the tree.  The plus one is
   // because we have to count this node, too.
@@ -357,8 +399,12 @@ size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
                       (right ? right->TreeDepth() : 0));
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    IsLeaf() const
 {
   return !left;
 }
@@ -366,9 +412,12 @@ inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
 /**
  * Returns the number of children in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-    BinarySpaceTree<BoundType, StatisticType, MatType>::NumChildren() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumChildren() const
 {
   if (left && right)
     return 2;
@@ -382,8 +431,11 @@ inline size_t
  * Return a bound on the furthest point in the node from the centroid.  This
  * returns 0 unless the node is a leaf.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
     FurthestPointDistance() const
 {
   if (IsLeaf())
@@ -401,8 +453,11 @@ inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
  * furthest descendant distance may be less than what this method returns (but
  * it will never be greater than this).
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
     FurthestDescendantDistance() const
 {
   return furthestDescendantDistance;
@@ -411,10 +466,13 @@ inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
 /**
  * Return the specified child.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline BinarySpaceTree<BoundType, StatisticType, MatType>&
-    BinarySpaceTree<BoundType, StatisticType, MatType>::Child(
-    const size_t child) const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+        Child(const size_t child) const
 {
   if (child == 0)
     return *left;
@@ -425,9 +483,12 @@ inline BinarySpaceTree<BoundType, StatisticType, MatType>&
 /**
  * Return the number of points contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumPoints() const
 {
   if (left)
     return 0;
@@ -438,9 +499,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
 /**
  * Return the number of descendants contained in the node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumDescendants() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumDescendants() const
 {
   return count;
 }
@@ -448,10 +512,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::NumDescendants() const
 /**
  * Return the index of a particular descendant contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Descendant(
-    const size_t index) const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    Descendant(const size_t index) const
 {
   return (begin + index);
 }
@@ -459,10 +525,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::Descendant(
 /**
  * Return the index of a particular point contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
-    const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    Point(const size_t index) const
 {
   return (begin + index);
 }
@@ -470,14 +538,21 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
 /**
  * Gets the index one beyond the last index in the series.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    End() const
 {
   return begin + count;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
     MatType& data)
 {
   // We need to expand the bounds of this node properly.
@@ -490,35 +565,21 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
   if (count <= leafSize)
     return; // We can't split this.
 
-  // Figure out which dimension to split on.
-  size_t splitDim = data.n_rows; // Indicate invalid by maxDim + 1.
-  double maxWidth = -1;
-
-  // Find the split dimension.
-  for (size_t d = 0; d < data.n_rows; d++)
-  {
-    double width = bound[d].Width();
-
-    if (width > maxWidth)
-    {
-      maxWidth = width;
-      splitDim = d;
-    }
-  }
-  splitDimension = splitDim;
+  // splitCol denotes the two partitions of the dataset after the split. The
+  // points on its left go to the left child and the others go to the right
+  // child.
+  size_t splitCol;
 
-  // Split in the middle of that dimension.
-  double splitVal = bound[splitDim].Mid();
+  // Split the node. The elements of 'data' are reordered by the splitting
+  // algorithm. This function call updates splitDimension and splitCol.
+  const bool split = SplitType::SplitNode(bound, data, begin, count,
+      splitDimension, splitCol);
 
-  if (maxWidth == 0) // All these points are the same.  We can't split.
+  // The node may not be always split. For instance, if all the points are the
+  // same, we can't split them.
+  if (!split)
     return;
 
-  // Perform the actual splitting.  This will order the dataset such that points
-  // with value in dimension split_dim less than or equal to splitVal are on
-  // the left of splitCol, and points with value in dimension splitDim greater
-  // than splitVal are on the right side of splitCol.
-  size_t splitCol = GetSplitIndex(data, splitDim, splitVal);
-
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
   left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -541,8 +602,11 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
   right->ParentDistance() = rightParentDistance;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
     MatType& data,
     std::vector<size_t>& oldFromNew)
 {
@@ -557,35 +621,22 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
   if (count <= leafSize)
     return; // We can't split this.
 
-  // Figure out which dimension to split on.
-  size_t splitDim = data.n_rows; // Indicate invalid by max_dim + 1.
-  double maxWidth = -1;
-
-  // Find the split dimension.
-  for (size_t d = 0; d < data.n_rows; d++)
-  {
-    double width = bound[d].Width();
-
-    if (width > maxWidth)
-    {
-      maxWidth = width;
-      splitDim = d;
-    }
-  }
-  splitDimension = splitDim;
+  // splitCol denotes the two partitions of the dataset after the split. The
+  // points on its left go to the left child and the others go to the right
+  // child.
+  size_t splitCol;
 
-  // Split in the middle of that dimension.
-  double splitVal = bound[splitDim].Mid();
+  // Split the node. The elements of 'data' are reordered by the splitting
+  // algorithm. This function call updates splitDimension, splitCol and
+  // oldFromNew.
+  const bool split = SplitType::SplitNode(bound, data, begin, count,
+      splitDimension, splitCol, oldFromNew);
 
-  if (maxWidth == 0) // All these points are the same.  We can't split.
+  // The node may not be always split. For instance, if all the points are the
+  // same, we can't split them.
+  if (!split)
     return;
 
-  // Perform the actual splitting.  This will order the dataset such that points
-  // with value in dimension split_dim less than or equal to splitVal are on
-  // the left of splitCol, and points with value in dimension splitDim greater
-  // than splitVal are on the right side of splitCol.
-  size_t splitCol = GetSplitIndex(data, splitDim, splitVal, oldFromNew);
-
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
   left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -608,105 +659,15 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
   right->ParentDistance() = rightParentDistance;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
-    MatType& data,
-    int splitDim,
-    double splitVal)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // split_val should be on the left side of the matrix, and the points greater
-  // than split_val should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while ((data(splitDim, left) < splitVal) && (left <= right))
-    left++;
-  while ((data(splitDim, right) >= splitVal) && (left <= right))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while ((data(splitDim, left) < splitVal) && (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((data(splitDim, right) >= splitVal) && (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
-    MatType& data,
-    int splitDim,
-    double splitVal,
-    std::vector<size_t>& oldFromNew)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // split_val should be on the left side of the matrix, and the points greater
-  // than split_val should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while ((data(splitDim, left) < splitVal) && (left <= right))
-    left++;
-  while ((data(splitDim, right) >= splitVal) && (left <= right))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // Update the indices for what we changed.
-    size_t t = oldFromNew[left];
-    oldFromNew[left] = oldFromNew[right];
-    oldFromNew[right] = t;
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while ((data(splitDim, left) < splitVal) && (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((data(splitDim, right) >= splitVal) && (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
 /**
  * Returns a string representation of this object.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-std::string BinarySpaceTree<BoundType, StatisticType, MatType>::ToString() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+std::string BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    ToString() const
 {
   std::ostringstream convert;
   convert << "BinarySpaceTree [" << this << "]" << std::endl;
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
index 2d3a0c1..7cd1871 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
@@ -17,9 +17,13 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    DualTreeTraverser
 {
  public:
   /**
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
index d590e93..1a1d353 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
@@ -15,9 +15,12 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0),
@@ -26,12 +29,16 @@ DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     numBaseCases(0)
 { /* Nothing to do. */ }
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 DualTreeTraverser<RuleType>::Traverse(
-    BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
-    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryNode,
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+        referenceNode)
 {
   // Increment the visit counter.
   ++numVisited;
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
new file mode 100644
index 0000000..c6e3c04
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -0,0 +1,118 @@
+/**
+ * @file mean_split.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Definition of MeanSplit, a class that splits a binary space partitioning tree
+ * node into two parts using the mean of the values in a certain dimension.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+/**
+ * A binary space partitioning tree node is split into its left and right child.
+ * The split is done in the dimension that has the maximum width. The points are
+ * divided into two parts based on the mean in this dimension.
+ */
+template<typename BoundType, typename MatType = arma::mat>
+class MeanSplit
+{
+ public:
+  /**
+   * Split the node according to the mean value in the dimension with maximum
+   * width.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension This will be filled with the dimension the node is to
+   *    be split on.
+   * @param splitCol The index at which the dataset is divided into two parts
+   *    after the rearrangement.
+   */
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitDimension,
+                        size_t& splitCol);
+
+  /**
+   * Split the node according to the mean value in the dimension with maximum
+   * width and return a list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension This will be filled with the dimension the node is
+   *    to be split on.
+   * @param splitCol The index at which the dataset is divided into two parts
+   *    after the rearrangement.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitDimension,
+                        size_t& splitCol,
+                        std::vector<size_t>& oldFromNew);
+
+ private:
+  /**
+   * Reorder the dataset into two parts such that they lie on either side of
+   * splitCol.
+   *
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension The dimension to split the node on.
+   * @param splitVal The split in dimension splitDimension is based on this
+   *    value.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal);
+
+  /**
+   * Reorder the dataset into two parts such that they lie on either side of
+   * splitCol. Also returns a list of changed indices.
+   *
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension The dimension to split the node on.
+   * @param splitVal The split in dimension splitDimension is based on this
+   *    value.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal,
+                             std::vector<size_t>& oldFromNew);
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "mean_split_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
new file mode 100644
index 0000000..cf0616e
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -0,0 +1,195 @@
+/**
+ * @file mean_split_impl.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Implementation of class(MeanSplit) to split a binary space partition tree.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+
+#include "mean_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                              MatType& data,
+                                              const size_t begin,
+                                              const size_t count,
+                                              size_t& splitDimension,
+                                              size_t& splitCol)
+{
+  splitDimension = data.n_rows; // Indicate invalid.
+  double maxWidth = -1;
+
+  // Find the split dimension.
+  for (size_t d = 0; d < data.n_rows; d++)
+  {
+    double width = bound[d].Width();
+
+    if (width > maxWidth)
+    {
+      maxWidth = width;
+      splitDimension = d;
+    }
+  }
+
+  if (maxWidth == 0) // All these points are the same.  We can't split.
+    return false;
+
+  // Split in the middle of that dimension.
+  double splitVal = bound[splitDimension].Mid();
+
+  // Perform the actual splitting.  This will order the dataset such that points
+  // with value in dimension splitDimension less than or equal to splitVal are
+  // on the left of splitCol, and points with value in dimension splitDimension
+  // greater than splitVal are on the right side of splitCol.
+  splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
+
+  return true;
+}
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                              MatType& data,
+                                              const size_t begin,
+                                              const size_t count,
+                                              size_t& splitDimension,
+                                              size_t& splitCol,
+                                              std::vector<size_t>& oldFromNew)
+{
+  splitDimension = data.n_rows; // Indicate invalid.
+  double maxWidth = -1;
+
+  // Find the split dimension.
+  for (size_t d = 0; d < data.n_rows; d++)
+  {
+    double width = bound[d].Width();
+
+    if (width > maxWidth)
+    {
+      maxWidth = width;
+      splitDimension = d;
+    }
+  }
+
+  if (maxWidth == 0) // All these points are the same.  We can't split.
+    return false;
+
+  // Split in the middle of that dimension.
+  double splitVal = bound[splitDimension].Mid();
+
+  // Perform the actual splitting.  This will order the dataset such that points
+  // with value in dimension splitDimension less than or equal to splitVal are
+  // on the left of splitCol, and points with value in dimension splitDimension
+  // greater than splitVal are on the right side of splitCol.
+  splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
+      oldFromNew);
+
+  return true;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+    PerformSplit(MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 const size_t splitDimension,
+                 const double splitVal)
+{
+  // This method modifies the input dataset.  We loop both from the left and
+  // right sides of the points contained in this node.  The points less than
+  // splitVal should be on the left side of the matrix, and the points greater
+  // than splitVal should be on the right side of the matrix.
+  size_t left = begin;
+  size_t right = begin + count - 1;
+
+  // First half-iteration of the loop is out here because the termination
+  // condition is in the middle.
+  while ((data(splitDimension, left) < splitVal) && (left <= right))
+    left++;
+  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+    right--;
+
+  while (left <= right)
+  {
+    // Swap columns.
+    data.swap_cols(left, right);
+
+    // See how many points on the left are correct.  When they are correct,
+    // increase the left counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it later.
+    while ((data(splitDimension, left) < splitVal) && (left <= right))
+      left++;
+
+    // Now see how many points on the right are correct.  When they are correct,
+    // decrease the right counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it with the wrong point we found in the
+    // previous loop.
+    while ((data(splitDimension, right) >= splitVal) && (left <= right))
+      right--;
+  }
+
+  Log::Assert(left == right + 1);
+
+  return left;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+    PerformSplit(MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 const size_t splitDimension,
+                 const double splitVal,
+                 std::vector<size_t>& oldFromNew)
+{
+  // This method modifies the input dataset.  We loop both from the left and
+  // right sides of the points contained in this node.  The points less than
+  // splitVal should be on the left side of the matrix, and the points greater
+  // than splitVal should be on the right side of the matrix.
+  size_t left = begin;
+  size_t right = begin + count - 1;
+
+  // First half-iteration of the loop is out here because the termination
+  // condition is in the middle.
+  while ((data(splitDimension, left) < splitVal) && (left <= right))
+    left++;
+  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+    right--;
+
+  while (left <= right)
+  {
+    // Swap columns.
+    data.swap_cols(left, right);
+
+    // Update the indices for what we changed.
+    size_t t = oldFromNew[left];
+    oldFromNew[left] = oldFromNew[right];
+    oldFromNew[right] = t;  
+
+    // See how many points on the left are correct.  When they are correct,
+    // increase the left counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it later.
+    while ((data(splitDimension, left) < splitVal) && (left <= right))
+      left++;
+
+    // Now see how many points on the right are correct.  When they are correct,
+    // decrease the right counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it with the wrong point we found in the
+    // previous loop.
+    while ((data(splitDimension, right) >= splitVal) && (left <= right))
+      right--;
+  }
+
+  Log::Assert(left == right + 1);
+
+  return left;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
index a9d0035..69f2a04 100644
--- a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
@@ -16,9 +16,13 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::SingleTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    SingleTreeTraverser
 {
  public:
   /**
diff --git a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
index 04af321..7b559e0 100644
--- a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
@@ -17,20 +17,27 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0)
 { /* Nothing to do. */ }
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 SingleTreeTraverser<RuleType>::Traverse(
     const size_t queryIndex,
-    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+        referenceNode)
 {
   // If we are a leaf, run the base case as necessary.
   if (referenceNode.IsLeaf())



More information about the mlpack-git mailing list