[mlpack-git] master: - More sparse-matrix migration steps. (e4a9be0)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Nov 1 15:22:53 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit e4a9be06c6733590d896bdbdb1b63d30b67ef4f7
Author: theJonan <ivan at jonan.info>
Date: Mon Oct 17 14:30:48 2016 +0300
- More sparse-matrix migration steps.
>---------------------------------------------------------------
e4a9be06c6733590d896bdbdb1b63d30b67ef4f7
src/mlpack/methods/det/dtree.hpp | 4 +-
src/mlpack/methods/det/dtree_impl.hpp | 100 ++++++++++++++++++++++++++++++----
2 files changed, 92 insertions(+), 12 deletions(-)
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index 6234287..9a2b175 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -199,9 +199,9 @@ class DTree
size_t end;
//! Upper half of bounding box for this node.
- arma::vec maxVals;
+ VecType maxVals;
//! Lower half of bounding box for this node.
- arma::vec minVals;
+ VecType minVals;
//! The splitting dimension for this node.
size_t splitDim;
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index b456024..b683ffd 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -8,6 +8,8 @@
*/
#include "dtree.hpp"
#include <stack>
+#include <vector>
+#include <algorithm>
#include <mlpack/core/tree/perform_split.hpp>
using namespace mlpack;
@@ -32,6 +34,68 @@ namespace detail
ElemType splitVal;
size_t splitDimension;
};
+
+ /**
+ * Get the values for the dimension and sort them. The old implementation:
+ * dimVec = data.row(dim).subvec(start, end - 1);
+ * dimVec = arma::sort(dimVec);
+ * was quite inefficient, due to many (3) vector copy operations. This could be a
+ * problem especially for sparse matrices. That's why they have custom implementation.
+ */
+ template <typename MatType>
+ typename MatType::row_type ExtractSortedRow(const MatType& data,
+ size_t dim,
+ size_t start,
+ size_t end)
+ {
+ typedef typename MatType::elem_type ElemType;
+
+ assert(start < end);
+
+ typename MatType::row_type dimVec = data(dim, arma::span(start, end - 1));
+ std::sort(dimVec.begin(), dimVec.end());
+ return dimVec;
+ }
+
+// template <typename ElemType>
+// std::vector<SortedItem<ElemType> > ExtractSortedRow(const arma::SpMat<ElemType>& data,
+// size_t dim,
+// size_t start,
+// size_t end,
+// size_t padding)
+// {
+// typedef SortedItem<ElemType> SortedType;
+//
+// assert(padding > 0);
+// assert(start < end);
+//
+// arma::SpRow<ElemType> dimVec = data(dim, arma::span(start, end - 1));
+// typename arma::SpRow<ElemType>::iterator dimVecEnd = dimVec.end();
+//
+// // Build the vector to be sorted with values in ascending order.
+// std::vector<SortedType> sortedDim = std::vector<SortedType>();
+// sortedDim.reserve(dimVec.n_elem / 2);
+//
+// // Prepare these for the iteration.
+// ElemType lastVal = 0;
+// --padding;
+//
+// // Iterate over the row and put only different values, also skipping
+// // `padding` number of elements from both sides of the row.
+// for (typename MatType::row_col_iterator di = dimVec.begin_row_col(); di != dimVecEnd; ++di)
+// {
+// if (di.col() < padding || di.col() >= dimVec.n_elem - padding)
+// continue;
+// // if (*di == lastVal && sortedDim.size() > 0)
+// // continue;
+//
+// sortedDim.push_back(lastVal = *di);
+// }
+//
+// std::sort(sortedDim.begin(), sortedDim.end());
+// return sortedDim;
+// }
+
};
template <typename MatType, typename TagType>
@@ -240,21 +304,37 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
// Find the log volume of all the other dimensions.
double volumeWithoutDim = logVolume - std::log(max - min);
- // Get the values for the dimension.
- typename MatType::row_type dimVec = data.row(dim).subvec(start, end - 1);
+ // Get a sorted version of the dimension in interest,
+ // from the given samples range. This is the most expensive step.
+ typename MatType::row_type dimVec = detail::ExtractSortedRow(data,
+ dim,
+ start,
+ end);
- // Sort the values in ascending order.
- dimVec = arma::sort(dimVec);
+ typename MatType::row_col_iterator dimVecEnd = dimVec.end_row_col();
+ typename MatType::row_col_iterator dI = dimVec.begin_row_col();
// Find the best split for this dimension. We need to figure out why
// there are spikes if this minLeafSize is enforced here...
- for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+ for (;;)
{
+ const size_t position = dI.col();
+
+ if (position >= dimVec.n_cols - minLeafSize)
+ break;
+ if (position < minLeafSize - 1)
+ continue;
+
+ ElemType split = *dI;
+ if (++dI == dimVecEnd)
+ break; // This means we have same values till the end => No split.
+
// This makes sense for real continuous data. This kinda corrupts the
// data and estimation if the data is ordinal.
- const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+ split += *dI;
+ split /= 2.0;
- if (split == dimVec[i])
+ if (split == *dI)
continue; // We can't split here (two points are the same).
// Another way of picking split is using this:
@@ -263,7 +343,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
{
// Ensure that the right node will have at least the minimum number of
// points.
- Log::Assert((points - i - 1) >= minLeafSize);
+ Log::Assert((points - position - 1) >= minLeafSize);
// Now we have to see if the error will be reduced. Simple manipulation
// of the error function gives us the condition we must satisfy:
@@ -271,8 +351,8 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
// and because the volume is only dependent on the dimension we are
// splitting, we can assume V_l is just the range of the left and V_r is
// just the range of the right.
- double negLeftError = std::pow(i + 1, 2.0) / (split - min);
- double negRightError = std::pow(points - i - 1, 2.0) / (max - split);
+ double negLeftError = std::pow(position + 1, 2.0) / (split - min);
+ double negRightError = std::pow(points - position - 1, 2.0) / (max - split);
// If this is better, take it.
if ((negLeftError + negRightError) >= minDimError)
More information about the mlpack-git
mailing list