[mlpack-git] master: - Fixes, based on PR's comments. - The fastest sparse-matrix utilization so far. (a60ae8a)

gitdub at mlpack.org gitdub at mlpack.org
Wed Oct 19 18:07:35 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit a60ae8a4022ce1154c40b4f321bb7cab28a4663e
Author: theJonan <ivan at jonan.info>
Date:   Thu Oct 20 01:07:35 2016 +0300

    - Fixes, based on PR's comments.
    - The fastest sparse-matrix utilization so far.


>---------------------------------------------------------------

a60ae8a4022ce1154c40b4f321bb7cab28a4663e
 src/mlpack/methods/det/CMakeLists.txt    |  4 +++
 src/mlpack/methods/det/dt_utils_impl.hpp | 18 +++++++---
 src/mlpack/methods/det/dtree_impl.hpp    | 62 +++++++++++++++++++-------------
 3 files changed, 55 insertions(+), 29 deletions(-)

diff --git a/src/mlpack/methods/det/CMakeLists.txt b/src/mlpack/methods/det/CMakeLists.txt
index 4dd3bc3..ced69a9 100644
--- a/src/mlpack/methods/det/CMakeLists.txt
+++ b/src/mlpack/methods/det/CMakeLists.txt
@@ -5,6 +5,10 @@ set(SOURCES
   # the DET class
   dtree.hpp
   dtree_impl.hpp
+
+  # Utility files
+  dt_utils.hpp
+  dt_utils_impl.hpp
 )
 
 # add directory name to sources
diff --git a/src/mlpack/methods/det/dt_utils_impl.hpp b/src/mlpack/methods/det/dt_utils_impl.hpp
index 6a798c4..191bce1 100644
--- a/src/mlpack/methods/det/dt_utils_impl.hpp
+++ b/src/mlpack/methods/det/dt_utils_impl.hpp
@@ -158,7 +158,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
     // Some sanity checks.  It seems that on some datasets, the error does not
     // increase as the tree is pruned but instead stays the same---hence the
     // "<=" in the final assert.
-    Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree.SubtreeLeaves() == 1));
+    Log::Assert((alpha < std::numeric_limits<double>::max())
+                || (dtree.SubtreeLeaves() == 1));
     Log::Assert(alpha > oldAlpha);
     Log::Assert(dtree.SubtreeLeavesLogNegError() <= treeSeq.second);
   }
@@ -191,7 +192,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
   {
     // Break up data into train and test sets.
     const size_t start = fold * testSize;
-    const size_t end = std::min((size_t) (fold + 1) * testSize, (size_t) cvData.n_cols);
+    const size_t end = std::min((size_t) (fold + 1)
+                                * testSize, (size_t) cvData.n_cols);
 
     MatType test = cvData.cols(start, end - 1);
     MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
@@ -242,7 +244,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
       cvRegularizationConstants[i] += 2.0 * cvVal / (double) cvData.n_cols;
 
       // Determine the new alpha value and prune accordingly.
-      double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first + prunedSequence[i + 2].first);
+      double cvOldAlpha = 0.5 * (prunedSequence[i + 1].first
+                                 + prunedSequence[i + 2].first);
       cvDTree.PruneAndUpdate(cvOldAlpha, train.n_cols, useVolumeReg);
     }
 
@@ -255,7 +258,8 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
     }
 
     if (prunedSequence.size() > 2)
-      cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal / (double) cvData.n_cols;
+      cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal
+        / (double) cvData.n_cols;
 
     #pragma omp critical (DTreeCVUpdate)
     regularizationConstants += cvRegularizationConstants;
@@ -293,7 +297,11 @@ DTree<MatType, TagType>* mlpack::det::Trainer(MatType& dataset,
 
   // Grow the tree.
   oldAlpha = -DBL_MAX;
-  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize);
+  alpha = dtreeOpt->Grow(newDataset,
+                         oldFromNew,
+                         useVolumeReg,
+                         maxLeafSize,
+                         minLeafSize);
 
   // Prune with optimal alpha.
   while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 1d2cb76..7c86899 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -23,7 +23,8 @@ namespace details
    * in a vector, that can easily be iterated afterwards.
    */
   template <typename MatType>
-  void ExtractSplits(std::vector<std::pair<typename MatType::elem_type, size_t>>& splitVec,
+  void ExtractSplits(std::vector<
+                      std::pair<typename MatType::elem_type, size_t>>& splitVec,
                      const MatType& data,
                      size_t dim,
                      size_t start,
@@ -90,7 +91,8 @@ namespace details
         lastVal = ElemType(0);
       }
       
-      if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)// the normal case
+      // the normal case
+      if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
       {
         // This makes sense for real continuous data.  This kinda corrupts the
         // data and estimation if the data is ordinal.
@@ -278,8 +280,6 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
   for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
 #endif
   {
-    // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
-    // think of how to do that...
     const ElemType min = minVals[dim];
     const ElemType max = maxVals[dim];
 
@@ -329,8 +329,10 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
         // and because the volume is only dependent on the dimension we are
         // splitting, we can assume V_l is just the range of the left and V_r is
         // just the range of the right.
-        double negLeftError = std::pow(position + 1, 2.0) / (split - min);
-        double negRightError = std::pow(points - position - 1, 2.0) / (max - split);
+        double negLeftError = std::pow(position + 1, 2.0)
+          / (split - min);
+        double negRightError = std::pow(points - position - 1, 2.0)
+          / (max - split);
 
         // If this is better, take it.
         if ((negLeftError + negRightError) >= minDimError)
@@ -344,21 +346,23 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
       }
     }
 
-    double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+    double actualMinDimError = std::log(minDimError)
+      - 2 * std::log((double) data.n_cols)
+      - volumeWithoutDim;
 
 #pragma omp critical (DTreeFindUpdate)
     if ((actualMinDimError > minError) && dimSplitFound)
     {
-      {
-        // Calculate actual error (in logspace) by adding terms back to our
-        // estimate.
-        minError = actualMinDimError;
-        splitDim = dim;
-        splitValue = dimSplitValue;
-        leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-        rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-        splitFound = true;
-      }
+      // Calculate actual error (in logspace) by adding terms back to our
+      // estimate.
+      minError = actualMinDimError;
+      splitDim = dim;
+      splitValue = dimSplitValue;
+      leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
+        - volumeWithoutDim;
+      rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
+        - volumeWithoutDim;
+      splitFound = true;
     } // end if better split found in this dimension.
   }
 
@@ -451,8 +455,10 @@ double DTree<MatType, TagType>::Grow(MatType& data,
       left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
       right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
 
-      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
-      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
+      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+                         minLeafSize);
+      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+                           minLeafSize);
 
       // Store values of R(T~) and |T~|.
       subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
@@ -517,12 +523,15 @@ double DTree<MatType, TagType>::Grow(MatType& data,
 
     if (right->SubtreeLeaves() > 1)
     {
-      const double exponent = 2 * std::log((double) data.n_cols) + logVolume + right->AlphaUpper();
+      const double exponent = 2 * std::log((double) data.n_cols)
+        + logVolume
+        + right->AlphaUpper();
 
       tmpAlphaSum += std::exp(exponent);
     }
 
-    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols) - logVolume;
+    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
+      - logVolume;
 
     double gT;
     if (useVolReg)
@@ -689,7 +698,9 @@ double DTree<MatType, TagType>::ComputeValue(const VecType& query) const
   else
   {
     // Return either of the two children - left or right, depending on the splitValue
-    return (query[splitDim] <= splitValue) ? left->ComputeValue(query) : right->ComputeValue(query);
+    return (query[splitDim] <= splitValue) ?
+      left->ComputeValue(query) :
+      right->ComputeValue(query);
   }
 
   return 0.0;
@@ -725,12 +736,15 @@ TagType DTree<MatType, TagType>::FindBucket(const VecType& query) const
   else
   {
     // Return the tag from either of the two children - left or right.
-    return (query[splitDim] <= splitValue) ? left->FindBucket(query) : right->FindBucket(query);
+    return (query[splitDim] <= splitValue) ?
+      left->FindBucket(query) :
+      right->FindBucket(query);
   }
 }
 
 template <typename MatType, typename TagType>
-void DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
+void
+DTree<MatType, TagType>::ComputeVariableImportance(arma::vec& importances) const
 {
   // Clear and set to right size.
   importances.zeros(maxVals.n_elem);




More information about the mlpack-git mailing list