[mlpack-git] [mlpack/mlpack] Density Estimation Tree made sparse-enabled (#802)
Ryan Curtin
notifications at github.com
Wed Oct 19 17:36:32 EDT 2016
rcurtin commented on this pull request.
This looks great to me; thank you for taking the time to make these changes. This will be a nice improvement to mlpack's DET implementation. I have a few comments, so let me know what you think and we can go from there.
> @@ -12,6 +12,15 @@
template<typename Archive>
void serialize(Archive& ar, const unsigned int version);
+/**
+ * These will help us refer the proper vector / column types, only with
+ * specifying the matrix type we want to use.
+ */
+
+typedef Col<elem_type> vec_type;
+typedef Col<elem_type> col_type;
+typedef Row<elem_type> row_type;
This is a nice idea, and we should consider submitting something like this upstream, or at least starting a discussion with the Armadillo maintainer.
> prunedSequence.push_back(treeSeq);
oldAlpha = alpha;
alpha = dtree.PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
// Some sanity checks. It seems that on some datasets, the error does not
// increase as the tree is pruned but instead stays the same---hence the
// "<=" in the final assert.
- Log::Assert((alpha < std::numeric_limits<double>::max()) ||
- (dtree.SubtreeLeaves() == 1));
+ Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree.SubtreeLeaves() == 1));
This line is over 80 characters, we should wrap it in accordance with the style guide:
https://github.com/mlpack/mlpack/wiki/DesignGuidelines
I think there are other lines that are too long now too.
> cvDTree.PruneAndUpdate(cvOldAlpha, train.n_cols, useVolumeReg);
}
// Compute test values for this state of the tree.
double cvVal = 0.0;
for (size_t i = 0; i < test.n_cols; ++i)
{
- arma::vec testPoint = test.unsafe_col(i);
+ typename MatType::vec_type testPoint = test.unsafe_col(i);
cvVal += cvDTree.ComputeValue(testPoint);
Can we do `cvDTree.ComputeValue(test.col(i))` here? It would probably require templatizing `ComputeValue()` to accept arbitrary vector types. My concern is that sparse datasets don't have the `unsafe_col()` method.
>
const size_t points = end - start;
double minError = logNegError;
bool splitFound = false;
// Loop through each dimension.
- for (size_t dim = 0; dim < maxVals.n_elem; dim++)
+#ifdef _WIN32
+ #pragma omp parallel for default(shared)
+ for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
+#else
+ #pragma omp parallel for default(shared)
+ for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
+#endif
{
// Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
// think of how to do that...
We can remove this comment now, I think. This doesn't really handle nominal data but with your refactoring it does handle real and integer data. Handling nominal data in density estimation trees is not something that I think Pari's paper even talked about (although the extension should be straightforward... kind of) so I don't think we need to worry about that.
> if ((actualMinDimError > minError) && dimSplitFound)
{
- // Calculate actual error (in logspace) by adding terms back to our
- // estimate.
- minError = actualMinDimError;
- splitDim = dim;
- splitValue = dimSplitValue;
- leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
- rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
- splitFound = true;
+ {
Why the extra braces?
> - dimVec = arma::sort(dimVec);
-
- // Find the best split for this dimension. We need to figure out why
- // there are spikes if this minLeafSize is enforced here...
- for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+ // Get the values for splitting. The old implementation:
+ // dimVec = data.row(dim).subvec(start, end - 1);
+ // dimVec = arma::sort(dimVec);
+ // could be quite inefficient for sparse matrices, due to copy operations (3).
+ // This one has custom implementation for dense and sparse matrices.
+
+ std::vector<SplitItem> splitVec = details::ExtractSplits(data,
+ dim,
+ start,
+ end,
+ minLeafSize);
As far as I can tell the reason for the `ExtractSplits` function is because the `sort()` method is not available for sparse matrices. Suppose that `sort()` did exist for sparse matrices (e.g. suppose I sat down and wrote it, which I might need to do shortly!). Then we could do this...
```
typename MatType::row_type dimVec = data.row(dim).subvec(start, end - 1);
dimVec = arma::sort(dimVec);
// Iterate over all possible values.
typename MatType::row_type::const_row_col_iterator it;
for (it = dimVec.begin_row_col(); ++it; it != dimVec.end_row_col())
{
// Check the split to the left side of the point that *it represents, if it exists.
if (it->col() > 0)
{
// do checking for split between dimVec[it->col() - 1] and dimVec[it->col()]...
}
// If we are in the next-to-last position, check the split to the right, if applicable.
// There's probably a cleaner way to write this code.
typename MatType::row_type::const_row_col_iterator it2 = it;
if ((++it2) == dimVec.end_row_col())
{
// do checking for split between dimVec[it->col()] and dimVec[it->col() + 1]...
}
}
```
Note that the `row_col_iterator` will only "stop" at points that are actually represented in memory. So for sparse matrices it will skip over zero elements. I think that the `row_col_iterator` is not actually documented by Armadillo... I should submit a patch for that...
What do you think? Would this approach work? If so I will implement `SpMat::sort()` (it should be pretty straightforward, I think I have a good idea). That would allow us to avoid having specific code for both the dense and sparse case (I like to push specific code like that to Armadillo wherever possible).
--
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/mlpack/mlpack/pull/802#pullrequestreview-4957462
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://mailman.cc.gatech.edu/pipermail/mlpack-git/attachments/20161019/fe0caf38/attachment-0001.html>
More information about the mlpack-git
mailing list