[mlpack-git] master: Refactor BinarySpaceTree to use instantiated splitter. This isn't currently user-passable, but, maybe later. (e603caa)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun May 3 19:34:28 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/174d2de995a3fe343cd92d158730f3afa03e622d...076156df78e26ba87012f2b5fbc6d45e84da918b
>---------------------------------------------------------------
commit e603caa2af2034c255f35936f3e9ace690fce0f2
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun May 3 18:41:35 2015 -0400
Refactor BinarySpaceTree to use instantiated splitter.
This isn't currently user-passable, but, maybe later.
>---------------------------------------------------------------
e603caa2af2034c255f35936f3e9ace690fce0f2
.../tree/binary_space_tree/binary_space_tree.hpp | 9 ++++--
.../binary_space_tree/binary_space_tree_impl.hpp | 36 +++++++++++++---------
2 files changed, 29 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index 9af121b..cbd6da7 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -141,6 +141,7 @@ class BinarySpaceTree
BinarySpaceTree(MatType& data,
const size_t begin,
const size_t count,
+ SplitType& splitter,
BinarySpaceTree* parent = NULL,
const size_t maxLeafSize = 20);
@@ -166,6 +167,7 @@ class BinarySpaceTree
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
+ SplitType& splitter,
BinarySpaceTree* parent = NULL,
const size_t maxLeafSize = 20);
@@ -195,6 +197,7 @@ class BinarySpaceTree
const size_t count,
std::vector<size_t>& oldFromNew,
std::vector<size_t>& newFromOld,
+ SplitType& splitter,
BinarySpaceTree* parent = NULL,
const size_t maxLeafSize = 20);
@@ -429,7 +432,8 @@ class BinarySpaceTree
* @param maxLeafSize Maximum number of points held in a leaf.
*/
void SplitNode(MatType& data,
- const size_t maxLeafSize);
+ const size_t maxLeafSize,
+ SplitType& splitter);
/**
* Splits the current node, assigning its left and right children recursively.
@@ -441,7 +445,8 @@ class BinarySpaceTree
*/
void SplitNode(MatType& data,
std::vector<size_t>& oldFromNew,
- const size_t maxLeafSize);
+ const size_t maxLeafSize,
+ SplitType& splitter);
public:
/**
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index 9764578..6cdc9c6 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -35,7 +35,8 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
dataset(data)
{
// Do the actual splitting of this node.
- SplitNode(data, maxLeafSize);
+ SplitType splitter;
+ SplitNode(data, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -64,7 +65,8 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
oldFromNew[i] = i; // Fill with unharmed indices.
// Now do the actual splitting.
- SplitNode(data, oldFromNew, maxLeafSize);
+ SplitType splitter;
+ SplitNode(data, oldFromNew, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -94,7 +96,8 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
oldFromNew[i] = i; // Fill with unharmed indices.
// Now do the actual splitting.
- SplitNode(data, oldFromNew, maxLeafSize);
+ SplitType splitter;
+ SplitNode(data, oldFromNew, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -113,6 +116,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
MatType& data,
const size_t begin,
const size_t count,
+ SplitType& splitter,
BinarySpaceTree* parent,
const size_t maxLeafSize) :
left(NULL),
@@ -124,7 +128,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
dataset(data)
{
// Perform the actual splitting.
- SplitNode(data, maxLeafSize);
+ SplitNode(data, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -139,6 +143,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
+ SplitType& splitter,
BinarySpaceTree* parent,
const size_t maxLeafSize) :
left(NULL),
@@ -154,7 +159,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
assert(oldFromNew.size() == data.n_cols);
// Perform the actual splitting.
- SplitNode(data, oldFromNew, maxLeafSize);
+ SplitNode(data, oldFromNew, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -170,6 +175,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
const size_t count,
std::vector<size_t>& oldFromNew,
std::vector<size_t>& newFromOld,
+ SplitType& splitter,
BinarySpaceTree* parent,
const size_t maxLeafSize) :
left(NULL),
@@ -185,7 +191,7 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
Log::Assert(oldFromNew.size() == data.n_cols);
// Perform the actual splitting.
- SplitNode(data, oldFromNew, maxLeafSize);
+ SplitNode(data, oldFromNew, maxLeafSize, splitter);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -515,7 +521,8 @@ template<typename BoundType,
typename SplitType>
void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
MatType& data,
- const size_t maxLeafSize)
+ const size_t maxLeafSize,
+ SplitType& splitter)
{
// We need to expand the bounds of this node properly.
bound |= data.cols(begin, begin + count - 1);
@@ -534,7 +541,7 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
// Split the node. The elements of 'data' are reordered by the splitting
// algorithm. This function call updates splitCol.
- const bool split = SplitType::SplitNode(bound, data, begin, count, splitCol);
+ const bool split = splitter.SplitNode(bound, data, begin, count, splitCol);
// The node may not be always split. For instance, if all the points are the
// same, we can't split them.
@@ -544,9 +551,9 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
- splitCol - begin, this, maxLeafSize);
+ splitCol - begin, splitter, this, maxLeafSize);
right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
- begin + count - splitCol, this, maxLeafSize);
+ begin + count - splitCol, splitter, this, maxLeafSize);
// Calculate parent distances for those two nodes.
arma::vec centroid, leftCentroid, rightCentroid;
@@ -570,7 +577,8 @@ template<typename BoundType,
void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
MatType& data,
std::vector<size_t>& oldFromNew,
- const size_t maxLeafSize)
+ const size_t maxLeafSize,
+ SplitType& splitter)
{
// This should be a single function for Bound.
// We need to expand the bounds of this node properly.
@@ -590,7 +598,7 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
// Split the node. The elements of 'data' are reordered by the splitting
// algorithm. This function call updates splitCol and oldFromNew.
- const bool split = SplitType::SplitNode(bound, data, begin, count, splitCol,
+ const bool split = splitter.SplitNode(bound, data, begin, count, splitCol,
oldFromNew);
// The node may not be always split. For instance, if all the points are the
@@ -601,9 +609,9 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
- splitCol - begin, oldFromNew, this, maxLeafSize);
+ splitCol - begin, oldFromNew, splitter, this, maxLeafSize);
right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
- begin + count - splitCol, oldFromNew, this, maxLeafSize);
+ begin + count - splitCol, oldFromNew, splitter, this, maxLeafSize);
// Calculate parent distances for those two nodes.
arma::vec centroid, leftCentroid, rightCentroid;
More information about the mlpack-git
mailing list