[mlpack-svn] r16373 - in mlpack/trunk/src/mlpack/core/tree: . binary_space_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Mar 24 14:52:43 EDT 2014


Author: rcurtin
Date: Mon Mar 24 14:52:43 2014
New Revision: 16373

Log:
Incorporate patch from yashdv to move splitting procedure to a different class.


Added:
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	Mon Mar 24 14:52:43 2014
@@ -7,6 +7,8 @@
   binary_space_tree/binary_space_tree_impl.hpp
   binary_space_tree/dual_tree_traverser.hpp
   binary_space_tree/dual_tree_traverser_impl.hpp
+  binary_space_tree/mean_split.hpp
+  binary_space_tree/mean_split_impl.hpp
   binary_space_tree/single_tree_traverser.hpp
   binary_space_tree/single_tree_traverser_impl.hpp
   binary_space_tree/traits.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp	Mon Mar 24 14:52:43 2014
@@ -7,6 +7,7 @@
 #define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
 
 #include <mlpack/core.hpp>
+#include "mean_split.hpp"
 
 #include "../statistic.hpp"
 
@@ -30,10 +31,15 @@
  *     bounds/.
  * @tparam StatisticType Extra data contained in the node.  See statistic.hpp
  *     for the necessary skeleton interface.
+ * @tparam MatType The dataset class.
+ * @tparam SplitType The class that partitions the dataset/points at a
+ *     particular node into two parts. Its definition decides the way this split
+ *     is done.
  */
 template<typename BoundType,
          typename StatisticType = EmptyStatistic,
-         typename MatType = arma::mat>
+         typename MatType = arma::mat,
+         typename SplitType = MeanSplit<BoundType, MatType> >
 class BinarySpaceTree
 {
  private:
@@ -463,28 +469,6 @@
    */
   void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
 
-  /**
-   * Find the index to split on for this node, given that we are splitting in
-   * the given split dimension on the specified split value.
-   *
-   * @param data Dataset which we are using.
-   * @param splitDim Dimension of dataset to split on.
-   * @param splitVal Value to split on, in the given split dimension.
-   */
-  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
-
-  /**
-   * Find the index to split on for this node, given that we are splitting in
-   * the given split dimension on the specified split value.  Also returns a
-   * list of the changed indices.
-   *
-   * @param data Dataset which we are using.
-   * @param splitDim Dimension of dataset to split on.
-   * @param splitVal Value to split on, in the given split dimension.
-   * @param oldFromNew Vector holding permuted indices.
-   */
-  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
-      std::vector<size_t>& oldFromNew);
  public:
   /**
    * Returns a string representation of this object.

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp	Mon Mar 24 14:52:43 2014
@@ -18,8 +18,11 @@
 
 // Each of these overloads is kept as a separate function to keep the overhead
 // from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t leafSize) :
     left(NULL),
@@ -39,8 +42,11 @@
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     std::vector<size_t>& oldFromNew,
     const size_t leafSize) :
@@ -66,8 +72,11 @@
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     std::vector<size_t>& oldFromNew,
     std::vector<size_t>& newFromOld,
@@ -99,8 +108,11 @@
     newFromOld[oldFromNew[i]] = i;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -122,8 +134,11 @@
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -150,8 +165,11 @@
   stat = StatisticType(*this);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     MatType& data,
     const size_t begin,
     const size_t count,
@@ -203,8 +221,11 @@
  * Create a binary space tree by copying the other tree.  Be careful!  This can
  * take a long time and use a lot of memory.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     const BinarySpaceTree& other) :
     left(NULL),
     right(NULL),
@@ -238,8 +259,12 @@
  * destructors in turn.  This will invalidate any pointers or references to any
  * nodes which are children of this one.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+  ~BinarySpaceTree()
 {
   if (left)
     delete left;
@@ -258,9 +283,12 @@
  * @param queryCount The Count() of the node to find.
  * @return The found node, or NULL if nothing is found.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-const BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+const BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
     size_t queryBegin,
     size_t queryCount) const
 {
@@ -288,9 +316,12 @@
  * @param queryCount the Count() of the node to find
  * @return the found node, or NULL
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
     const size_t queryBegin,
     const size_t queryCount)
 {
@@ -309,8 +340,11 @@
     return NULL;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::ExtendTree(
     size_t level)
 {
   --level;
@@ -340,16 +374,24 @@
  *     to avoid exceeding the stack limit
  */
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    TreeSize() const
 {
   // Recursively count the nodes on each side of the tree.  The plus one is
   // because we have to count this node, too.
   return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    TreeDepth() const
 {
   // Recursively count the depth on each side of the tree.  The plus one is
   // because we have to count this node, too.
@@ -357,8 +399,12 @@
                       (right ? right->TreeDepth() : 0));
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    IsLeaf() const
 {
   return !left;
 }
@@ -366,9 +412,12 @@
 /**
  * Returns the number of children in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-    BinarySpaceTree<BoundType, StatisticType, MatType>::NumChildren() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumChildren() const
 {
   if (left && right)
     return 2;
@@ -382,8 +431,11 @@
  * Return a bound on the furthest point in the node from the centroid.  This
  * returns 0 unless the node is a leaf.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
     FurthestPointDistance() const
 {
   if (IsLeaf())
@@ -401,8 +453,11 @@
  * furthest descendant distance may be less than what this method returns (but
  * it will never be greater than this).
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
     FurthestDescendantDistance() const
 {
   return furthestDescendantDistance;
@@ -411,10 +466,13 @@
 /**
  * Return the specified child.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline BinarySpaceTree<BoundType, StatisticType, MatType>&
-    BinarySpaceTree<BoundType, StatisticType, MatType>::Child(
-    const size_t child) const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+        Child(const size_t child) const
 {
   if (child == 0)
     return *left;
@@ -425,9 +483,12 @@
 /**
  * Return the number of points contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumPoints() const
 {
   if (left)
     return 0;
@@ -438,9 +499,12 @@
 /**
  * Return the number of descendants contained in the node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumDescendants() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    NumDescendants() const
 {
   return count;
 }
@@ -448,10 +512,12 @@
 /**
  * Return the index of a particular descendant contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Descendant(
-    const size_t index) const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    Descendant(const size_t index) const
 {
   return (begin + index);
 }
@@ -459,10 +525,12 @@
 /**
  * Return the index of a particular point contained in this node.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
-    const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    Point(const size_t index) const
 {
   return (begin + index);
 }
@@ -470,14 +538,21 @@
 /**
  * Gets the index one beyond the last index in the series.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    End() const
 {
   return begin + count;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
     MatType& data)
 {
   // We need to expand the bounds of this node properly.
@@ -490,35 +565,21 @@
   if (count <= leafSize)
     return; // We can't split this.
 
-  // Figure out which dimension to split on.
-  size_t splitDim = data.n_rows; // Indicate invalid by maxDim + 1.
-  double maxWidth = -1;
-
-  // Find the split dimension.
-  for (size_t d = 0; d < data.n_rows; d++)
-  {
-    double width = bound[d].Width();
-
-    if (width > maxWidth)
-    {
-      maxWidth = width;
-      splitDim = d;
-    }
-  }
-  splitDimension = splitDim;
-
-  // Split in the middle of that dimension.
-  double splitVal = bound[splitDim].Mid();
-
-  if (maxWidth == 0) // All these points are the same.  We can't split.
+  // splitCol denotes the two partitions of the dataset after the split. The
+  // points on its left go to the left child and the others go to the right
+  // child.
+  size_t splitCol;
+
+  // Split the node. The elements of 'data' are reordered by the splitting
+  // algorithm. This function call updates splitDimension and splitCol.
+  const bool split = SplitType::SplitNode(bound, data, begin, count,
+      splitDimension, splitCol);
+
+  // The node may not be always split. For instance, if all the points are the
+  // same, we can't split them.
+  if (!split)
     return;
 
-  // Perform the actual splitting.  This will order the dataset such that points
-  // with value in dimension split_dim less than or equal to splitVal are on
-  // the left of splitCol, and points with value in dimension splitDim greater
-  // than splitVal are on the right side of splitCol.
-  size_t splitCol = GetSplitIndex(data, splitDim, splitVal);
-
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
   left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -541,8 +602,11 @@
   right->ParentDistance() = rightParentDistance;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
     MatType& data,
     std::vector<size_t>& oldFromNew)
 {
@@ -557,35 +621,22 @@
   if (count <= leafSize)
     return; // We can't split this.
 
-  // Figure out which dimension to split on.
-  size_t splitDim = data.n_rows; // Indicate invalid by max_dim + 1.
-  double maxWidth = -1;
-
-  // Find the split dimension.
-  for (size_t d = 0; d < data.n_rows; d++)
-  {
-    double width = bound[d].Width();
-
-    if (width > maxWidth)
-    {
-      maxWidth = width;
-      splitDim = d;
-    }
-  }
-  splitDimension = splitDim;
-
-  // Split in the middle of that dimension.
-  double splitVal = bound[splitDim].Mid();
-
-  if (maxWidth == 0) // All these points are the same.  We can't split.
+  // splitCol denotes the two partitions of the dataset after the split. The
+  // points on its left go to the left child and the others go to the right
+  // child.
+  size_t splitCol;
+
+  // Split the node. The elements of 'data' are reordered by the splitting
+  // algorithm. This function call updates splitDimension, splitCol and
+  // oldFromNew.
+  const bool split = SplitType::SplitNode(bound, data, begin, count,
+      splitDimension, splitCol, oldFromNew);
+
+  // The node may not be always split. For instance, if all the points are the
+  // same, we can't split them.
+  if (!split)
     return;
 
-  // Perform the actual splitting.  This will order the dataset such that points
-  // with value in dimension split_dim less than or equal to splitVal are on
-  // the left of splitCol, and points with value in dimension splitDim greater
-  // than splitVal are on the right side of splitCol.
-  size_t splitCol = GetSplitIndex(data, splitDim, splitVal, oldFromNew);
-
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
   left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -608,105 +659,15 @@
   right->ParentDistance() = rightParentDistance;
 }
 
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
-    MatType& data,
-    int splitDim,
-    double splitVal)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // split_val should be on the left side of the matrix, and the points greater
-  // than split_val should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while ((data(splitDim, left) < splitVal) && (left <= right))
-    left++;
-  while ((data(splitDim, right) >= splitVal) && (left <= right))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while ((data(splitDim, left) < splitVal) && (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((data(splitDim, right) >= splitVal) && (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
-    MatType& data,
-    int splitDim,
-    double splitVal,
-    std::vector<size_t>& oldFromNew)
-{
-  // This method modifies the input dataset.  We loop both from the left and
-  // right sides of the points contained in this node.  The points less than
-  // split_val should be on the left side of the matrix, and the points greater
-  // than split_val should be on the right side of the matrix.
-  size_t left = begin;
-  size_t right = begin + count - 1;
-
-  // First half-iteration of the loop is out here because the termination
-  // condition is in the middle.
-  while ((data(splitDim, left) < splitVal) && (left <= right))
-    left++;
-  while ((data(splitDim, right) >= splitVal) && (left <= right))
-    right--;
-
-  while (left <= right)
-  {
-    // Swap columns.
-    data.swap_cols(left, right);
-
-    // Update the indices for what we changed.
-    size_t t = oldFromNew[left];
-    oldFromNew[left] = oldFromNew[right];
-    oldFromNew[right] = t;
-
-    // See how many points on the left are correct.  When they are correct,
-    // increase the left counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it later.
-    while ((data(splitDim, left) < splitVal) && (left <= right))
-      left++;
-
-    // Now see how many points on the right are correct.  When they are correct,
-    // decrease the right counter accordingly.  When we encounter one that isn't
-    // correct, stop.  We will switch it with the wrong point we found in the
-    // previous loop.
-    while ((data(splitDim, right) >= splitVal) && (left <= right))
-      right--;
-  }
-
-  Log::Assert(left == right + 1);
-
-  return left;
-}
-
 /**
  * Returns a string representation of this object.
  */
-template<typename BoundType, typename StatisticType, typename MatType>
-std::string BinarySpaceTree<BoundType, StatisticType, MatType>::ToString() const
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+std::string BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    ToString() const
 {
   std::ostringstream convert;
   convert << "BinarySpaceTree [" << this << "]" << std::endl;

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	Mon Mar 24 14:52:43 2014
@@ -17,9 +17,13 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    DualTreeTraverser
 {
  public:
   /**

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	Mon Mar 24 14:52:43 2014
@@ -15,9 +15,12 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0),
@@ -26,12 +29,16 @@
     numBaseCases(0)
 { /* Nothing to do. */ }
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 DualTreeTraverser<RuleType>::Traverse(
-    BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
-    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryNode,
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+        referenceNode)
 {
   // Increment the visit counter.
   ++numVisited;

Added: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split.hpp	Mon Mar 24 14:52:43 2014
@@ -0,0 +1,118 @@
+/**
+ * @file mean_split.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Definition of MeanSplit, a class that splits a binary space partitioning tree
+ * node into two parts using the mean of the values in a certain dimension.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+/**
+ * A binary space partitioning tree node is split into its left and right child.
+ * The split is done in the dimension that has the maximum width. The points are
+ * divided into two parts based on the mean in this dimension.
+ */
+template<typename BoundType, typename MatType = arma::mat>
+class MeanSplit
+{
+ public:
+  /**
+   * Split the node according to the mean value in the dimension with maximum
+   * width.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension This will be filled with the dimension the node is to
+   *    be split on.
+   * @param splitCol The index at which the dataset is divided into two parts
+   *    after the rearrangement.
+   */
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitDimension,
+                        size_t& splitCol);
+
+  /**
+   * Split the node according to the mean value in the dimension with maximum
+   * width and return a list of changed indices.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension This will be filled with the dimension the node is
+   *    to be split on.
+   * @param splitCol The index at which the dataset is divided into two parts
+   *    after the rearrangement.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static bool SplitNode(const BoundType& bound,
+                        MatType& data,
+                        const size_t begin,
+                        const size_t count,
+                        size_t& splitDimension,
+                        size_t& splitCol,
+                        std::vector<size_t>& oldFromNew);
+
+ private:
+  /**
+   * Reorder the dataset into two parts such that they lie on either side of
+   * splitCol.
+   *
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension The dimension to split the node on.
+   * @param splitVal The split in dimension splitDimension is based on this
+   *    value.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal);
+
+  /**
+   * Reorder the dataset into two parts such that they lie on either side of
+   * splitCol. Also returns a list of changed indices.
+   *
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitDimension The dimension to split the node on.
+   * @param splitVal The split in dimension splitDimension is based on this
+   *    value.
+   * @param oldFromNew Vector which will be filled with the old positions for
+   *    each new point.
+   */
+  static size_t PerformSplit(MatType& data,
+                             const size_t begin,
+                             const size_t count,
+                             const size_t splitDimension,
+                             const double splitVal,
+                             std::vector<size_t>& oldFromNew);
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "mean_split_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp	Mon Mar 24 14:52:43 2014
@@ -0,0 +1,195 @@
+/**
+ * @file mean_split_impl.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Implementation of class(MeanSplit) to split a binary space partition tree.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+
+#include "mean_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                              MatType& data,
+                                              const size_t begin,
+                                              const size_t count,
+                                              size_t& splitDimension,
+                                              size_t& splitCol)
+{
+  splitDimension = data.n_rows; // Indicate invalid.
+  double maxWidth = -1;
+
+  // Find the split dimension.
+  for (size_t d = 0; d < data.n_rows; d++)
+  {
+    double width = bound[d].Width();
+
+    if (width > maxWidth)
+    {
+      maxWidth = width;
+      splitDimension = d;
+    }
+  }
+
+  if (maxWidth == 0) // All these points are the same.  We can't split.
+    return false;
+
+  // Split in the middle of that dimension.
+  double splitVal = bound[splitDimension].Mid();
+
+  // Perform the actual splitting.  This will order the dataset such that points
+  // with value in dimension splitDimension less than or equal to splitVal are
+  // on the left of splitCol, and points with value in dimension splitDimension
+  // greater than splitVal are on the right side of splitCol.
+  splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
+
+  return true;
+}
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                              MatType& data,
+                                              const size_t begin,
+                                              const size_t count,
+                                              size_t& splitDimension,
+                                              size_t& splitCol,
+                                              std::vector<size_t>& oldFromNew)
+{
+  splitDimension = data.n_rows; // Indicate invalid.
+  double maxWidth = -1;
+
+  // Find the split dimension.
+  for (size_t d = 0; d < data.n_rows; d++)
+  {
+    double width = bound[d].Width();
+
+    if (width > maxWidth)
+    {
+      maxWidth = width;
+      splitDimension = d;
+    }
+  }
+
+  if (maxWidth == 0) // All these points are the same.  We can't split.
+    return false;
+
+  // Split in the middle of that dimension.
+  double splitVal = bound[splitDimension].Mid();
+
+  // Perform the actual splitting.  This will order the dataset such that points
+  // with value in dimension splitDimension less than or equal to splitVal are
+  // on the left of splitCol, and points with value in dimension splitDimension
+  // greater than splitVal are on the right side of splitCol.
+  splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
+      oldFromNew);
+
+  return true;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+    PerformSplit(MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 const size_t splitDimension,
+                 const double splitVal)
+{
+  // This method modifies the input dataset.  We loop both from the left and
+  // right sides of the points contained in this node.  The points less than
+  // splitVal should be on the left side of the matrix, and the points greater
+  // than splitVal should be on the right side of the matrix.
+  size_t left = begin;
+  size_t right = begin + count - 1;
+
+  // First half-iteration of the loop is out here because the termination
+  // condition is in the middle.
+  while ((data(splitDimension, left) < splitVal) && (left <= right))
+    left++;
+  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+    right--;
+
+  while (left <= right)
+  {
+    // Swap columns.
+    data.swap_cols(left, right);
+
+    // See how many points on the left are correct.  When they are correct,
+    // increase the left counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it later.
+    while ((data(splitDimension, left) < splitVal) && (left <= right))
+      left++;
+
+    // Now see how many points on the right are correct.  When they are correct,
+    // decrease the right counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it with the wrong point we found in the
+    // previous loop.
+    while ((data(splitDimension, right) >= splitVal) && (left <= right))
+      right--;
+  }
+
+  Log::Assert(left == right + 1);
+
+  return left;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+    PerformSplit(MatType& data,
+                 const size_t begin,
+                 const size_t count,
+                 const size_t splitDimension,
+                 const double splitVal,
+                 std::vector<size_t>& oldFromNew)
+{
+  // This method modifies the input dataset.  We loop both from the left and
+  // right sides of the points contained in this node.  The points less than
+  // splitVal should be on the left side of the matrix, and the points greater
+  // than splitVal should be on the right side of the matrix.
+  size_t left = begin;
+  size_t right = begin + count - 1;
+
+  // First half-iteration of the loop is out here because the termination
+  // condition is in the middle.
+  while ((data(splitDimension, left) < splitVal) && (left <= right))
+    left++;
+  while ((data(splitDimension, right) >= splitVal) && (left <= right))
+    right--;
+
+  while (left <= right)
+  {
+    // Swap columns.
+    data.swap_cols(left, right);
+
+    // Update the indices for what we changed.
+    size_t t = oldFromNew[left];
+    oldFromNew[left] = oldFromNew[right];
+    oldFromNew[right] = t;  
+
+    // See how many points on the left are correct.  When they are correct,
+    // increase the left counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it later.
+    while ((data(splitDimension, left) < splitVal) && (left <= right))
+      left++;
+
+    // Now see how many points on the right are correct.  When they are correct,
+    // decrease the right counter accordingly.  When we encounter one that isn't
+    // correct, stop.  We will switch it with the wrong point we found in the
+    // previous loop.
+    while ((data(splitDimension, right) >= splitVal) && (left <= right))
+      right--;
+  }
+
+  Log::Assert(left == right + 1);
+
+  return left;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp	Mon Mar 24 14:52:43 2014
@@ -16,9 +16,13 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::SingleTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    SingleTreeTraverser
 {
  public:
   /**

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp	Mon Mar 24 14:52:43 2014
@@ -17,20 +17,27 @@
 namespace mlpack {
 namespace tree {
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0)
 { /* Nothing to do. */ }
 
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
 template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
 SingleTreeTraverser<RuleType>::Traverse(
     const size_t queryIndex,
-    BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+    BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+        referenceNode)
 {
   // If we are a leaf, run the base case as necessary.
   if (referenceNode.IsLeaf())



More information about the mlpack-svn mailing list