[mlpack-svn] r10713 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 12 04:49:26 EST 2011


Author: rcurtin
Date: 2011-12-12 04:49:25 -0500 (Mon, 12 Dec 2011)
New Revision: 10713

Modified:
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp
Log:
Templatize the matrix type.


Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp	2011-12-12 09:32:31 UTC (rev 10712)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree.hpp	2011-12-12 09:49:25 UTC (rev 10713)
@@ -40,7 +40,8 @@
  *     for the necessary skeleton interface.
  */
 template<typename BoundType,
-         typename StatisticType = EmptyStatistic>
+         typename StatisticType = EmptyStatistic,
+         typename MatType = arma::mat>
 class BinarySpaceTree
 {
  private:
@@ -69,7 +70,7 @@
    * @param data Dataset to create tree from.  This will be modified!
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data, const size_t leafSize = 20);
+  BinarySpaceTree(MatType& data, const size_t leafSize = 20);
 
   /**
    * Construct this as the root node of a binary space tree using the given
@@ -81,7 +82,7 @@
    *     each new point.
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data,
+  BinarySpaceTree(MatType& data,
                   std::vector<size_t>& oldFromNew,
                   const size_t leafSize = 20);
 
@@ -98,7 +99,7 @@
    *     each old point.
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data,
+  BinarySpaceTree(MatType& data,
                   std::vector<size_t>& oldFromNew,
                   std::vector<size_t>& newFromOld,
                   const size_t leafSize = 20);
@@ -114,7 +115,7 @@
    * @param count Number of points to use to construct tree.
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data,
+  BinarySpaceTree(MatType& data,
                   const size_t begin,
                   const size_t count,
                   const size_t leafSize = 20);
@@ -137,7 +138,7 @@
    *     each new point.
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data,
+  BinarySpaceTree(MatType& data,
                   const size_t begin,
                   const size_t count,
                   std::vector<size_t>& oldFromNew,
@@ -164,7 +165,7 @@
    *     each old point.
    * @param leafSize Size of each leaf in the tree.
    */
-  BinarySpaceTree(arma::mat& data,
+  BinarySpaceTree(MatType& data,
                   const size_t begin,
                   const size_t count,
                   std::vector<size_t>& oldFromNew,
@@ -293,7 +294,7 @@
    *
    * @param data Dataset which we are using.
    */
-  void SplitNode(arma::mat& data);
+  void SplitNode(MatType& data);
 
   /**
    * Splits the current node, assigning its left and right children recursively.
@@ -302,7 +303,7 @@
    * @param data Dataset which we are using.
    * @param oldFromNew Vector holding permuted indices.
    */
-  void SplitNode(arma::mat& data, std::vector<size_t>& oldFromNew);
+  void SplitNode(MatType& data, std::vector<size_t>& oldFromNew);
 
   /**
    * Find the index to split on for this node, given that we are splitting in
@@ -312,7 +313,7 @@
    * @param splitDim Dimension of dataset to split on.
    * @param splitVal Value to split on, in the given split dimension.
    */
-  size_t GetSplitIndex(arma::mat& data, int splitDim, double splitVal);
+  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal);
 
   /**
    * Find the index to split on for this node, given that we are splitting in
@@ -324,7 +325,7 @@
    * @param splitVal Value to split on, in the given split dimension.
    * @param oldFromNew Vector holding permuted indices.
    */
-  size_t GetSplitIndex(arma::mat& data, int splitDim, double splitVal,
+  size_t GetSplitIndex(MatType& data, int splitDim, double splitVal,
       std::vector<size_t>& oldFromNew);
 };
 

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp	2011-12-12 09:32:31 UTC (rev 10712)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree_impl.hpp	2011-12-12 09:49:25 UTC (rev 10713)
@@ -17,9 +17,9 @@
 
 // Each of these overloads is kept as a separate function to keep the overhead
 // from the two std::vectors out, if possible.
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     const size_t leafSize) :
     left(NULL),
     right(NULL),
@@ -33,9 +33,9 @@
   SplitNode(data);
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     std::vector<size_t>& oldFromNew,
     const size_t leafSize) :
     left(NULL),
@@ -55,9 +55,9 @@
   SplitNode(data, oldFromNew);
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     std::vector<size_t>& oldFromNew,
     std::vector<size_t>& newFromOld,
     const size_t leafSize) :
@@ -83,9 +83,9 @@
     newFromOld[oldFromNew[i]] = i;
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     const size_t begin,
     const size_t count,
     const size_t leafSize) :
@@ -101,9 +101,9 @@
   SplitNode(data);
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     const size_t begin,
     const size_t count,
     std::vector<size_t>& oldFromNew,
@@ -124,9 +124,9 @@
   SplitNode(data, oldFromNew);
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree(
+    MatType& data,
     const size_t begin,
     const size_t count,
     std::vector<size_t>& oldFromNew,
@@ -153,8 +153,8 @@
     newFromOld[oldFromNew[i]] = i;
 }
 
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::BinarySpaceTree() :
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::BinarySpaceTree() :
     left(NULL),
     right(NULL),
     begin(0),
@@ -171,8 +171,8 @@
  * destructors in turn.  This will invalidate any pointers or references to any
  * nodes which are children of this one.
  */
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>::~BinarySpaceTree()
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>::~BinarySpaceTree()
 {
   if (left)
     delete left;
@@ -191,10 +191,11 @@
  * @param queryCount The Count() of the node to find.
  * @return The found node, or NULL if nothing is found.
  */
-template<typename BoundType, typename StatisticType>
-const BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::FindByBeginCount(size_t queryBegin,
-                                                    size_t queryCount) const
+template<typename BoundType, typename StatisticType, typename MatType>
+const BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
+    size_t queryBegin,
+    size_t queryCount) const
 {
   Log::Assert(queryBegin >= begin);
   Log::Assert(queryCount <= count);
@@ -220,9 +221,9 @@
  * @param queryCount the Count() of the node to find
  * @return the found node, or NULL
  */
-template<typename BoundType, typename StatisticType>
-BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::FindByBeginCount(
+template<typename BoundType, typename StatisticType, typename MatType>
+BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::FindByBeginCount(
     const size_t queryBegin,
     const size_t queryCount)
 {
@@ -241,8 +242,9 @@
     return NULL;
 }
 
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::ExtendTree(size_t level)
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::ExtendTree(
+    size_t level)
 {
   --level;
   // Return the number of nodes duplicated.
@@ -271,16 +273,16 @@
  *     to avoid exceeding the stack limit
  */
 
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::TreeSize() const
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeSize() const
 {
   // Recursively count the nodes on each side of the tree.  The plus one is
   // because we have to count this node, too.
   return 1 + (left ? left->TreeSize() : 0) + (right ? right->TreeSize() : 0);
 }
 
-template<typename Bound, typename Statistic>
-size_t BinarySpaceTree<Bound, Statistic>::TreeDepth() const
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::TreeDepth() const
 {
   // Recursively count the depth on each side of the tree.  The plus one is
   // because we have to count this node, too.
@@ -288,33 +290,33 @@
                       (right ? right->TreeDepth() : 0));
 }
 
-template<typename BoundType, typename StatisticType>
-inline const BoundType& BinarySpaceTree<BoundType, StatisticType>::Bound() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline const BoundType& BinarySpaceTree<BoundType, StatisticType, MatType>::Bound() const
 {
   return bound;
 }
 
-template<typename BoundType, typename StatisticType>
-inline BoundType& BinarySpaceTree<BoundType, StatisticType>::Bound()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BoundType& BinarySpaceTree<BoundType, StatisticType, MatType>::Bound()
 {
   return bound;
 }
 
-template<typename BoundType, typename StatisticType>
-inline const StatisticType& BinarySpaceTree<BoundType, StatisticType>::Stat()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline const StatisticType& BinarySpaceTree<BoundType, StatisticType, MatType>::Stat()
     const
 {
   return stat;
 }
 
-template<typename BoundType, typename StatisticType>
-inline StatisticType& BinarySpaceTree<BoundType, StatisticType>::Stat()
+template<typename BoundType, typename StatisticType, typename MatType>
+inline StatisticType& BinarySpaceTree<BoundType, StatisticType, MatType>::Stat()
 {
   return stat;
 }
 
-template<typename BoundType, typename StatisticType>
-inline bool BinarySpaceTree<BoundType, StatisticType>::IsLeaf() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline bool BinarySpaceTree<BoundType, StatisticType, MatType>::IsLeaf() const
 {
   return !left;
 }
@@ -322,9 +324,9 @@
 /**
  * Gets the left branch of the tree.
  */
-template<typename BoundType, typename StatisticType>
-inline BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::Left() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::Left() const
 {
   return left;
 }
@@ -332,9 +334,9 @@
 /**
  * Gets the right branch.
  */
-template<typename BoundType, typename StatisticType>
-inline BinarySpaceTree<BoundType, StatisticType>*
-BinarySpaceTree<BoundType, StatisticType>::Right() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline BinarySpaceTree<BoundType, StatisticType, MatType>*
+BinarySpaceTree<BoundType, StatisticType, MatType>::Right() const
 {
   return right;
 }
@@ -342,8 +344,8 @@
 /**
  * Gets the index of the begin point of this subset.
  */
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::Begin() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::Begin() const
 {
   return begin;
 }
@@ -351,8 +353,8 @@
 /**
  * Gets the index one beyond the last index in the series.
  */
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::End() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::End() const
 {
   return begin + count;
 }
@@ -360,14 +362,14 @@
 /**
  * Gets the number of points in this subset.
  */
-template<typename BoundType, typename StatisticType>
-inline size_t BinarySpaceTree<BoundType, StatisticType>::Count() const
+template<typename BoundType, typename StatisticType, typename MatType>
+inline size_t BinarySpaceTree<BoundType, StatisticType, MatType>::Count() const
 {
   return count;
 }
 
-template<typename BoundType, typename StatisticType>
-void BinarySpaceTree<BoundType, StatisticType>::SplitNode(arma::mat& data)
+template<typename BoundType, typename StatisticType, typename MatType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(MatType& data)
 {
   // This should be a single function for Bound.
   // We need to expand the bounds of this node properly.
@@ -408,15 +410,15 @@
 
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
-  left = new BinarySpaceTree<BoundType, StatisticType>(data, begin,
+  left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
       splitCol - begin, leafSize);
-  right = new BinarySpaceTree<BoundType, StatisticType>(data, splitCol,
+  right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
       begin + count - splitCol, leafSize);
 }
 
-template<typename BoundType, typename StatisticType>
-void BinarySpaceTree<BoundType, StatisticType>::SplitNode(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+void BinarySpaceTree<BoundType, StatisticType, MatType>::SplitNode(
+    MatType& data,
     std::vector<size_t>& oldFromNew)
 {
   // This should be a single function for Bound.
@@ -458,15 +460,15 @@
 
   // Now that we know the split column, we will recursively split the children
   // by calling their constructors (which perform this splitting process).
-  left = new BinarySpaceTree<BoundType, StatisticType>(data, begin,
+  left = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, begin,
       splitCol - begin, oldFromNew, leafSize);
-  right = new BinarySpaceTree<BoundType, StatisticType>(data, splitCol,
+  right = new BinarySpaceTree<BoundType, StatisticType, MatType>(data, splitCol,
       begin + count - splitCol, oldFromNew, leafSize);
 }
 
-template<typename BoundType, typename StatisticType>
-size_t BinarySpaceTree<BoundType, StatisticType>::GetSplitIndex(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+    MatType& data,
     int splitDim,
     double splitVal)
 {
@@ -508,9 +510,9 @@
   return left;
 }
 
-template<typename BoundType, typename StatisticType>
-size_t BinarySpaceTree<BoundType, StatisticType>::GetSplitIndex(
-    arma::mat& data,
+template<typename BoundType, typename StatisticType, typename MatType>
+size_t BinarySpaceTree<BoundType, StatisticType, MatType>::GetSplitIndex(
+    MatType& data,
     int splitDim,
     double splitVal,
     std::vector<size_t>& oldFromNew)




More information about the mlpack-svn mailing list