[mlpack-git] master: - Template fixes for ExtractSplits. (10812ca)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Oct 20 06:28:17 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit 10812ca06558e96878ef301db268e0bcd5883610
Author: theJonan <ivan at jonan.info>
Date: Thu Oct 20 13:28:17 2016 +0300
- Template fixes for ExtractSplits.
>---------------------------------------------------------------
10812ca06558e96878ef301db268e0bcd5883610
src/mlpack/methods/det/dtree_impl.hpp | 107 ++++++++++++++++++++++------------
1 file changed, 71 insertions(+), 36 deletions(-)
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 7c86899..a3e1567 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -20,35 +20,66 @@ namespace details
{
/**
* This one sorts and scand the given per-dimension extract and puts all splits
- * in a vector, that can easily be iterated afterwards.
+ * in a vector, that can easily be iterated afterwards. General implementation.
*/
- template <typename MatType>
- void ExtractSplits(std::vector<
- std::pair<typename MatType::elem_type, size_t>>& splitVec,
+ template <typename ElemType, typename MatType>
+ void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
const MatType& data,
size_t dim,
- size_t start,
- size_t end,
- size_t minLeafSize)
+ const size_t start,
+ const size_t end,
+ const size_t minLeafSize)
{
- typedef typename MatType::elem_type ElemType;
- typedef std::pair<ElemType, size_t> SplitItem;
- typename MatType::row_type dimVec = data(dim, arma::span(start, end - 1));
+ static_assert(
+ std::is_same<typename MatType::elem_type, ElemType>::value == true,
+ "The ElemType does not correspond to the matrix's element type."
+ );
- // We sort these, in-place (it's a copy of the data, anyways).
- std::sort(dimVec.begin(), dimVec.end());
+ typedef std::pair<ElemType, size_t> SplitItem;
+ const typename MatType::row_type dimVec =
+ arma::sort(data(dim, arma::span(start, end - 1)));
// Ensure the minimum leaf size on both sides. We need to figure out why
// there are spikes if this minLeafSize is enforced here...
for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
{
- // This makes sense for real continuous data. This kinda corrupts the
- // data and estimation if the data is ordinal.
+ // This makes sense for real continuous data. This kinda corrupts the
+ // data and estimation if the data is ordinal. Potentially we can fix
+ // that by taking into account ordinality later in the min/max update,
+ // but then we can end-up with a zero-volumed dimension. No good.
const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
// Check if we can split here (two points are different)
if (split != dimVec[i])
- splitVec.push_back(SplitItem(split, i));
+ splitVec.push_back(SplitItem(split, i + 1));
+ }
+ }
+
+ // Now the custom arma::Mat implementation
+ template <typename ElemType>
+ void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
+ const arma::Mat<ElemType>& data,
+ size_t dim,
+ const size_t start,
+ const size_t end,
+ const size_t minLeafSize)
+ {
+ typedef std::pair<ElemType, size_t> SplitItem;
+ arma::vec dimVec = data(dim, arma::span(start, end - 1));
+
+ // We sort these, in-place (it's a copy of the data, anyways).
+ std::sort(dimVec.begin(), dimVec.end());
+
+ for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+ {
+ // This makes sense for real continuous data. This kinda corrupts the
+ // data and estimation if the data is ordinal. Potentially we can fix
+ // that by taking into account ordinality later in the min/max update,
+ // but then we can end-up with a zero-volumed dimension. No good.
+ const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+
+ if (split != dimVec[i])
+ splitVec.push_back(SplitItem(split, i + 1));
}
}
@@ -57,10 +88,13 @@ namespace details
void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
const arma::SpMat<ElemType>& data,
size_t dim,
- size_t start,
- size_t end,
- size_t minLeafSize)
+ const size_t start,
+ const size_t end,
+ const size_t minLeafSize)
{
+ // It's common sense, but we also use it in a check later.
+ Log::Assert(minLeafSize > 0);
+
typedef std::pair<ElemType, size_t> SplitItem;
const size_t n_elem = end - start;
@@ -73,8 +107,9 @@ namespace details
// Now iterate over the values, taking account for the over-the-zeroes
// jump and construct the splits vector.
+ const size_t zeroes = n_elem - valsVec.size();
ElemType lastVal = -std::numeric_limits<ElemType>::max();
- size_t padding = 0, zeroes = n_elem - valsVec.size();
+ size_t padding = 0;
for (size_t i = 0; i < valsVec.size(); ++i)
{
@@ -82,10 +117,10 @@ namespace details
if (lastVal < ElemType(0) && newVal > ElemType(0) && zeroes > 0)
{
Log::Assert(padding == 0); // we should arrive here once!
- if (lastVal >= valsVec[0] && // i.e. we're not in the beginning
- i >= minLeafSize &&
- i <= n_elem - minLeafSize)
- splitVec.push_back(SplitItem(lastVal / 2.0, i - 1));
+
+ // the minLeafSize > 0 also guarantees we're not entering right at the start.
+ if (i >= minLeafSize && i <= n_elem - minLeafSize)
+ splitVec.push_back(SplitItem(lastVal / 2.0, i));
padding = zeroes;
lastVal = ElemType(0);
@@ -95,12 +130,14 @@ namespace details
if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
{
// This makes sense for real continuous data. This kinda corrupts the
- // data and estimation if the data is ordinal.
+ // data and estimation if the data is ordinal. Potentially we can fix
+ // that by taking into account ordinality later in the min/max update,
+ // but then we can end-up with a zero-volumed dimension. No good.
const ElemType split = (lastVal + newVal) / 2.0;
// Check if we can split here (two points are different)
if (split != newVal)
- splitVec.push_back(SplitItem(split, i + padding - 1));
+ splitVec.push_back(SplitItem(split, i + padding));
}
lastVal = newVal;
@@ -287,7 +324,10 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
if (max - min == 0.0)
continue; // Skip to next dimension.
- // Initializing all the stuff for this dimension.
+ // Find the log volume of all the other dimensions.
+ const double volumeWithoutDim = logVolume - std::log(max - min);
+
+ // Initializing all other stuff for this dimension.
bool dimSplitFound = false;
// Take an error estimate for this dimension.
double minDimError = std::pow(points, 2.0) / (max - min);
@@ -295,9 +335,6 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
double dimRightError = 0.0; // always be set to something else before use.
ElemType dimSplitValue = 0.0;
- // Find the log volume of all the other dimensions.
- double volumeWithoutDim = logVolume - std::log(max - min);
-
// Get the values for splitting. The old implementation:
// dimVec = data.row(dim).subvec(start, end - 1);
// dimVec = arma::sort(dimVec);
@@ -305,7 +342,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
// This one has custom implementation for dense and sparse matrices.
std::vector<SplitItem> splitVec;
- details::ExtractSplits(splitVec, data, dim, start, end, minLeafSize);
+ details::ExtractSplits<ElemType>(splitVec, data, dim, start, end, minLeafSize);
// Iterate on all the splits for this dimension
for (typename std::vector<SplitItem>::iterator i = splitVec.begin();
@@ -321,7 +358,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
{
// Ensure that the right node will have at least the minimum number of
// points.
- Log::Assert((points - position - 1) >= minLeafSize);
+ Log::Assert((points - position) >= minLeafSize);
// Now we have to see if the error will be reduced. Simple manipulation
// of the error function gives us the condition we must satisfy:
@@ -329,10 +366,8 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
// and because the volume is only dependent on the dimension we are
// splitting, we can assume V_l is just the range of the left and V_r is
// just the range of the right.
- double negLeftError = std::pow(position + 1, 2.0)
- / (split - min);
- double negRightError = std::pow(points - position - 1, 2.0)
- / (max - split);
+ double negLeftError = std::pow(position, 2.0) / (split - min);
+ double negRightError = std::pow(points - position, 2.0) / (max - split);
// If this is better, take it.
if ((negLeftError + negRightError) >= minDimError)
@@ -346,7 +381,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
}
}
- double actualMinDimError = std::log(minDimError)
+ const double actualMinDimError = std::log(minDimError)
- 2 * std::log((double) data.n_cols)
- volumeWithoutDim;
More information about the mlpack-git
mailing list