[mlpack-git] master: - DTree class templated. (aa2ad99)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Oct 18 05:43:35 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit aa2ad993b56e379261ff7eecfebb666d80ee8e64
Author: theJonan <ivan at jonan.info>
Date: Thu Oct 13 16:10:43 2016 +0300
- DTree class templated.
>---------------------------------------------------------------
aa2ad993b56e379261ff7eecfebb666d80ee8e64
src/mlpack/methods/det/CMakeLists.txt | 2 +-
src/mlpack/methods/det/dtree.hpp | 98 +++++-------
.../methods/det/{dtree.cpp => dtree_impl.hpp} | 178 ++++++++++++---------
3 files changed, 141 insertions(+), 137 deletions(-)
diff --git a/src/mlpack/methods/det/CMakeLists.txt b/src/mlpack/methods/det/CMakeLists.txt
index 73b0474..66edbe6 100644
--- a/src/mlpack/methods/det/CMakeLists.txt
+++ b/src/mlpack/methods/det/CMakeLists.txt
@@ -4,7 +4,7 @@
set(SOURCES
# the DET class
dtree.hpp
- dtree.cpp
+ dtree_impl.hpp
# the util file
dt_utils.hpp
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index 8e8bbdf..f39b5bf 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -36,10 +36,18 @@ namespace det /** Density Estimation Trees */ {
* }
* @endcode
*/
+template <typename MatType,
+ typename VecType,
+ typename TagType = int>
class DTree
{
public:
/**
+ * The actual, underlying type we're working with
+ */
+ typedef typename MatType::elem_type ElemType;
+
+ /**
* Create an empty density estimation tree.
*/
DTree();
@@ -52,8 +60,8 @@ class DTree
* @param minVals Minimum values of the bounding box.
* @param totalPoints Total number of points in the dataset.
*/
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
+ DTree(const VecType& maxVals,
+ const VecType& minVals,
const size_t totalPoints);
/**
@@ -64,7 +72,7 @@ class DTree
*
* @param data Dataset to build tree on.
*/
- DTree(arma::mat& data);
+ DTree(MatType& data);
/**
* Create a child node of a density estimation tree given the bounding box
@@ -78,8 +86,8 @@ class DTree
* @param end End of points represented by this node in the data matrix.
* @param error log-negative error of this node.
*/
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
+ DTree(const VecType& maxVals,
+ const VecType& minVals,
const size_t start,
const size_t end,
const double logNegError);
@@ -95,8 +103,8 @@ class DTree
* @param start Start of points represented by this node in the data matrix.
* @param end End of points represented by this node in the data matrix.
*/
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
+ DTree(const VecType& maxVals,
+ const VecType& minVals,
const size_t totalPoints,
const size_t start,
const size_t end);
@@ -114,7 +122,7 @@ class DTree
* @param maxLeafSize Maximum size of a leaf.
* @param minLeafSize Minimum size of a leaf.
*/
- double Grow(arma::mat& data,
+ double Grow(MatType& data,
arma::Col<size_t>& oldFromNew,
const bool useVolReg = false,
const size_t maxLeafSize = 10,
@@ -137,16 +145,7 @@ class DTree
*
* @param query Point to estimate density of.
*/
- double ComputeValue(const arma::vec& query) const;
-
- /**
- * Print the tree in a depth-first manner (this function is called
- * recursively).
- *
- * @param fp File to write the tree to.
- * @param level Level of the tree (should start at 0).
- */
- void WriteTree(FILE *fp, const size_t level = 0) const;
+ double ComputeValue(const VecType& query) const;
/**
* Index the buckets for possible usage later; this results in every leaf in
@@ -155,7 +154,7 @@ class DTree
*
* @param tag Tag for the next leaf; leave at 0 for the initial call.
*/
- int TagTree(const int tag = 0);
+ TagType TagTree(const TagType tag = 0);
/**
* Return the tag of the leaf containing the query. This is useful for
@@ -163,7 +162,7 @@ class DTree
*
* @param query Query to search for.
*/
- int FindBucket(const arma::vec& query) const;
+ TagType FindBucket(const VecType& query) const;
/**
* Compute the variable importance of each dimension in the learned tree.
@@ -183,7 +182,7 @@ class DTree
/**
* Return whether a query point is within the range of this node.
*/
- bool WithinRange(const arma::vec& query) const;
+ bool WithinRange(const VecType& query) const;
private:
// The indices in the complete set of points
@@ -208,7 +207,7 @@ class DTree
size_t splitDim;
//! The split value on the splitting dimension for this node.
- double splitValue;
+ ElemType splitValue;
//! log-negative-L2-error of the node.
double logNegError;
@@ -229,7 +228,7 @@ class DTree
double logVolume;
//! The tag for the leaf, used for hashing points.
- int bucketTag;
+ TagType bucketTag;
//! Upper part of alpha sum; used for pruning.
double alphaUpper;
@@ -247,7 +246,7 @@ class DTree
//! Return the split dimension of this node.
size_t SplitDim() const { return splitDim; }
//! Return the split value of this node.
- double SplitValue() const { return splitValue; }
+ ElemType SplitValue() const { return splitValue; }
//! Return the log negative error of this node.
double LogNegError() const { return logNegError; }
//! Return the log negative error of all descendants of this node.
@@ -267,51 +266,24 @@ class DTree
bool Root() const { return root; }
//! Return the upper part of the alpha sum.
double AlphaUpper() const { return alphaUpper; }
+ //! Return the current bucket's ID, if leaf, or -1 otherwise
+ TagType BucketTag() const { return subtreeLeaves == 1 ? bucketTag : -1; }
//! Return the maximum values.
- const arma::vec& MaxVals() const { return maxVals; }
+ const VecType& MaxVals() const { return maxVals; }
//! Modify the maximum values.
- arma::vec& MaxVals() { return maxVals; }
+ VecType& MaxVals() { return maxVals; }
//! Return the minimum values.
- const arma::vec& MinVals() const { return minVals; }
+ const VecType& MinVals() const { return minVals; }
//! Modify the minimum values.
- arma::vec& MinVals() { return minVals; }
+ VecType& MinVals() { return minVals; }
/**
* Serialize the density estimation tree.
*/
template<typename Archive>
- void Serialize(Archive& ar, const unsigned int /* version */)
- {
- using data::CreateNVP;
-
- ar & CreateNVP(start, "start");
- ar & CreateNVP(end, "end");
- ar & CreateNVP(maxVals, "maxVals");
- ar & CreateNVP(minVals, "minVals");
- ar & CreateNVP(splitDim, "splitDim");
- ar & CreateNVP(splitValue, "splitValue");
- ar & CreateNVP(logNegError, "logNegError");
- ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
- ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
- ar & CreateNVP(root, "root");
- ar & CreateNVP(ratio, "ratio");
- ar & CreateNVP(logVolume, "logVolume");
- ar & CreateNVP(bucketTag, "bucketTag");
- ar & CreateNVP(alphaUpper, "alphaUpper");
-
- if (Archive::is_loading::value)
- {
- if (left)
- delete left;
- if (right)
- delete right;
- }
-
- ar & CreateNVP(left, "left");
- ar & CreateNVP(right, "right");
- }
+ void Serialize(Archive& ar, const unsigned int /* version */);
private:
@@ -320,9 +292,9 @@ class DTree
/**
* Find the dimension to split on.
*/
- bool FindSplit(const arma::mat& data,
+ bool FindSplit(const MatType& data,
size_t& splitDim,
- double& splitValue,
+ ElemType& splitValue,
double& leftError,
double& rightError,
const size_t minLeafSize = 5) const;
@@ -330,9 +302,9 @@ class DTree
/**
* Split the data, returning the number of points left of the split.
*/
- size_t SplitData(arma::mat& data,
+ size_t SplitData(MatType& data,
const size_t splitDim,
- const double splitValue,
+ const ElemType splitValue,
arma::Col<size_t>& oldFromNew) const;
};
@@ -340,4 +312,6 @@ class DTree
} // namespace det
} // namespace mlpack
+#include "dtree_impl.hpp"
+
#endif // MLPACK_METHODS_DET_DTREE_HPP
diff --git a/src/mlpack/methods/det/dtree.cpp b/src/mlpack/methods/det/dtree_impl.hpp
similarity index 77%
rename from src/mlpack/methods/det/dtree.cpp
rename to src/mlpack/methods/det/dtree_impl.hpp
index c88d480..5ef3758 100644
--- a/src/mlpack/methods/det/dtree.cpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -12,7 +12,8 @@
using namespace mlpack;
using namespace det;
-DTree::DTree() :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree() :
start(0),
end(0),
splitDim(size_t(-1)),
@@ -31,9 +32,11 @@ DTree::DTree() :
// Root node initializers
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints) :
+
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t totalPoints) :
start(0),
end(totalPoints),
maxVals(maxVals),
@@ -52,7 +55,8 @@ DTree::DTree(const arma::vec& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-DTree::DTree(arma::mat& data) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(MatType & data) :
start(0),
end(data.n_cols),
splitDim(size_t(-1)),
@@ -88,11 +92,12 @@ DTree::DTree(arma::mat& data) :
// Non-root node initializers
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t start,
- const size_t end,
- const double logNegError) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError) :
start(start),
end(end),
maxVals(maxVals),
@@ -111,11 +116,12 @@ DTree::DTree(const arma::vec& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-DTree::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end) :
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
+ const VecType& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end) :
start(start),
end(end),
maxVals(maxVals),
@@ -134,7 +140,8 @@ DTree::DTree(const arma::vec& maxVals,
right(NULL)
{ /* Nothing to do. */ }
-DTree::~DTree()
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>::~DTree()
{
delete left;
delete right;
@@ -142,7 +149,8 @@ DTree::~DTree()
// This function computes the log-l2-negative-error of a given node from the
// formula R(t) = log(|t|^2 / (N^2 V_t)).
-double DTree::LogNegativeError(const size_t totalPoints) const
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoints) const
{
// log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
double err = 2 * std::log((double) (end - start)) -
@@ -162,12 +170,13 @@ double DTree::LogNegativeError(const size_t totalPoints) const
// This function finds the best split with respect to the L2-error, by trying
// all possible splits. The dataset is the full data set but the start and
// end are used to obtain the point in this node.
-bool DTree::FindSplit(const arma::mat& data,
- size_t& splitDim,
- double& splitValue,
- double& leftError,
- double& rightError,
- const size_t minLeafSize) const
+template <typename MatType, typename VecType, typename TagType>
+bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
+ size_t& splitDim,
+ ElemType& splitValue,
+ double& leftError,
+ double& rightError,
+ const size_t minLeafSize) const
{
// Ensure the dimensionality of the data is the same as the dimensionality of
// the bounding rectangle.
@@ -180,12 +189,20 @@ bool DTree::FindSplit(const arma::mat& data,
bool splitFound = false;
// Loop through each dimension.
- for (size_t dim = 0; dim < maxVals.n_elem; dim++)
+#ifdef _WIN32
+ #pragma omp parallel for default(none) \
+ shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+ for (intmax_t dim = 0; fold < (intmax_t) maxVals.n_elem; ++dim)
+#else
+ #pragma omp parallel for default(none) \
+ shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+ for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
+#endif
{
// Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
// think of how to do that...
- const double min = minVals[dim];
- const double max = maxVals[dim];
+ const ElemType min = minVals[dim];
+ const ElemType max = maxVals[dim];
// If there is nothing to split in this dimension, move on.
if (max - min == 0.0)
@@ -197,7 +214,7 @@ bool DTree::FindSplit(const arma::mat& data,
double minDimError = std::pow(points, 2.0) / (max - min);
double dimLeftError = 0.0; // For -Wuninitialized. These variables will
double dimRightError = 0.0; // always be set to something else before use.
- double dimSplitValue = 0.0;
+ ElemType dimSplitValue = 0.0;
// Find the log volume of all the other dimensions.
double volumeWithoutDim = logVolume - std::log(max - min);
@@ -214,7 +231,7 @@ bool DTree::FindSplit(const arma::mat& data,
{
// This makes sense for real continuous data. This kinda corrupts the
// data and estimation if the data is ordinal.
- const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+ const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
if (split == dimVec[i])
continue; // We can't split here (two points are the same).
@@ -269,10 +286,11 @@ bool DTree::FindSplit(const arma::mat& data,
return splitFound;
}
-size_t DTree::SplitData(arma::mat& data,
- const size_t splitDim,
- const double splitValue,
- arma::Col<size_t>& oldFromNew) const
+template <typename MatType, typename VecType, typename TagType>
+size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
+ const size_t splitDim,
+ const double splitValue,
+ arma::Col<size_t>& oldFromNew) const
{
// Swap all columns such that any columns with value in dimension splitDim
// less than or equal to splitValue are on the left side, and all others are
@@ -303,11 +321,12 @@ size_t DTree::SplitData(arma::mat& data,
}
// Greedily expand the tree
-double DTree::Grow(arma::mat& data,
- arma::Col<size_t>& oldFromNew,
- const bool useVolReg,
- const size_t maxLeafSize,
- const size_t minLeafSize)
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::Grow(MatType& data,
+ arma::Col<size_t>& oldFromNew,
+ const bool useVolReg,
+ const size_t maxLeafSize,
+ const size_t minLeafSize)
{
Log::Assert(data.n_rows == maxVals.n_elem);
Log::Assert(data.n_rows == minVals.n_elem);
@@ -450,10 +469,10 @@ double DTree::Grow(arma::mat& data,
}
-double DTree::PruneAndUpdate(const double oldAlpha,
- const size_t points,
- const bool useVolReg)
-
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::PruneAndUpdate(const double oldAlpha,
+ const size_t points,
+ const bool useVolReg)
{
// Compute gT.
if (subtreeLeaves == 1) // If we are a leaf...
@@ -565,7 +584,8 @@ double DTree::PruneAndUpdate(const double oldAlpha,
//
// Future improvement: Open up the range with epsilons on both sides where
// epsilon depends on the density near the boundary.
-bool DTree::WithinRange(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+bool DTree<MatType, VecType, TagType>::WithinRange(const VecType& query) const
{
for (size_t i = 0; i < query.n_elem; ++i)
if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
@@ -575,7 +595,8 @@ bool DTree::WithinRange(const arma::vec& query) const
}
-double DTree::ComputeValue(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) const
{
Log::Assert(query.n_elem == maxVals.n_elem);
@@ -607,35 +628,9 @@ double DTree::ComputeValue(const arma::vec& query) const
}
-void DTree::WriteTree(FILE *fp, const size_t level) const
-{
- if (subtreeLeaves > 1)
- {
- fprintf(fp, "\n");
- for (size_t i = 0; i < level; ++i)
- fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
-
- right->WriteTree(fp, level + 1);
-
- fprintf(fp, "\n");
- for (size_t i = 0; i < level; ++i)
- fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
-
- left->WriteTree(fp, level);
- }
- else // If we are a leaf...
- {
- fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
- if (bucketTag != -1)
- fprintf(fp, " BT:%d", bucketTag);
- }
-}
-
-
// Index the buckets for possible usage later.
-int DTree::TagTree(const int tag)
+template <typename MatType, typename VecType, typename TagType>
+TagType DTree<MatType, VecType, TagType>::TagTree(const TagType tag)
{
if (subtreeLeaves == 1)
{
@@ -650,7 +645,8 @@ int DTree::TagTree(const int tag)
}
-int DTree::FindBucket(const arma::vec& query) const
+template <typename MatType, typename VecType, typename TagType>
+TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
{
Log::Assert(query.n_elem == maxVals.n_elem);
@@ -670,8 +666,8 @@ int DTree::FindBucket(const arma::vec& query) const
}
}
-
-void DTree::ComputeVariableImportance(arma::vec& importances) const
+template <typename MatType, typename VecType, typename TagType>
+void DTree<MatType, VecType, TagType>::ComputeVariableImportance(arma::vec& importances) const
{
// Clear and set to right size.
importances.zeros(maxVals.n_elem);
@@ -697,3 +693,37 @@ void DTree::ComputeVariableImportance(arma::vec& importances) const
nodes.push(curNode.Right());
}
}
+
+template <typename MatType, typename VecType, typename TagType>
+template <typename Archive>
+void DTree<MatType, VecType, TagType>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(start, "start");
+ ar & CreateNVP(end, "end");
+ ar & CreateNVP(maxVals, "maxVals");
+ ar & CreateNVP(minVals, "minVals");
+ ar & CreateNVP(splitDim, "splitDim");
+ ar & CreateNVP(splitValue, "splitValue");
+ ar & CreateNVP(logNegError, "logNegError");
+ ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
+ ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
+ ar & CreateNVP(root, "root");
+ ar & CreateNVP(ratio, "ratio");
+ ar & CreateNVP(logVolume, "logVolume");
+ ar & CreateNVP(bucketTag, "bucketTag");
+ ar & CreateNVP(alphaUpper, "alphaUpper");
+
+ if (Archive::is_loading::value)
+ {
+ if (left)
+ delete left;
+ if (right)
+ delete right;
+ }
+
+ ar & CreateNVP(left, "left");
+ ar & CreateNVP(right, "right");
+}
+
More information about the mlpack-git
mailing list