[mlpack-git] master: - Template fixes for ExtractSplits. (10812ca)

gitdub at mlpack.org gitdub at mlpack.org
Thu Oct 20 06:28:17 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea

>---------------------------------------------------------------

commit 10812ca06558e96878ef301db268e0bcd5883610
Author: theJonan <ivan at jonan.info>
Date:   Thu Oct 20 13:28:17 2016 +0300

    - Template fixes for ExtractSplits.


>---------------------------------------------------------------

10812ca06558e96878ef301db268e0bcd5883610
 src/mlpack/methods/det/dtree_impl.hpp | 107 ++++++++++++++++++++++------------
 1 file changed, 71 insertions(+), 36 deletions(-)

diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 7c86899..a3e1567 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -20,35 +20,66 @@ namespace details
 {
   /**
    * This one sorts and scand the given per-dimension extract and puts all splits
-   * in a vector, that can easily be iterated afterwards.
+   * in a vector, that can easily be iterated afterwards. General implementation.
    */
-  template <typename MatType>
-  void ExtractSplits(std::vector<
-                      std::pair<typename MatType::elem_type, size_t>>& splitVec,
+  template <typename ElemType, typename MatType>
+  void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
                      const MatType& data,
                      size_t dim,
-                     size_t start,
-                     size_t end,
-                     size_t minLeafSize)
+                     const size_t start,
+                     const size_t end,
+                     const size_t minLeafSize)
   {
-    typedef typename MatType::elem_type ElemType;
-    typedef std::pair<ElemType, size_t> SplitItem;
-    typename MatType::row_type dimVec = data(dim, arma::span(start, end - 1));
+    static_assert(
+      std::is_same<typename MatType::elem_type, ElemType>::value == true,
+      "The ElemType does not correspond to the matrix's element type."
+                  );
     
-    // We sort these, in-place (it's a copy of the data, anyways).
-    std::sort(dimVec.begin(), dimVec.end());
+    typedef std::pair<ElemType, size_t> SplitItem;
+    const typename MatType::row_type dimVec =
+      arma::sort(data(dim, arma::span(start, end - 1)));
     
     // Ensure the minimum leaf size on both sides. We need to figure out why
     // there are spikes if this minLeafSize is enforced here...
     for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
     {
-      // This makes sense for real continuous data.  This kinda corrupts the
-      // data and estimation if the data is ordinal.
+      // This makes sense for real continuous data. This kinda corrupts the
+      // data and estimation if the data is ordinal. Potentially we can fix
+      // that by taking into account ordinality later in the min/max update,
+      // but then we can end-up with a zero-volumed dimension. No good.
       const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
       
       // Check if we can split here (two points are different)
       if (split != dimVec[i])
-        splitVec.push_back(SplitItem(split, i));
+        splitVec.push_back(SplitItem(split, i + 1));
+    }
+  }
+
+  // Now the custom arma::Mat implementation
+  template <typename ElemType>
+  void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
+                     const arma::Mat<ElemType>& data,
+                     size_t dim,
+                     const size_t start,
+                     const size_t end,
+                     const size_t minLeafSize)
+  {
+    typedef std::pair<ElemType, size_t> SplitItem;
+    arma::vec dimVec = data(dim, arma::span(start, end - 1));
+    
+    // We sort these, in-place (it's a copy of the data, anyways).
+    std::sort(dimVec.begin(), dimVec.end());
+    
+    for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+    {
+      // This makes sense for real continuous data. This kinda corrupts the
+      // data and estimation if the data is ordinal. Potentially we can fix
+      // that by taking into account ordinality later in the min/max update,
+      // but then we can end-up with a zero-volumed dimension. No good.
+      const ElemType split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+      
+      if (split != dimVec[i])
+        splitVec.push_back(SplitItem(split, i + 1));
     }
   }
   
@@ -57,10 +88,13 @@ namespace details
   void ExtractSplits(std::vector<std::pair<ElemType, size_t>>& splitVec,
                      const arma::SpMat<ElemType>& data,
                      size_t dim,
-                     size_t start,
-                     size_t end,
-                     size_t minLeafSize)
+                     const size_t start,
+                     const size_t end,
+                     const size_t minLeafSize)
   {
+    // It's common sense, but we also use it in a check later.
+    Log::Assert(minLeafSize > 0);
+    
     typedef std::pair<ElemType, size_t> SplitItem;
     const size_t n_elem = end - start;
     
@@ -73,8 +107,9 @@ namespace details
 
     // Now iterate over the values, taking account for the over-the-zeroes
     // jump and construct the splits vector.
+    const size_t zeroes = n_elem - valsVec.size();
     ElemType lastVal = -std::numeric_limits<ElemType>::max();
-    size_t padding = 0, zeroes = n_elem - valsVec.size();
+    size_t padding = 0;
 
     for (size_t i = 0; i < valsVec.size(); ++i)
     {
@@ -82,10 +117,10 @@ namespace details
       if (lastVal < ElemType(0) && newVal > ElemType(0) && zeroes > 0)
       {
         Log::Assert(padding == 0); // we should arrive here once!
-        if (lastVal >= valsVec[0] && // i.e. we're not in the beginning
-            i >= minLeafSize &&
-            i <= n_elem - minLeafSize)
-          splitVec.push_back(SplitItem(lastVal / 2.0, i - 1));
+
+        // the minLeafSize > 0 also guarantees we're not entering right at the start.
+        if (i >= minLeafSize && i <= n_elem - minLeafSize)
+          splitVec.push_back(SplitItem(lastVal / 2.0, i));
 
         padding = zeroes;
         lastVal = ElemType(0);
@@ -95,12 +130,14 @@ namespace details
       if (i + padding >= minLeafSize && i + padding <= n_elem - minLeafSize)
       {
         // This makes sense for real continuous data.  This kinda corrupts the
-        // data and estimation if the data is ordinal.
+        // data and estimation if the data is ordinal. Potentially we can fix
+        // that by taking into account ordinality later in the min/max update,
+        // but then we can end-up with a zero-volumed dimension. No good.
         const ElemType split = (lastVal + newVal) / 2.0;
         
         // Check if we can split here (two points are different)
         if (split != newVal)
-          splitVec.push_back(SplitItem(split, i + padding - 1));
+          splitVec.push_back(SplitItem(split, i + padding));
       }
       
       lastVal = newVal;
@@ -287,7 +324,10 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
     if (max - min == 0.0)
       continue; // Skip to next dimension.
 
-    // Initializing all the stuff for this dimension.
+    // Find the log volume of all the other dimensions.
+    const double volumeWithoutDim = logVolume - std::log(max - min);
+    
+    // Initializing all other stuff for this dimension.
     bool dimSplitFound = false;
     // Take an error estimate for this dimension.
     double minDimError = std::pow(points, 2.0) / (max - min);
@@ -295,9 +335,6 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
     double dimRightError = 0.0; // always be set to something else before use.
     ElemType dimSplitValue = 0.0;
 
-    // Find the log volume of all the other dimensions.
-    double volumeWithoutDim = logVolume - std::log(max - min);
-
     // Get the values for splitting. The old implementation:
     //   dimVec = data.row(dim).subvec(start, end - 1);
     //   dimVec = arma::sort(dimVec);
@@ -305,7 +342,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
     // This one has custom implementation for dense and sparse matrices.
 
     std::vector<SplitItem> splitVec;
-    details::ExtractSplits(splitVec, data, dim, start, end, minLeafSize);
+    details::ExtractSplits<ElemType>(splitVec, data, dim, start, end, minLeafSize);
     
     // Iterate on all the splits for this dimension
     for (typename std::vector<SplitItem>::iterator i = splitVec.begin();
@@ -321,7 +358,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
       {
         // Ensure that the right node will have at least the minimum number of
         // points.
-        Log::Assert((points - position - 1) >= minLeafSize);
+        Log::Assert((points - position) >= minLeafSize);
 
         // Now we have to see if the error will be reduced.  Simple manipulation
         // of the error function gives us the condition we must satisfy:
@@ -329,10 +366,8 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
         // and because the volume is only dependent on the dimension we are
         // splitting, we can assume V_l is just the range of the left and V_r is
         // just the range of the right.
-        double negLeftError = std::pow(position + 1, 2.0)
-          / (split - min);
-        double negRightError = std::pow(points - position - 1, 2.0)
-          / (max - split);
+        double negLeftError = std::pow(position, 2.0) / (split - min);
+        double negRightError = std::pow(points - position, 2.0) / (max - split);
 
         // If this is better, take it.
         if ((negLeftError + negRightError) >= minDimError)
@@ -346,7 +381,7 @@ bool DTree<MatType, TagType>::FindSplit(const MatType& data,
       }
     }
 
-    double actualMinDimError = std::log(minDimError)
+    const double actualMinDimError = std::log(minDimError)
       - 2 * std::log((double) data.n_cols)
       - volumeWithoutDim;
 




More information about the mlpack-git mailing list