[mlpack-git] master, mlpack-1.0.x: Incorporate patch from yashdv to move splitting procedure to a different class. (a2feea8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:45:27 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit a2feea8ccd1db9244a426aa007e63205acc4aa75
Author: Yash Vadalia <yashdv at gmail.com>
Date: Mon Mar 24 18:52:43 2014 +0000
Incorporate patch from yashdv to move splitting procedure to a different class.
>---------------------------------------------------------------
a2feea8ccd1db9244a426aa007e63205acc4aa75
src/mlpack/core/tree/CMakeLists.txt | 2 +
.../tree/binary_space_tree/binary_space_tree.hpp | 30 +-
.../binary_space_tree/binary_space_tree_impl.hpp | 375 +++++++++------------
.../tree/binary_space_tree/dual_tree_traverser.hpp | 8 +-
.../binary_space_tree/dual_tree_traverser_impl.hpp | 19 +-
.../core/tree/binary_space_tree/mean_split.hpp | 118 +++++++
.../tree/binary_space_tree/mean_split_impl.hpp | 195 +++++++++++
.../binary_space_tree/single_tree_traverser.hpp | 8 +-
.../single_tree_traverser_impl.hpp | 17 +-
9 files changed, 527 insertions(+), 245 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index b090514..85c1d8d 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -7,6 +7,8 @@ set(SOURCES
binary_space_tree/binary_space_tree_impl.hpp
binary_space_tree/dual_tree_traverser.hpp
binary_space_tree/dual_tree_traverser_impl.hpp
+ binary_space_tree/mean_split.hpp
+ binary_space_tree/mean_split_impl.hpp
binary_space_tree/single_tree_traverser.hpp
binary_space_tree/single_tree_traverser_impl.hpp
binary_space_tree/traits.hpp
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index b169d86..183c711 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -7,6 +7,7 @@
#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_HPP
#include <mlpack/core.hpp>
+#include "mean_split.hpp"
#include "../statistic.hpp"
@@ -30,10 +31,15 @@ namespace tree /** Trees and tree-building procedures. */ {
* bounds/.
* @tparam StatisticType Extra data contained in the node. See statistic.hpp
* for the necessary skeleton interface.
+ * @tparam MatType The dataset class.
+ * @tparam SplitType The class that partitions the dataset/points at a
+ * particular node into two parts. Its definition decides the way this split
+ * is done.
*/
template<typename BoundType,
typename StatisticType = EmptyStatistic,
- typename MatType = arma::mat>
+ typename MatType = arma::mat,
+ typename SplitType = MeanSplit<BoundType, MatType> >
class BinarySpaceTree
{
private:
@@ -463,28 +469,6 @@ class BinarySpaceTree
*/
void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
- /**
- * Find the index to split on for this node, given that we are splitting in
- * the given split dimension on the specified split value.
- *
- * @param data Dataset which we are using.
- * @param splitDim Dimension of dataset to split on.
- * @param splitVal Value to split on, in the given split dimension.
- */
- size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
-
- /**
- * Find the index to split on for this node, given that we are splitting in
- * the given split dimension on the specified split value. Also returns a
- * list of the changed indices.
- *
- * @param data Dataset which we are using.
- * @param splitDim Dimension of dataset to split on.
- * @param splitVal Value to split on, in the given split dimension.
- * @param oldFromNew Vector holding permuted indices.
- */
- size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
- std::vector<size_t>& oldFromNew);
public:
/**
* Returns a string representation of this object.
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index 26f2336..25ed4f5 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -18,8 +18,11 @@ namespace tree {
// Each of these overloads is kept as a separate function to keep the overhead
// from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
const size_t leafSize) :
left(NULL),
@@ -39,8 +42,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
stat = StatisticType(*this);
}
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
std::vector<size_t>& oldFromNew,
const size_t leafSize) :
@@ -66,8 +72,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
stat = StatisticType(*this);
}
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
std::vector<size_t>& oldFromNew,
std::vector<size_t>& newFromOld,
@@ -99,8 +108,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
newFromOld[oldFromNew[i]] = i;
}
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
const size_t begin,
const size_t count,
@@ -122,8 +134,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
stat = StatisticType(*this);
}
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
const size_t begin,
const size_t count,
@@ -150,8 +165,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
stat = StatisticType(*this);
}
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
const size_t begin,
const size_t count,
@@ -203,8 +221,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
* Create a binary space tree by copying the other tree. Be careful! This can
* take a long time and use a lot of memory.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
const BinarySpaceTree& other) :
left(NULL),
right(NULL),
@@ -238,8 +259,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
* destructors in turn. This will invalidate any pointers or references to any
* nodes which are children of this one.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ ~BinarySpaceTree()
{
if (left)
delete left;
@@ -258,9 +283,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
* @param queryCount The Count() of the node to find.
* @return The found node, or NULL if nothing is found.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-const BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+const BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
size_t queryBegin,
size_t queryCount) const
{
@@ -288,9 +316,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
* @param queryCount the Count() of the node to find
* @return the found node, or NULL
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-BinarySpaceTree<BoundType, StatisticType, MatType>*
-BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>*
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::FindByBeginCount(
const size_t queryBegin,
const size_t queryCount)
{
@@ -309,8 +340,11 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
return NULL;
}
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::ExtendTree(
size_t level)
{
--level;
@@ -340,16 +374,24 @@ size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
* to avoid exceeding the stack limit
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ TreeSize() const
{
// Recursively count the nodes on each side of the tree. The plus one is
// because we have to count this node, too.
return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
}
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ TreeDepth() const
{
// Recursively count the depth on each side of the tree. The plus one is
// because we have to count this node, too.
@@ -357,8 +399,12 @@ size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
(right ? right->TreeDepth() : 0));
}
-template<typename BoundType, typename StatisticType, typename MatType>
-inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ IsLeaf() const
{
return !left;
}
@@ -366,9 +412,12 @@ inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
/**
* Returns the number of children in this node.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
- BinarySpaceTree<BoundType, StatisticType, MatType>::NumChildren() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ NumChildren() const
{
if (left && right)
return 2;
@@ -382,8 +431,11 @@ inline size_t
* Return a bound on the furthest point in the node from the centroid. This
* returns 0 unless the node is a leaf.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
FurthestPointDistance() const
{
if (IsLeaf())
@@ -401,8 +453,11 @@ inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
* furthest descendant distance may be less than what this method returns (but
* it will never be greater than this).
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline double BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
FurthestDescendantDistance() const
{
return furthestDescendantDistance;
@@ -411,10 +466,13 @@ inline double BinarySpaceTree<BoundType, StatisticType, MatType>::
/**
* Return the specified child.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline BinarySpaceTree<BoundType, StatisticType, MatType>&
- BinarySpaceTree<BoundType, StatisticType, MatType>::Child(
- const size_t child) const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ Child(const size_t child) const
{
if (child == 0)
return *left;
@@ -425,9 +483,12 @@ inline BinarySpaceTree<BoundType, StatisticType, MatType>&
/**
* Return the number of points contained in this node.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ NumPoints() const
{
if (left)
return 0;
@@ -438,9 +499,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::NumPoints() const
/**
* Return the number of descendants contained in the node.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::NumDescendants() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ NumDescendants() const
{
return count;
}
@@ -448,10 +512,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::NumDescendants() const
/**
* Return the index of a particular descendant contained in this node.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Descendant(
- const size_t index) const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ Descendant(const size_t index) const
{
return (begin + index);
}
@@ -459,10 +525,12 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::Descendant(
/**
* Return the index of a particular point contained in this node.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t
-BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
- const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ Point(const size_t index) const
{
return (begin + index);
}
@@ -470,14 +538,21 @@ BinarySpaceTree<BoundType, StatisticType, MatType>::Point(const size_t index)
/**
* Gets the index one beyond the last index in the series.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ End() const
{
return begin + count;
}
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
MatType& data)
{
// We need to expand the bounds of this node properly.
@@ -490,35 +565,21 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
if (count <= leafSize)
return; // We can't split this.
- // Figure out which dimension to split on.
- size_t splitDim = data.n_rows; // Indicate invalid by maxDim + 1.
- double maxWidth = -1;
-
- // Find the split dimension.
- for (size_t d = 0; d < data.n_rows; d++)
- {
- double width = bound[d].Width();
-
- if (width > maxWidth)
- {
- maxWidth = width;
- splitDim = d;
- }
- }
- splitDimension = splitDim;
+ // splitCol denotes the two partitions of the dataset after the split. The
+ // points on its left go to the left child and the others go to the right
+ // child.
+ size_t splitCol;
- // Split in the middle of that dimension.
- double splitVal = bound[splitDim].Mid();
+ // Split the node. The elements of 'data' are reordered by the splitting
+ // algorithm. This function call updates splitDimension and splitCol.
+ const bool split = SplitType::SplitNode(bound, data, begin, count,
+ splitDimension, splitCol);
- if (maxWidth == 0) // All these points are the same. We can't split.
+ // The node may not be always split. For instance, if all the points are the
+ // same, we can't split them.
+ if (!split)
return;
- // Perform the actual splitting. This will order the dataset such that points
- // with value in dimension split_dim less than or equal to splitVal are on
- // the left of splitCol, and points with value in dimension splitDim greater
- // than splitVal are on the right side of splitCol.
- size_t splitCol = GetSplitIndex(data, splitDim, splitVal);
-
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -541,8 +602,11 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
right->ParentDistance() = rightParentDistance;
}
-template<typename BoundType, typename StatisticType, typename MatType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
MatType& data,
std::vector<size_t>& oldFromNew)
{
@@ -557,35 +621,22 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
if (count <= leafSize)
return; // We can't split this.
- // Figure out which dimension to split on.
- size_t splitDim = data.n_rows; // Indicate invalid by max_dim + 1.
- double maxWidth = -1;
-
- // Find the split dimension.
- for (size_t d = 0; d < data.n_rows; d++)
- {
- double width = bound[d].Width();
-
- if (width > maxWidth)
- {
- maxWidth = width;
- splitDim = d;
- }
- }
- splitDimension = splitDim;
+ // splitCol denotes the two partitions of the dataset after the split. The
+ // points on its left go to the left child and the others go to the right
+ // child.
+ size_t splitCol;
- // Split in the middle of that dimension.
- double splitVal = bound[splitDim].Mid();
+ // Split the node. The elements of 'data' are reordered by the splitting
+ // algorithm. This function call updates splitDimension, splitCol and
+ // oldFromNew.
+ const bool split = SplitType::SplitNode(bound, data, begin, count,
+ splitDimension, splitCol, oldFromNew);
- if (maxWidth == 0) // All these points are the same. We can't split.
+ // The node may not be always split. For instance, if all the points are the
+ // same, we can't split them.
+ if (!split)
return;
- // Perform the actual splitting. This will order the dataset such that points
- // with value in dimension split_dim less than or equal to splitVal are on
- // the left of splitCol, and points with value in dimension splitDim greater
- // than splitVal are on the right side of splitCol.
- size_t splitCol = GetSplitIndex(data, splitDim, splitVal, oldFromNew);
-
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
@@ -608,105 +659,15 @@ void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
right->ParentDistance() = rightParentDistance;
}
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
- MatType& data,
- int splitDim,
- double splitVal)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // split_val should be on the left side of the matrix, and the points greater
- // than split_val should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename BoundType, typename StatisticType, typename MatType>
-size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
- MatType& data,
- int splitDim,
- double splitVal,
- std::vector<size_t>& oldFromNew)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // split_val should be on the left side of the matrix, and the points greater
- // than split_val should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // Update the indices for what we changed.
- size_t t = oldFromNew[left];
- oldFromNew[left] = oldFromNew[right];
- oldFromNew[right] = t;
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while ((data(splitDim, left) < splitVal) && (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((data(splitDim, right) >= splitVal) && (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
/**
* Returns a string representation of this object.
*/
-template<typename BoundType, typename StatisticType, typename MatType>
-std::string BinarySpaceTree<BoundType, StatisticType, MatType>::ToString() const
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+std::string BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ ToString() const
{
std::ostringstream convert;
convert << "BinarySpaceTree [" << this << "]" << std::endl;
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
index 2d3a0c1..7cd1871 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
@@ -17,9 +17,13 @@
namespace mlpack {
namespace tree {
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ DualTreeTraverser
{
public:
/**
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
index d590e93..1a1d353 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
@@ -15,9 +15,12 @@
namespace mlpack {
namespace tree {
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
rule(rule),
numPrunes(0),
@@ -26,12 +29,16 @@ DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
numBaseCases(0)
{ /* Nothing to do. */ }
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
DualTreeTraverser<RuleType>::Traverse(
- BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
- BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryNode,
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+ referenceNode)
{
// Increment the visit counter.
++numVisited;
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
new file mode 100644
index 0000000..c6e3c04
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -0,0 +1,118 @@
+/**
+ * @file mean_split.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Definition of MeanSplit, a class that splits a binary space partitioning tree
+ * node into two parts using the mean of the values in a certain dimension.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree /** Trees and tree-building procedures. */ {
+
+/**
+ * A binary space partitioning tree node is split into its left and right child.
+ * The split is done in the dimension that has the maximum width. The points are
+ * divided into two parts based on the mean in this dimension.
+ */
+template<typename BoundType, typename MatType = arma::mat>
+class MeanSplit
+{
+ public:
+ /**
+ * Split the node according to the mean value in the dimension with maximum
+ * width.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitDimension This will be filled with the dimension the node is to
+ * be split on.
+ * @param splitCol The index at which the dataset is divided into two parts
+ * after the rearrangement.
+ */
+ static bool SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitDimension,
+ size_t& splitCol);
+
+ /**
+ * Split the node according to the mean value in the dimension with maximum
+ * width and return a list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitDimension This will be filled with the dimension the node is
+ * to be split on.
+ * @param splitCol The index at which the dataset is divided into two parts
+ * after the rearrangement.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static bool SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitDimension,
+ size_t& splitCol,
+ std::vector<size_t>& oldFromNew);
+
+ private:
+ /**
+ * Reorder the dataset into two parts such that they lie on either side of
+ * splitCol.
+ *
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitDimension The dimension to split the node on.
+ * @param splitVal The split in dimension splitDimension is based on this
+ * value.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal);
+
+ /**
+ * Reorder the dataset into two parts such that they lie on either side of
+ * splitCol. Also returns a list of changed indices.
+ *
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitDimension The dimension to split the node on.
+ * @param splitVal The split in dimension splitDimension is based on this
+ * value.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal,
+ std::vector<size_t>& oldFromNew);
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "mean_split_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
new file mode 100644
index 0000000..cf0616e
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -0,0 +1,195 @@
+/**
+ * @file mean_split_impl.hpp
+ * @author Yash Vadalia
+ * @author Ryan Curtin
+ *
+ * Implementation of class(MeanSplit) to split a binary space partition tree.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+
+#include "mean_split.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitDimension,
+ size_t& splitCol)
+{
+ splitDimension = data.n_rows; // Indicate invalid.
+ double maxWidth = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++)
+ {
+ double width = bound[d].Width();
+
+ if (width > maxWidth)
+ {
+ maxWidth = width;
+ splitDimension = d;
+ }
+ }
+
+ if (maxWidth == 0) // All these points are the same. We can't split.
+ return false;
+
+ // Split in the middle of that dimension.
+ double splitVal = bound[splitDimension].Mid();
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension splitDimension less than or equal to splitVal are
+ // on the left of splitCol, and points with value in dimension splitDimension
+ // greater than splitVal are on the right side of splitCol.
+ splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
+
+ return true;
+}
+
+template<typename BoundType, typename MatType>
+bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitDimension,
+ size_t& splitCol,
+ std::vector<size_t>& oldFromNew)
+{
+ splitDimension = data.n_rows; // Indicate invalid.
+ double maxWidth = -1;
+
+ // Find the split dimension.
+ for (size_t d = 0; d < data.n_rows; d++)
+ {
+ double width = bound[d].Width();
+
+ if (width > maxWidth)
+ {
+ maxWidth = width;
+ splitDimension = d;
+ }
+ }
+
+ if (maxWidth == 0) // All these points are the same. We can't split.
+ return false;
+
+ // Split in the middle of that dimension.
+ double splitVal = bound[splitDimension].Mid();
+
+ // Perform the actual splitting. This will order the dataset such that points
+ // with value in dimension splitDimension less than or equal to splitVal are
+ // on the left of splitCol, and points with value in dimension splitDimension
+ // greater than splitVal are on the right side of splitCol.
+ splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
+ oldFromNew);
+
+ return true;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+ PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // splitVal should be on the left side of the matrix, and the points greater
+ // than splitVal should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(splitDimension, left) < splitVal) && (left <= right))
+ left++;
+ while ((data(splitDimension, right) >= splitVal) && (left <= right))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(splitDimension, left) < splitVal) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(splitDimension, right) >= splitVal) && (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
+template<typename BoundType, typename MatType>
+size_t MeanSplit<BoundType, MatType>::
+ PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal,
+ std::vector<size_t>& oldFromNew)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // splitVal should be on the left side of the matrix, and the points greater
+ // than splitVal should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while ((data(splitDimension, left) < splitVal) && (left <= right))
+ left++;
+ while ((data(splitDimension, right) >= splitVal) && (left <= right))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // Update the indices for what we changed.
+ size_t t = oldFromNew[left];
+ oldFromNew[left] = oldFromNew[right];
+ oldFromNew[right] = t;
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while ((data(splitDimension, left) < splitVal) && (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((data(splitDimension, right) >= splitVal) && (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
index a9d0035..69f2a04 100644
--- a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser.hpp
@@ -16,9 +16,13 @@
namespace mlpack {
namespace tree {
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-class BinarySpaceTree<BoundType, StatisticType, MatType>::SingleTreeTraverser
+class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ SingleTreeTraverser
{
public:
/**
diff --git a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
index 04af321..7b559e0 100644
--- a/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/single_tree_traverser_impl.hpp
@@ -17,20 +17,27 @@
namespace mlpack {
namespace tree {
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-BinarySpaceTree<BoundType, StatisticType, MatType>::
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
rule(rule),
numPrunes(0)
{ /* Nothing to do. */ }
-template<typename BoundType, typename StatisticType, typename MatType>
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
template<typename RuleType>
-void BinarySpaceTree<BoundType, StatisticType, MatType>::
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
SingleTreeTraverser<RuleType>::Traverse(
const size_t queryIndex,
- BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+ referenceNode)
{
// If we are a leaf, run the base case as necessary.
if (referenceNode.IsLeaf())
More information about the mlpack-git
mailing list