[mlpack-svn] r10713 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 04:49:26 EST 2011
Author: rcurtin
Date: 2011-12-12 04:49:25 -0500 (Mon, 12 Dec 2011)
New Revision: 10713
Modified:
mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp
mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp
Log:
Templatize the matrix type.
Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp 2011-12-12 09:32:31 UTC (rev 10712)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp 2011-12-12 09:49:25 UTC (rev 10713)
@@ -40,7 +40,8 @@
* for the necessary skeleton interface.
*/
template<typename BoundType,
- typename StatisticType = EmptyStatistic>
+ typename StatisticType = EmptyStatistic,
+ typename MatType = arma::mat>
class BinarySpaceTree
{
private:
@@ -69,7 +70,7 @@
* @param data Dataset to create tree from. This will be modified!
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data, const size_t leafSize = 20);
+ BinarySpaceTree(MatType& data, const size_t leafSize = 20);
/**
* Construct this as the root node of a binary space tree using the given
@@ -81,7 +82,7 @@
* each new point.
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data,
+ BinarySpaceTree(MatType& data,
std::vector<size_t>& oldFromNew,
const size_t leafSize = 20);
@@ -98,7 +99,7 @@
* each old point.
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data,
+ BinarySpaceTree(MatType& data,
std::vector<size_t>& oldFromNew,
std::vector<size_t>& newFromOld,
const size_t leafSize = 20);
@@ -114,7 +115,7 @@
* @param count Number of points to use to construct tree.
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data,
+ BinarySpaceTree(MatType& data,
const size_t begin,
const size_t count,
const size_t leafSize = 20);
@@ -137,7 +138,7 @@
* each new point.
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data,
+ BinarySpaceTree(MatType& data,
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
@@ -164,7 +165,7 @@
* each old point.
* @param leafSize Size of each leaf in the tree.
*/
- BinarySpaceTree(arma::mat& data,
+ BinarySpaceTree(MatType& data,
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
@@ -293,7 +294,7 @@
*
* @param data Dataset which we are using.
*/
- void SplitNode(arma::mat& data);
+ void SplitNode(MatType& data);
/**
* Splits the current node, assigning its left and right children recursively.
@@ -302,7 +303,7 @@
* @param data Dataset which we are using.
* @param oldFromNew Vector holding permuted indices.
*/
- void SplitNode(arma::mat& data, std::vector<size_t>& oldFromNew);
+ void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
/**
* Find the index to split on for this node, given that we are splitting in
@@ -312,7 +313,7 @@
* @param splitDim Dimension of dataset to split on.
* @param splitVal Value to split on, in the given split dimension.
*/
- size_t GetSplitIndex(arma::mat& data, int splitDim, double splitVal);
+ size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
/**
* Find the index to split on for this node, given that we are splitting in
@@ -324,7 +325,7 @@
* @param splitVal Value to split on, in the given split dimension.
* @param oldFromNew Vector holding permuted indices.
*/
- size_t GetSplitIndex(arma::mat& data, int splitDim, double splitVal,
+ size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
std::vector<size_t>& oldFromNew);
};
Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp 2011-12-12 09:32:31 UTC (rev 10712)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp 2011-12-12 09:49:25 UTC (rev 10713)
@@ -17,9 +17,9 @@
// Each of these overloads is kept as a separate function to keep the overhead
// from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
const size_t leafSize) :
left(NULL),
right(NULL),
@@ -33,9 +33,9 @@
SplitNode(data);
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
std::vector<size_t>& oldFromNew,
const size_t leafSize) :
left(NULL),
@@ -55,9 +55,9 @@
SplitNode(data, oldFromNew);
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
std::vector<size_t>& oldFromNew,
std::vector<size_t>& newFromOld,
const size_t leafSize) :
@@ -83,9 +83,9 @@
newFromOld[oldFromNew[i]] = i;
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
const size_t begin,
const size_t count,
const size_t leafSize) :
@@ -101,9 +101,9 @@
SplitNode(data);
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
@@ -124,9 +124,9 @@
SplitNode(data, oldFromNew);
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+ MatType& data,
const size_t begin,
const size_t count,
std::vector<size_t>& oldFromNew,
@@ -153,8 +153,8 @@
newFromOld[oldFromNew[i]] = i;
}
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree() :
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
left(NULL),
right(NULL),
begin(0),
@@ -171,8 +171,8 @@
* destructors in turn. This will invalidate any pointers or references to any
* nodes which are children of this one.
*/
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::~BinarySpaceTree()
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
{
if (left)
delete left;
@@ -191,10 +191,11 @@
* @param queryCount The Count() of the node to find.
* @return The found node, or NULL if nothing is found.
*/
-template<typename BoundType, typename StatisticType>
-const BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::FindByBeginCount(size_t queryBegin,
- size_t queryCount) const
+template<typename BoundType, typename StatisticType, typename MatType>
+const BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+ size_t queryBegin,
+ size_t queryCount) const
{
Log::Assert(queryBegin >= begin);
Log::Assert(queryCount <= count);
@@ -220,9 +221,9 @@
* @param queryCount the Count() of the node to find
* @return the found node, or NULL
*/
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::FindByBeginCount(
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
const size_t queryBegin,
const size_t queryCount)
{
@@ -241,8 +242,9 @@
return NULL;
}
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::ExtendTree(size_t level)
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+ size_t level)
{
--level;
// Return the number of nodes duplicated.
@@ -271,16 +273,16 @@
* to avoid exceeding the stack limit
*/
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::TreeSize() const
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
{
// Recursively count the nodes on each side of the tree. The plus one is
// because we have to count this node, too.
return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
}
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::TreeDepth() const
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
{
// Recursively count the depth on each side of the tree. The plus one is
// because we have to count this node, too.
@@ -288,33 +290,33 @@
(right ? right->TreeDepth() : 0));
}
-template<typename BoundType, typename StatisticType>
-inline const BoundType& BinarySpaceTree<BoundType, StatisticType>::Bound() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline const BoundType& BinarySpaceTree<BoundType, StatisticType, MatType>::Bound() const
{
return bound;
}
-template<typename BoundType, typename StatisticType>
-inline BoundType& BinarySpaceTree<BoundType, StatisticType>::Bound()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BoundType& BinarySpaceTree<BoundType, StatisticType, MatType>::Bound()
{
return bound;
}
-template<typename BoundType, typename StatisticType>
-inline const StatisticType& BinarySpaceTree<BoundType, StatisticType>::Stat()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline const StatisticType& BinarySpaceTree<BoundType, StatisticType, MatType>::Stat()
const
{
return stat;
}
-template<typename BoundType, typename StatisticType>
-inline StatisticType& BinarySpaceTree<BoundType, StatisticType>::Stat()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline StatisticType& BinarySpaceTree<BoundType, StatisticType, MatType>::Stat()
{
return stat;
}
-template<typename BoundType, typename StatisticType>
-inline bool BinarySpaceTree<BoundType, StatisticType>::IsLeaf() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
{
return !left;
}
@@ -322,9 +324,9 @@
/**
* Gets the left branch of the tree.
*/
-template<typename BoundType, typename StatisticType>
-inline BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::Left() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::Left() const
{
return left;
}
@@ -332,9 +334,9 @@
/**
* Gets the right branch.
*/
-template<typename BoundType, typename StatisticType>
-inline BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::Right() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::Right() const
{
return right;
}
@@ -342,8 +344,8 @@
/**
* Gets the index of the begin point of this subset.
*/
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::Begin() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::Begin() const
{
return begin;
}
@@ -351,8 +353,8 @@
/**
* Gets the index one beyond the last index in the series.
*/
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::End() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
{
return begin + count;
}
@@ -360,14 +362,14 @@
/**
* Gets the number of points in this subset.
*/
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::Count() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::Count() const
{
return count;
}
-template<typename BoundType, typename StatisticType>
-void BinarySpaceTree<BoundType, StatisticType>::SplitNode(arma::mat& data)
+template<typename BoundType, typename StatisticType, typename MatType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(MatType& data)
{
// This should be a single function for Bound.
// We need to expand the bounds of this node properly.
@@ -408,15 +410,15 @@
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
- left = new BinarySpaceTree<BoundType, StatisticType>(data, begin,
+ left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
splitCol - begin, leafSize);
- right = new BinarySpaceTree<BoundType, StatisticType>(data, splitCol,
+ right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
begin + count - splitCol, leafSize);
}
-template<typename BoundType, typename StatisticType>
-void BinarySpaceTree<BoundType, StatisticType>::SplitNode(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+ MatType& data,
std::vector<size_t>& oldFromNew)
{
// This should be a single function for Bound.
@@ -458,15 +460,15 @@
// Now that we know the split column, we will recursively split the children
// by calling their constructors (which perform this splitting process).
- left = new BinarySpaceTree<BoundType, StatisticType>(data, begin,
+ left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
splitCol - begin, oldFromNew, leafSize);
- right = new BinarySpaceTree<BoundType, StatisticType>(data, splitCol,
+ right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
begin + count - splitCol, oldFromNew, leafSize);
}
-template<typename BoundType, typename StatisticType>
-size_t BinarySpaceTree<BoundType, StatisticType>::GetSplitIndex(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+ MatType& data,
int splitDim,
double splitVal)
{
@@ -508,9 +510,9 @@
return left;
}
-template<typename BoundType, typename StatisticType>
-size_t BinarySpaceTree<BoundType, StatisticType>::GetSplitIndex(
- arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+ MatType& data,
int splitDim,
double splitVal,
std::vector<size_t>& oldFromNew)
More information about the mlpack-svn
mailing list