[mlpack-git] master: - DET templating ready and tests passing. (45ff5ba)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 18 05:43:36 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit 45ff5baa3cf6606ccff5dddb3fbff55df261e9a3
Author: theJonan <ivan at jonan.info>
Date:   Fri Oct 14 23:06:19 2016 +0300

    - DET templating ready and tests passing.


>---------------------------------------------------------------

45ff5baa3cf6606ccff5dddb3fbff55df261e9a3
 src/mlpack/core/arma_extend/Mat_extra_bones.hpp   |   9 ++
 src/mlpack/core/arma_extend/SpMat_extra_bones.hpp |   8 +
 src/mlpack/methods/det/det_main.cpp               |   5 +-
 src/mlpack/methods/det/dt_utils.hpp               |  22 +--
 src/mlpack/methods/det/dt_utils_impl.hpp          |  22 +--
 src/mlpack/methods/det/dtree.hpp                  |   2 +-
 src/mlpack/methods/det/dtree_impl.hpp             | 187 +++++++++-------------
 src/mlpack/tests/det_test.cpp                     |  18 +--
 src/mlpack/tests/serialization_test.cpp           |   2 +-
 9 files changed, 126 insertions(+), 149 deletions(-)

diff --git a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
index e09f5f5..f4f25d6 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
@@ -12,6 +12,15 @@
 template<typename Archive>
 void serialize(Archive& ar, const unsigned int version);
 
+/**
+ * These will help us refer the proper vector / column types, only with
+ * specifying the matrix type we want to use.
+ */
+
+typedef Col<elem_type>   vec_type;
+typedef Col<elem_type>   col_type;
+typedef Row<elem_type>   row_type;
+
 /*
  * Add row_col_iterator and row_col_const_iterator to arma::Mat.
  */
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index d3c18de..a5d274c 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -16,6 +16,14 @@
 template<typename Archive>
 void serialize(Archive& ar, const unsigned int version);
 
+/**
+ * These will help us refer the proper vector / column types, only with
+ * specifying the matrix type we want to use.
+ */
+typedef SpCol<elem_type>   vec_type;
+typedef SpCol<elem_type>   col_type;
+typedef SpRow<elem_type>   row_type;
+
 /*
  * Extra functions for SpMat<eT>
  * Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index 26b394d..16ffa25 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -101,7 +101,7 @@ int main(int argc, char *argv[])
         << "(-T) is not specified." << endl;
 
   // Are we training a DET or loading from file?
-  DTree<arma::mat, arma::vec, int>* tree;
+  DTree<arma::mat, int>* tree;
   if (CLI::HasParam("training_file"))
   {
     const string trainSetFile = CLI::GetParam<string>("training_file");
@@ -127,8 +127,7 @@ int main(int argc, char *argv[])
 
     // Obtain the optimal tree.
     Timer::Start("det_training");
-    tree = Trainer<arma::mat, arma::vec, int>(trainingData, folds, regularization, maxLeafSize,
-        minLeafSize, "");
+    tree = Trainer<arma::mat, int>(trainingData, folds, regularization, maxLeafSize, minLeafSize, "");
     Timer::Stop("det_training");
 
     // Compute training set estimates, if desired.
diff --git a/src/mlpack/methods/det/dt_utils.hpp b/src/mlpack/methods/det/dt_utils.hpp
index 067e6fe..3535eae 100644
--- a/src/mlpack/methods/det/dt_utils.hpp
+++ b/src/mlpack/methods/det/dt_utils.hpp
@@ -25,8 +25,8 @@ namespace det {
  * @param numClasses Number of classes in dataset.
  * @param leafClassMembershipFile Name of file to print to (optional).
  */
-template <typename MatType, typename VecType, typename TagType>
-void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void PrintLeafMembership(DTree<MatType, TagType>* dtree,
                          const MatType& data,
                          const arma::Mat<size_t>& labels,
                          const size_t numClasses,
@@ -40,8 +40,8 @@ void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
  * @param dtree Density tree to use.
  * @param viFile Name of file to print to (optional).
  */
-template <typename MatType, typename VecType, typename TagType>
-void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void PrintVariableImportance(const DTree<MatType, TagType>* dtree,
                              const std::string viFile = "");
 
 /**
@@ -56,13 +56,13 @@ void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
  * @param minLeafSize Minimum number of points allowed in a leaf.
  * @param unprunedTreeOutput Filename to print unpruned tree to (optional).
  */
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>* Trainer(MatType& dataset,
-                                          const size_t folds,
-                                          const bool useVolumeReg = false,
-                                          const size_t maxLeafSize = 10,
-                                          const size_t minLeafSize = 5,
-                                          const std::string unprunedTreeOutput = "");
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>* Trainer(MatType& dataset,
+                                 const size_t folds,
+                                 const bool useVolumeReg = false,
+                                 const size_t maxLeafSize = 10,
+                                 const size_t minLeafSize = 5,
+                                 const std::string unprunedTreeOutput = "");
 
 } // namespace det
 } // namespace mlpack
diff --git a/src/mlpack/methods/det/dt_utils_impl.hpp b/src/mlpack/methods/det/dt_utils_impl.hpp
index 4c057f7..cad5289 100644
--- a/src/mlpack/methods/det/dt_utils_impl.hpp
+++ b/src/mlpack/methods/det/dt_utils_impl.hpp
@@ -10,8 +10,8 @@
 using namespace mlpack;
 using namespace det;
 
-template <typename MatType, typename VecType, typename TagType>
-void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void mlpack::det::PrintLeafMembership(DTree<MatType, TagType>* dtree,
                                       const MatType& data,
                                       const arma::Mat<size_t>& labels,
                                       const size_t numClasses,
@@ -25,7 +25,7 @@ void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
 
   for (size_t i = 0; i < data.n_cols; i++)
   {
-    const VecType testPoint = data.unsafe_col(i);
+    const typename MatType::vec_type testPoint = data.unsafe_col(i);
     const TagType leafTag = dtree->FindBucket(testPoint);
     const size_t label = labels[i];
     table(leafTag, label) += 1;
@@ -58,8 +58,8 @@ void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
   return;
 }
 
-template <typename MatType, typename VecType, typename TagType>
-void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void mlpack::det::PrintVariableImportance(const DTree<MatType, TagType>* dtree,
                                           const std::string viFile)
 {
   arma::vec imps;
@@ -97,8 +97,8 @@ void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>
 
 // This function trains the optimal decision tree using the given number of
 // folds.
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
                                                        const size_t folds,
                                                        const bool useVolumeReg,
                                                        const size_t maxLeafSize,
@@ -106,7 +106,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
                                                        const std::string unprunedTreeOutput)
 {
   // Initialize the tree.
-  DTree<MatType, VecType, TagType> dtree(dataset);
+  DTree<MatType, TagType> dtree(dataset);
 
   // Prepare to grow the tree...
   arma::Col<size_t> oldFromNew(dataset.n_cols);
@@ -211,7 +211,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
     }
 
     // Initialize the tree.
-    DTree<MatType, VecType, TagType> cvDTree(train);
+    DTree<MatType, TagType> cvDTree(train);
 
     // Getting ready to grow the tree...
     arma::Col<size_t> cvOldFromNew(train.n_cols);
@@ -251,7 +251,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
     double cvVal = 0.0;
     for (size_t i = 0; i < test.n_cols; ++i)
     {
-      VecType testPoint = test.unsafe_col(i);
+      typename MatType::vec_type testPoint = test.unsafe_col(i);
       cvVal += cvDTree.ComputeValue(testPoint);
     }
 
@@ -283,7 +283,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
   Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
 
   // Initialize the tree.
-  DTree<MatType, VecType, TagType>* dtreeOpt = new DTree<MatType, VecType, TagType>(dataset);
+  DTree<MatType, TagType>* dtreeOpt = new DTree<MatType, TagType>(dataset);
 
   // Getting ready to grow the tree...
   for (size_t i = 0; i < oldFromNew.n_elem; i++)
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index f34750f..6234287 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -37,7 +37,6 @@ namespace det /** Density Estimation Trees */ {
  * @endcode
  */
 template <typename MatType,
-          typename VecType,
           typename TagType = int>
 class DTree
 {
@@ -46,6 +45,7 @@ class DTree
    * The actual, underlying type we're working with
    */
   typedef typename MatType::elem_type ElemType;
+  typedef typename MatType::vec_type  VecType;
   
   /**
    * Create an empty density estimation tree.
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 2261bbf..b456024 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -32,63 +32,10 @@ namespace detail
     ElemType    splitVal;
     size_t      splitDimension;
   };
-  
-  /**
-   * We need that function, to be able to specialize it for sparse matrices
-   * in a way which is much faster then usual iteration.
-   */
-  template <typename MatType, typename VecType>
-  void ExtractMinMax(const MatType& data,
-                     VecType& minVals,
-                     VecType& maxVals)
-  {
-    // Initialize to first column; values will be overwritten if necessary.
-    maxVals = data.col(0);
-    minVals = data.col(0);
-    
-    // Loop over data to extract maximum and minimum values in each dimension.
-    for (size_t i = 1; i < data.n_cols; ++i)
-    {
-      for (size_t j = 0; j < data.n_rows; ++j)
-      {
-        if (data(j, i) > maxVals[j])
-          maxVals[j] = data(j, i);
-        if (data(j, i) < minVals[j])
-          minVals[j] = data(j, i);
-      }
-    }
-  }
-  
-  /**
-   * Here is the optimized specialization
-   */
-  template <typename ElemType>
-  void ExtractMinMax(const arma::SpMat<ElemType>& data,
-                     arma::SpCol<ElemType>& minVals,
-                     arma::SpCol<ElemType>& maxVals)
-  {
-    // Initialize to first column; values will be overwritten if necessary.
-    maxVals = data.col(0);
-    minVals = data.col(0);
-    
-    typename arma::sp_mat::iterator dataEnd = data.end();
-    
-    // Loop over data to extract maximum and minimum values in each dimension.
-    for (typename arma::sp_mat::iterator i = data.begin(); i != dataEnd; ++i)
-    {
-      size_t j = i.row();
-      if (i.col() == 0)
-        continue; // we've already taken these values.
-      else if (*i > maxVals[j])
-        maxVals[j] = *i;
-      else if (*i < minVals[j])
-        minVals[j] = *i;
-    }
-  }
 };
 
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree() :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree() :
     start(0),
     end(0),
     splitDim(size_t(-1)),
@@ -108,10 +55,10 @@ DTree<MatType, VecType, TagType>::DTree() :
 
 // Root node initializers
 
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
-                                        const VecType& minVals,
-                                        const size_t totalPoints) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+                               const VecType& minVals,
+                               const size_t totalPoints) :
     start(0),
     end(totalPoints),
     maxVals(maxVals),
@@ -130,8 +77,8 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(MatType & data) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(MatType & data) :
     start(0),
     end(data.n_cols),
     splitDim(size_t(-1)),
@@ -146,17 +93,31 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
     left(NULL),
     right(NULL)
 {
-  detail::ExtractMinMax(data, minVals, maxVals);
+  maxVals = data.col(0);
+  minVals = data.col(0);
+  
+  typename MatType::row_col_iterator dataEnd = data.end_row_col();
+  
+  // Loop over data to extract maximum and minimum values in each dimension.
+  for (typename MatType::row_col_iterator i = data.begin_row_col(); i != dataEnd; ++i)
+  {
+    size_t j = i.row();
+    if (*i > maxVals[j])
+      maxVals[j] = *i;
+    else if (*i < minVals[j])
+      minVals[j] = *i;
+  }
+
   logNegError = LogNegativeError(data.n_cols);
 }
 
 // Non-root node initializers
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
-                                        const VecType& minVals,
-                                        const size_t start,
-                                        const size_t end,
-                                        const double logNegError) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+                               const VecType& minVals,
+                               const size_t start,
+                               const size_t end,
+                               const double logNegError) :
     start(start),
     end(end),
     maxVals(maxVals),
@@ -175,12 +136,12 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
-                                        const VecType& minVals,
-                                        const size_t totalPoints,
-                                        const size_t start,
-                                        const size_t end) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+                               const VecType& minVals,
+                               const size_t totalPoints,
+                               const size_t start,
+                               const size_t end) :
     start(start),
     end(end),
     maxVals(maxVals),
@@ -199,8 +160,8 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
     right(NULL)
 { /* Nothing to do. */ }
 
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::~DTree()
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::~DTree()
 {
   delete left;
   delete right;
@@ -208,8 +169,8 @@ DTree<MatType, VecType, TagType>::~DTree()
 
 // This function computes the log-l2-negative-error of a given node from the
 // formula R(t) = log(|t|^2 / (N^2 V_t)).
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoints) const
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::LogNegativeError(const size_t totalPoints) const
 {
   // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
   double err = 2 * std::log((double) (end - start)) -
@@ -230,13 +191,13 @@ double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoin
 // This function finds the best split with respect to the L2-error, by trying
 // all possible splits.  The dataset is the full data set but the start and
 // end are used to obtain the point in this node.
-template <typename MatType, typename VecType, typename TagType>
-bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
-                                                 size_t& splitDim,
-                                                 ElemType& splitValue,
-                                                 double& leftError,
-                                                 double& rightError,
-                                                 const size_t minLeafSize) const
+template <typename MatType, typename TagType>
+bool DTree<MatType, TagType>::FindSplit(const MatType& data,
+                                        size_t& splitDim,
+                                        ElemType& splitValue,
+                                        double& leftError,
+                                        double& rightError,
+                                        const size_t minLeafSize) const
 {
   // Ensure the dimensionality of the data is the same as the dimensionality of
   // the bounding rectangle.
@@ -280,7 +241,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
     double volumeWithoutDim = logVolume - std::log(max - min);
 
     // Get the values for the dimension.
-    VecType dimVec = data.row(dim).subvec(start, end - 1);
+    typename MatType::row_type dimVec = data.row(dim).subvec(start, end - 1);
 
     // Sort the values in ascending order.
     dimVec = arma::sort(dimVec);
@@ -347,11 +308,11 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
   return splitFound;
 }
 
-template <typename MatType, typename VecType, typename TagType>
-size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
-                                                   const size_t splitDim,
-                                                   const ElemType splitValue,
-                                                   arma::Col<size_t>& oldFromNew) const
+template <typename MatType, typename TagType>
+size_t DTree<MatType, TagType>::SplitData(MatType& data,
+                                          const size_t splitDim,
+                                          const ElemType splitValue,
+                                          arma::Col<size_t>& oldFromNew) const
 {
   // Swap all columns such that any columns with value in dimension splitDim
   // less than or equal to splitValue are on the left side, and all others are
@@ -382,12 +343,12 @@ size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
 }
 
 // Greedily expand the tree
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::Grow(MatType& data,
-                                              arma::Col<size_t>& oldFromNew,
-                                              const bool useVolReg,
-                                              const size_t maxLeafSize,
-                                              const size_t minLeafSize)
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::Grow(MatType& data,
+                                     arma::Col<size_t>& oldFromNew,
+                                     const bool useVolReg,
+                                     const size_t maxLeafSize,
+                                     const size_t minLeafSize)
 {
   Log::Assert(data.n_rows == maxVals.n_elem);
   Log::Assert(data.n_rows == minVals.n_elem);
@@ -527,10 +488,10 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
 }
 
 
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
-                                                        const size_t points,
-                                                        const bool useVolReg)
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::PruneAndUpdate(const double oldAlpha,
+                                               const size_t points,
+                                               const bool useVolReg)
 {
   // Compute gT.
   if (subtreeLeaves == 1) // If we are a leaf...
@@ -642,8 +603,8 @@ double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
 //
 // Future improvement: Open up the range with epsilons on both sides where
 // epsilon depends on the density near the boundary.
-template <typename MatType, typename VecType, typename TagType>
-bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
+template <typename MatType, typename TagType>
+bool DTree<MatType, TagType>::WithinRange(const VecType& query) const
 {
   for (size_t i = 0; i < query.n_elem; ++i)
     if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
@@ -653,8 +614,8 @@ bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
 }
 
 
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) const
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::ComputeValue(const VecType& query) const
 {
   Log::Assert(query.n_elem == maxVals.n_elem);
 
@@ -680,8 +641,8 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
 
 
 // Index the buckets for possible usage later.
-template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
+template <typename MatType, typename TagType>
+TagType DTree<MatType, TagType>::TagTree(const TagType& tag)
 {
   if (subtreeLeaves == 1)
   {
@@ -696,8 +657,8 @@ TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
 }
 
 
-template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
+template <typename MatType, typename TagType>
+TagType DTree<MatType, TagType>::FindBucket(const VecType& query) const
 {
   Log::Assert(query.n_elem == maxVals.n_elem);
 
@@ -712,8 +673,8 @@ TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
   }
 }
 
-template <typename MatType, typename VecType, typename TagType>
-void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& importances) const
+template <typename MatType, typename TagType>
+void DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
 {
   // Clear and set to right size.
   importances.zeros(maxVals.n_elem);
@@ -740,9 +701,9 @@ void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& impo
   }
 }
 
-template <typename MatType, typename VecType, typename TagType>
+template <typename MatType, typename TagType>
 template <typename Archive>
-void DTree<MatType, VecType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
+void DTree<MatType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
 {
   using data::CreateNVP;
   
diff --git a/src/mlpack/tests/det_test.cpp b/src/mlpack/tests/det_test.cpp
index 2b0ef37..3365984 100644
--- a/src/mlpack/tests/det_test.cpp
+++ b/src/mlpack/tests/det_test.cpp
@@ -42,7 +42,7 @@ BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<arma::mat, arma::vec> tree(testData);
+  DTree<arma::mat> tree(testData);
 
   BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
   BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(TestComputeNodeError)
   arma::vec maxVals("7 7 8");
   arma::vec minVals("3 0 1");
 
-  DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
+  DTree<arma::mat> testDTree(maxVals, minVals, 5);
   double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
 
   BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(TestWithinRange)
   arma::vec maxVals("7 7 8");
   arma::vec minVals("3 0 1");
 
-  DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
+  DTree<arma::mat> testDTree(maxVals, minVals, 5);
 
   arma::vec testQuery(3);
   testQuery << 4.5 << 2.5 << 2;
@@ -95,7 +95,7 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
 
   size_t obDim, trueDim;
   double trueLeftError, obLeftError, trueRightError, obRightError, obSplit, trueSplit;
@@ -123,7 +123,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
            << 5 << 0 << 1 << 7 << 1 << arma::endr
            << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
 
   arma::Col<size_t> oTest(5);
   oTest << 1 << 2 << 3 << 4 << 5;
@@ -166,7 +166,7 @@ BOOST_AUTO_TEST_CASE(TestGrow)
   rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
   rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
 
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
   BOOST_REQUIRE_EQUAL(oTest[0], 0);
@@ -219,7 +219,7 @@ BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
 
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
   alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
 
@@ -252,7 +252,7 @@ BOOST_AUTO_TEST_CASE(TestComputeValue)
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
   double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
@@ -295,7 +295,7 @@ BOOST_AUTO_TEST_CASE(TestVariableImportance)
   arma::Col<size_t> oTest(5);
   oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree<arma::mat, arma::vec> testDTree(testData);
+  DTree<arma::mat> testDTree(testData);
   testDTree.Grow(testData, oTest, false, 2, 1);
 
   arma::vec imps;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index d7ea76e..64769df 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -853,7 +853,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
 BOOST_AUTO_TEST_CASE(DETTest)
 {
   using det::DTree;
-  typedef DTree<arma::mat, arma::vec>   DTreeX;
+  typedef DTree<arma::mat>   DTreeX;
 
   // Create a density estimation tree on a random dataset.
   arma::mat dataset = arma::randu<arma::mat>(25, 5000);




More information about the mlpack-git mailing list