[mlpack-git] master: - DET templating ready and tests passing. (45ff5ba)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Oct 18 05:43:36 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit 45ff5baa3cf6606ccff5dddb3fbff55df261e9a3
Author: theJonan <ivan at jonan.info>
Date: Fri Oct 14 23:06:19 2016 +0300
- DET templating ready and tests passing.
>---------------------------------------------------------------
45ff5baa3cf6606ccff5dddb3fbff55df261e9a3
src/mlpack/core/arma_extend/Mat_extra_bones.hpp | 9 ++
src/mlpack/core/arma_extend/SpMat_extra_bones.hpp | 8 +
src/mlpack/methods/det/det_main.cpp | 5 +-
src/mlpack/methods/det/dt_utils.hpp | 22 +--
src/mlpack/methods/det/dt_utils_impl.hpp | 22 +--
src/mlpack/methods/det/dtree.hpp | 2 +-
src/mlpack/methods/det/dtree_impl.hpp | 187 +++++++++-------------
src/mlpack/tests/det_test.cpp | 18 +--
src/mlpack/tests/serialization_test.cpp | 2 +-
9 files changed, 126 insertions(+), 149 deletions(-)
diff --git a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
index e09f5f5..f4f25d6 100644
--- a/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/Mat_extra_bones.hpp
@@ -12,6 +12,15 @@
template<typename Archive>
void serialize(Archive& ar, const unsigned int version);
+/**
+ * These will help us refer the proper vector / column types, only with
+ * specifying the matrix type we want to use.
+ */
+
+typedef Col<elem_type> vec_type;
+typedef Col<elem_type> col_type;
+typedef Row<elem_type> row_type;
+
/*
* Add row_col_iterator and row_col_const_iterator to arma::Mat.
*/
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index d3c18de..a5d274c 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -16,6 +16,14 @@
template<typename Archive>
void serialize(Archive& ar, const unsigned int version);
+/**
+ * These will help us refer the proper vector / column types, only with
+ * specifying the matrix type we want to use.
+ */
+typedef SpCol<elem_type> vec_type;
+typedef SpCol<elem_type> col_type;
+typedef SpRow<elem_type> row_type;
+
/*
* Extra functions for SpMat<eT>
* Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index 26b394d..16ffa25 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -101,7 +101,7 @@ int main(int argc, char *argv[])
<< "(-T) is not specified." << endl;
// Are we training a DET or loading from file?
- DTree<arma::mat, arma::vec, int>* tree;
+ DTree<arma::mat, int>* tree;
if (CLI::HasParam("training_file"))
{
const string trainSetFile = CLI::GetParam<string>("training_file");
@@ -127,8 +127,7 @@ int main(int argc, char *argv[])
// Obtain the optimal tree.
Timer::Start("det_training");
- tree = Trainer<arma::mat, arma::vec, int>(trainingData, folds, regularization, maxLeafSize,
- minLeafSize, "");
+ tree = Trainer<arma::mat, int>(trainingData, folds, regularization, maxLeafSize, minLeafSize, "");
Timer::Stop("det_training");
// Compute training set estimates, if desired.
diff --git a/src/mlpack/methods/det/dt_utils.hpp b/src/mlpack/methods/det/dt_utils.hpp
index 067e6fe..3535eae 100644
--- a/src/mlpack/methods/det/dt_utils.hpp
+++ b/src/mlpack/methods/det/dt_utils.hpp
@@ -25,8 +25,8 @@ namespace det {
* @param numClasses Number of classes in dataset.
* @param leafClassMembershipFile Name of file to print to (optional).
*/
-template <typename MatType, typename VecType, typename TagType>
-void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void PrintLeafMembership(DTree<MatType, TagType>* dtree,
const MatType& data,
const arma::Mat<size_t>& labels,
const size_t numClasses,
@@ -40,8 +40,8 @@ void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
* @param dtree Density tree to use.
* @param viFile Name of file to print to (optional).
*/
-template <typename MatType, typename VecType, typename TagType>
-void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void PrintVariableImportance(const DTree<MatType, TagType>* dtree,
const std::string viFile = "");
/**
@@ -56,13 +56,13 @@ void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
* @param minLeafSize Minimum number of points allowed in a leaf.
* @param unprunedTreeOutput Filename to print unpruned tree to (optional).
*/
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>* Trainer(MatType& dataset,
- const size_t folds,
- const bool useVolumeReg = false,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5,
- const std::string unprunedTreeOutput = "");
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>* Trainer(MatType& dataset,
+ const size_t folds,
+ const bool useVolumeReg = false,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5,
+ const std::string unprunedTreeOutput = "");
} // namespace det
} // namespace mlpack
diff --git a/src/mlpack/methods/det/dt_utils_impl.hpp b/src/mlpack/methods/det/dt_utils_impl.hpp
index 4c057f7..cad5289 100644
--- a/src/mlpack/methods/det/dt_utils_impl.hpp
+++ b/src/mlpack/methods/det/dt_utils_impl.hpp
@@ -10,8 +10,8 @@
using namespace mlpack;
using namespace det;
-template <typename MatType, typename VecType, typename TagType>
-void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void mlpack::det::PrintLeafMembership(DTree<MatType, TagType>* dtree,
const MatType& data,
const arma::Mat<size_t>& labels,
const size_t numClasses,
@@ -25,7 +25,7 @@ void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
for (size_t i = 0; i < data.n_cols; i++)
{
- const VecType testPoint = data.unsafe_col(i);
+ const typename MatType::vec_type testPoint = data.unsafe_col(i);
const TagType leafTag = dtree->FindBucket(testPoint);
const size_t label = labels[i];
table(leafTag, label) += 1;
@@ -58,8 +58,8 @@ void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
return;
}
-template <typename MatType, typename VecType, typename TagType>
-void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
+template <typename MatType, typename TagType>
+void mlpack::det::PrintVariableImportance(const DTree<MatType, TagType>* dtree,
const std::string viFile)
{
arma::vec imps;
@@ -97,8 +97,8 @@ void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>
// This function trains the optimal decision tree using the given number of
// folds.
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
const size_t folds,
const bool useVolumeReg,
const size_t maxLeafSize,
@@ -106,7 +106,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
const std::string unprunedTreeOutput)
{
// Initialize the tree.
- DTree<MatType, VecType, TagType> dtree(dataset);
+ DTree<MatType, TagType> dtree(dataset);
// Prepare to grow the tree...
arma::Col<size_t> oldFromNew(dataset.n_cols);
@@ -211,7 +211,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
}
// Initialize the tree.
- DTree<MatType, VecType, TagType> cvDTree(train);
+ DTree<MatType, TagType> cvDTree(train);
// Getting ready to grow the tree...
arma::Col<size_t> cvOldFromNew(train.n_cols);
@@ -251,7 +251,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
double cvVal = 0.0;
for (size_t i = 0; i < test.n_cols; ++i)
{
- VecType testPoint = test.unsafe_col(i);
+ typename MatType::vec_type testPoint = test.unsafe_col(i);
cvVal += cvDTree.ComputeValue(testPoint);
}
@@ -283,7 +283,7 @@ DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
// Initialize the tree.
- DTree<MatType, VecType, TagType>* dtreeOpt = new DTree<MatType, VecType, TagType>(dataset);
+ DTree<MatType, TagType>* dtreeOpt = new DTree<MatType, TagType>(dataset);
// Getting ready to grow the tree...
for (size_t i = 0; i < oldFromNew.n_elem; i++)
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index f34750f..6234287 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -37,7 +37,6 @@ namespace det /** Density Estimation Trees */ {
* @endcode
*/
template <typename MatType,
- typename VecType,
typename TagType = int>
class DTree
{
@@ -46,6 +45,7 @@ class DTree
* The actual, underlying type we're working with
*/
typedef typename MatType::elem_type ElemType;
+ typedef typename MatType::vec_type VecType;
/**
* Create an empty density estimation tree.
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 2261bbf..b456024 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -32,63 +32,10 @@ namespace detail
ElemType splitVal;
size_t splitDimension;
};
-
- /**
- * We need that function, to be able to specialize it for sparse matrices
- * in a way which is much faster then usual iteration.
- */
- template <typename MatType, typename VecType>
- void ExtractMinMax(const MatType& data,
- VecType& minVals,
- VecType& maxVals)
- {
- // Initialize to first column; values will be overwritten if necessary.
- maxVals = data.col(0);
- minVals = data.col(0);
-
- // Loop over data to extract maximum and minimum values in each dimension.
- for (size_t i = 1; i < data.n_cols; ++i)
- {
- for (size_t j = 0; j < data.n_rows; ++j)
- {
- if (data(j, i) > maxVals[j])
- maxVals[j] = data(j, i);
- if (data(j, i) < minVals[j])
- minVals[j] = data(j, i);
- }
- }
- }
-
- /**
- * Here is the optimized specialization
- */
- template <typename ElemType>
- void ExtractMinMax(const arma::SpMat<ElemType>& data,
- arma::SpCol<ElemType>& minVals,
- arma::SpCol<ElemType>& maxVals)
- {
- // Initialize to first column; values will be overwritten if necessary.
- maxVals = data.col(0);
- minVals = data.col(0);
-
- typename arma::sp_mat::iterator dataEnd = data.end();
-
- // Loop over data to extract maximum and minimum values in each dimension.
- for (typename arma::sp_mat::iterator i = data.begin(); i != dataEnd; ++i)
- {
- size_t j = i.row();
- if (i.col() == 0)
- continue; // we've already taken these values.
- else if (*i > maxVals[j])
- maxVals[j] = *i;
- else if (*i < minVals[j])
- minVals[j] = *i;
- }
- }
};
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree() :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree() :
start(0),
end(0),
splitDim(size_t(-1)),
@@ -108,10 +55,10 @@ DTree<MatType, VecType, TagType>::DTree() :
// Root node initializers
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
- const VecType& minVals,
- const size_t totalPoints) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t totalPoints) :
start(0),
end(totalPoints),
maxVals(maxVals),
@@ -130,8 +77,8 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(MatType & data) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(MatType & data) :
start(0),
end(data.n_cols),
splitDim(size_t(-1)),
@@ -146,17 +93,31 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
left(NULL),
right(NULL)
{
- detail::ExtractMinMax(data, minVals, maxVals);
+ maxVals = data.col(0);
+ minVals = data.col(0);
+
+ typename MatType::row_col_iterator dataEnd = data.end_row_col();
+
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (typename MatType::row_col_iterator i = data.begin_row_col(); i != dataEnd; ++i)
+ {
+ size_t j = i.row();
+ if (*i > maxVals[j])
+ maxVals[j] = *i;
+ else if (*i < minVals[j])
+ minVals[j] = *i;
+ }
+
logNegError = LogNegativeError(data.n_cols);
}
// Non-root node initializers
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
- const VecType& minVals,
- const size_t start,
- const size_t end,
- const double logNegError) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError) :
start(start),
end(end),
maxVals(maxVals),
@@ -175,12 +136,12 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
- const VecType& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end) :
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end) :
start(start),
end(end),
maxVals(maxVals),
@@ -199,8 +160,8 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-template <typename MatType, typename VecType, typename TagType>
-DTree<MatType, VecType, TagType>::~DTree()
+template <typename MatType, typename TagType>
+DTree<MatType, TagType>::~DTree()
{
delete left;
delete right;
@@ -208,8 +169,8 @@ DTree<MatType, VecType, TagType>::~DTree()
// This function computes the log-l2-negative-error of a given node from the
// formula R(t) = log(|t|^2 / (N^2 V_t)).
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoints) const
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::LogNegativeError(const size_t totalPoints) const
{
// log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
double err = 2 * std::log((double) (end - start)) -
@@ -230,13 +191,13 @@ double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoin
// This function finds the best split with respect to the L2-error, by trying
// all possible splits. The dataset is the full data set but the start and
// end are used to obtain the point in this node.
-template <typename MatType, typename VecType, typename TagType>
-bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
- size_t& splitDim,
- ElemType& splitValue,
- double& leftError,
- double& rightError,
- const size_t minLeafSize) const
+template <typename MatType, typename TagType>
+bool DTree<MatType, TagType>::FindSplit(const MatType& data,
+ size_t& splitDim,
+ ElemType& splitValue,
+ double& leftError,
+ double& rightError,
+ const size_t minLeafSize) const
{
// Ensure the dimensionality of the data is the same as the dimensionality of
// the bounding rectangle.
@@ -280,7 +241,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
double volumeWithoutDim = logVolume - std::log(max - min);
// Get the values for the dimension.
- VecType dimVec = data.row(dim).subvec(start, end - 1);
+ typename MatType::row_type dimVec = data.row(dim).subvec(start, end - 1);
// Sort the values in ascending order.
dimVec = arma::sort(dimVec);
@@ -347,11 +308,11 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
return splitFound;
}
-template <typename MatType, typename VecType, typename TagType>
-size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
- const size_t splitDim,
- const ElemType splitValue,
- arma::Col<size_t>& oldFromNew) const
+template <typename MatType, typename TagType>
+size_t DTree<MatType, TagType>::SplitData(MatType& data,
+ const size_t splitDim,
+ const ElemType splitValue,
+ arma::Col<size_t>& oldFromNew) const
{
// Swap all columns such that any columns with value in dimension splitDim
// less than or equal to splitValue are on the left side, and all others are
@@ -382,12 +343,12 @@ size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
}
// Greedily expand the tree
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::Grow(MatType& data,
- arma::Col<size_t>& oldFromNew,
- const bool useVolReg,
- const size_t maxLeafSize,
- const size_t minLeafSize)
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::Grow(MatType& data,
+ arma::Col<size_t>& oldFromNew,
+ const bool useVolReg,
+ const size_t maxLeafSize,
+ const size_t minLeafSize)
{
Log::Assert(data.n_rows == maxVals.n_elem);
Log::Assert(data.n_rows == minVals.n_elem);
@@ -527,10 +488,10 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
}
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
- const size_t points,
- const bool useVolReg)
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::PruneAndUpdate(const double oldAlpha,
+ const size_t points,
+ const bool useVolReg)
{
// Compute gT.
if (subtreeLeaves == 1) // If we are a leaf...
@@ -642,8 +603,8 @@ double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
//
// Future improvement: Open up the range with epsilons on both sides where
// epsilon depends on the density near the boundary.
-template <typename MatType, typename VecType, typename TagType>
-bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
+template <typename MatType, typename TagType>
+bool DTree<MatType, TagType>::WithinRange(const VecType& query) const
{
for (size_t i = 0; i < query.n_elem; ++i)
if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
@@ -653,8 +614,8 @@ bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
}
-template <typename MatType, typename VecType, typename TagType>
-double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) const
+template <typename MatType, typename TagType>
+double DTree<MatType, TagType>::ComputeValue(const VecType& query) const
{
Log::Assert(query.n_elem == maxVals.n_elem);
@@ -680,8 +641,8 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
// Index the buckets for possible usage later.
-template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
+template <typename MatType, typename TagType>
+TagType DTree<MatType, TagType>::TagTree(const TagType& tag)
{
if (subtreeLeaves == 1)
{
@@ -696,8 +657,8 @@ TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
}
-template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
+template <typename MatType, typename TagType>
+TagType DTree<MatType, TagType>::FindBucket(const VecType& query) const
{
Log::Assert(query.n_elem == maxVals.n_elem);
@@ -712,8 +673,8 @@ TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
}
}
-template <typename MatType, typename VecType, typename TagType>
-void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& importances) const
+template <typename MatType, typename TagType>
+void DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
{
// Clear and set to right size.
importances.zeros(maxVals.n_elem);
@@ -740,9 +701,9 @@ void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& impo
}
}
-template <typename MatType, typename VecType, typename TagType>
+template <typename MatType, typename TagType>
template <typename Archive>
-void DTree<MatType, VecType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
+void DTree<MatType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
{
using data::CreateNVP;
diff --git a/src/mlpack/tests/det_test.cpp b/src/mlpack/tests/det_test.cpp
index 2b0ef37..3365984 100644
--- a/src/mlpack/tests/det_test.cpp
+++ b/src/mlpack/tests/det_test.cpp
@@ -42,7 +42,7 @@ BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<arma::mat, arma::vec> tree(testData);
+ DTree<arma::mat> tree(testData);
BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(TestComputeNodeError)
arma::vec maxVals("7 7 8");
arma::vec minVals("3 0 1");
- DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
+ DTree<arma::mat> testDTree(maxVals, minVals, 5);
double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(TestWithinRange)
arma::vec maxVals("7 7 8");
arma::vec minVals("3 0 1");
- DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
+ DTree<arma::mat> testDTree(maxVals, minVals, 5);
arma::vec testQuery(3);
testQuery << 4.5 << 2.5 << 2;
@@ -95,7 +95,7 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
size_t obDim, trueDim;
double trueLeftError, obLeftError, trueRightError, obRightError, obSplit, trueSplit;
@@ -123,7 +123,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
arma::Col<size_t> oTest(5);
oTest << 1 << 2 << 3 << 4 << 5;
@@ -166,7 +166,7 @@ BOOST_AUTO_TEST_CASE(TestGrow)
rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
BOOST_REQUIRE_EQUAL(oTest[0], 0);
@@ -219,7 +219,7 @@ BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
@@ -252,7 +252,7 @@ BOOST_AUTO_TEST_CASE(TestComputeValue)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
@@ -295,7 +295,7 @@ BOOST_AUTO_TEST_CASE(TestVariableImportance)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree<arma::mat, arma::vec> testDTree(testData);
+ DTree<arma::mat> testDTree(testData);
testDTree.Grow(testData, oTest, false, 2, 1);
arma::vec imps;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index d7ea76e..64769df 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -853,7 +853,7 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
BOOST_AUTO_TEST_CASE(DETTest)
{
using det::DTree;
- typedef DTree<arma::mat, arma::vec> DTreeX;
+ typedef DTree<arma::mat> DTreeX;
// Create a density estimation tree on a random dataset.
arma::mat dataset = arma::randu<arma::mat>(25, 5000);
More information about the mlpack-git
mailing list