[mlpack-svn] r13211 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 11 18:33:31 EDT 2012


Author: rcurtin
Date: 2012-07-11 18:33:31 -0400 (Wed, 11 Jul 2012)
New Revision: 13211

Modified:
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
First shot at overhauling FindSplit_().  Break deep nesting of if/for by
reversing if conditions.  Use negative log error instead of actual error to help
combat overflow issues.  Use double instead of cT (long double).  Change default
eT to double.


Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-11 22:21:40 UTC (rev 13210)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-11 22:33:31 UTC (rev 13211)
@@ -54,7 +54,7 @@
  * }
  * @endcode
  */
-template<typename eT = float,
+template<typename eT = double,
          typename cT = long double>
 class DTree
 {
@@ -147,7 +147,7 @@
 
   inline double LogNegativeError(size_t total_points);
 
-  bool FindSplit_(MatType* data,
+  bool FindSplit_(const arma::mat& data,
                   size_t* split_dim,
                   size_t* split_ind,
                   cT* left_error,

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-11 22:21:40 UTC (rev 13210)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-11 22:33:31 UTC (rev 13211)
@@ -26,11 +26,11 @@
          arma::accu(arma::log((*max_vals_) - (*min_vals_)));
 }
 
-// This function finds the best split with respect to the L2-error,
-// but trying all possible splits.  The dataset is the full data set but the
-// start_ and end_ are used to obtain the point in this node.
+// This function finds the best split with respect to the L2-error, by trying
+// all possible splits.  The dataset is the full data set but the start_ and
+// end_ are used to obtain the point in this node.
 template<typename eT, typename cT>
-bool DTree<eT, cT>::FindSplit_(MatType* data,
+bool DTree<eT, cT>::FindSplit_(const arma::mat& data,
                                size_t *split_dim,
                                size_t *split_ind,
                                cT *left_error,
@@ -38,12 +38,14 @@
                                size_t maxLeafSize,
                                size_t minLeafSize)
 {
-  assert(data->n_rows == max_vals_->n_elem);
-  assert(data->n_rows == min_vals_->n_elem);
+  // Ensure the dimensionality of the data is the same as the dimensionality of
+  // the bounding rectangle.
+  assert(data.n_rows == max_vals_->n_elem);
+  assert(data.n_rows == min_vals_->n_elem);
 
-  size_t total_n = data->n_cols, n_t = end_ - start_;
+  const size_t n_t = end_ - start_;
 
-  cT min_error = error_;
+  double min_error = std::log(-error_);
   bool some_split_found = false;
   size_t point_mass_in_dim = 0;
 
@@ -52,118 +54,116 @@
   {
     // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
     // think of how to do that...
-    eT min = (*min_vals_)[dim], max = (*max_vals_)[dim];
+    const double min = (*min_vals_)[dim];
+    const double max = (*max_vals_)[dim];
 
-    // Check if there is any scope of splitting in this dimension.
-    if (max - min > 0.0) {
-      // Initializing all the stuff for this dimension.
-      bool dim_split_found = false;
-      cT min_dim_error = min_error, temp_lval = 0.0, temp_rval = 0.0;
-      size_t dim_split_ind = -1;
+    // If there is nothing to split in this dimension, move on.
+    if (max - min == 0.0)
+    {
+      ++point_mass_in_dim;
+      continue; // Skip to next dimension.
+    }
 
-      cT log_range_all_not_dim = 0;
-      for (size_t i = 0; i < max_vals_->n_elem; i++)
+    // Initializing all the stuff for this dimension.
+    bool dim_split_found = false;
+    // Take an error estimate for this dimension.
+    double min_dim_error = n_t / (max - min);
+    double temp_lval = 0.0;
+    double temp_rval = 0.0;
+    size_t dim_split_ind = -1;
+
+    // Find the log volume of all the other dimensions.
+    double log_range_all_not_dim = 0;
+    for (size_t i = 0; i < max_vals_->n_elem; ++i)
+    {
+      if (((*max_vals_)[i] - (*min_vals_)[i] > 0.0) && (i != dim))
       {
-        if ((*max_vals_)[i] -(*min_vals_)[i] > 0.0 && i != dim)
-        {
-          log_range_all_not_dim +=
-              (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
-        }
+        log_range_all_not_dim += std::log((*max_vals_)[i] - (*min_vals_)[i]);
       }
+    }
 
-      assert(std::exp(log_range_all_not_dim) > 0);
+    // Get the values for the dimension.
+    arma::rowvec dim_val_vec = data.row(dim).subvec(start_, end_ - 1);
 
-      // Get the values for the dimension.
-      RowVecType dim_val_vec = data->row(dim).subvec(start_, end_ - 1);
+    // Sort the values in ascending order.
+    dim_val_vec = arma::sort(dim_val_vec);
 
-      // Sort the values in ascending order.
-      dim_val_vec = arma::sort(dim_val_vec);
+    // Get ready to go through the sorted list and compute error.
+    assert(dim_val_vec.n_elem > maxLeafSize);
 
-      // Get ready to go through the sorted list and compute error.
-      assert(dim_val_vec.n_elem > maxLeafSize);
+    // Enforce that the leaves have a minimum number of points to avoid
+    // spikes.  One way of doing this is to only consider splits resulting in
+    // sizes > some constant (minLeafSize).
+    size_t right_child_size;
 
-      // Enforce that the leaves have a minimum number of points to avoid
-      // spikes.  One way of doing this is to only consider splits resulting in
-      // sizes > some constant (minLeafSize).
-      size_t left_child_size = minLeafSize - 1, right_child_size;
+    // Find the best split for this dimension.  We need to figure out why
+    // there are spikes if this min_leaf_size is enforced here...
+    for (size_t i = minLeafSize - 1; i < dim_val_vec.n_elem - minLeafSize; ++i)
+    {
+      double split;
+      double lsplit = dim_val_vec[i];
+      double rsplit = dim_val_vec[i + 1];
 
-      // Find the best split for this dimension.  We need to figure out why
-      // there are spikes if this min_leaf_size is enforced here...
-      for (size_t i = minLeafSize -1; i < dim_val_vec.n_elem - minLeafSize;
-          ++i, ++left_child_size)
-      {
-        eT split, lsplit = dim_val_vec[i], rsplit = dim_val_vec[i + 1];
+      if (lsplit == rsplit)
+        continue; // We can't split here.
 
-        if (lsplit < rsplit)
-        {
-          // This makes sense for real continuous data.  This kinda corrupts the
-          // data and estimation if the data is ordinal
-          split = (lsplit + rsplit) / 2;
+      // This makes sense for real continuous data.  This kinda corrupts the
+      // data and estimation if the data is ordinal.
+      split = (lsplit + rsplit) / 2;
 
-          // Another way of picking split is using this:
-          //   split = left_split;
+      // Another way of picking split is using this:
+      //   split = left_split;
+      if ((split - min > 0.0) && (max - split > 0.0))
+      {
+        // Now we have to see if the error will be reduced.  Simple manipulation
+        // of the error function gives us the condition we must satisfy:
+        //   |t_l|^2 / V_l + |t_r|^2 / V_r  >= |t|^2 / (V_l + V_r)
+        // and because the volume is only dependent on the dimension we are
+        // splitting, we can assume V_l is just the range of the left and V_r is
+        // just the range of the right.
+        right_child_size = n_t - i - 1;
+        Log::Assert(right_child_size >= minLeafSize);
 
-          if (split - min > 0.0 && max - split > 0.0)
-          {
-            assert(std::exp(log_range_all_not_dim +
-                (cT) std::log(split - min)) > 0);
-            assert(std::exp(log_range_all_not_dim +
-                (cT) std::log(max - split)) > 0);
+        double negLeftError = std::pow(i + 1, 2.0) / (split - min);
+        double negRightError = std::pow(n_t - i - 1, 2.0) / (max - split);
 
-            cT temp_log_neg_l_error = 2 * std::log((cT) (i + 1) / (cT) total_n)
-                - (log_range_all_not_dim + (cT) std::log(split - min));
+        // If this is better, take it.
+        if ((negLeftError + negRightError) >= min_dim_error)
+        {
+          min_dim_error = negLeftError + negRightError;
+          temp_lval = negLeftError;
+          temp_rval = negRightError;
+          dim_split_ind = i;
+          dim_split_found = true;
+        }
+      }
+    }
 
-            assert(std::exp(temp_log_neg_l_error) > 0.0);
+    dim_val_vec.clear();
 
-            cT temp_l_error = -1.0 * std::exp(temp_log_neg_l_error);
+    double actualMinDimError = std::log(min_dim_error) -
+        2 * std::log(data.n_cols) - log_range_all_not_dim;
 
-            assert(std::abs(temp_l_error) < std::numeric_limits<cT>::max());
-
-            cT temp_log_neg_r_error = 2 * std::log((cT) (n_t - i - 1) /
-                (cT) total_n) - (log_range_all_not_dim +
-                (cT) std::log(max - split));
-
-            assert(std::exp(temp_log_neg_r_error) > 0.0);
-
-            right_child_size = n_t - i - 1;
-            assert(right_child_size >= minLeafSize);
-
-            cT temp_r_error = -1.0 * std::exp(temp_log_neg_r_error);
-
-            assert(std::abs(temp_r_error) < std::numeric_limits<cT>::max());
-
-            //if (temp_l + temp_r <= min_dim_error) {
-            // Why not just less than?
-            if (temp_l_error + temp_r_error < min_dim_error)
-            {
-              min_dim_error = temp_l_error + temp_r_error;
-              temp_lval = temp_l_error;
-              temp_rval = temp_r_error;
-              dim_split_ind = i;
-              dim_split_found = true;
-            } // end if improvement.
-          } // end if split - min > 0 & max - split > 0.
-        } // end if lsplit < rsplit instead of being equal.
-      } // end for loop over all splits in this dimension.
-
-      dim_val_vec.clear();
-
-      if ((min_dim_error < min_error) && dim_split_found)
-      {
-        min_error = min_dim_error;
-        *split_dim = dim;
-        *split_ind = dim_split_ind;
-        *left_error = temp_lval;
-        *right_error = temp_rval;
-        some_split_found = true;
-      } // end if better split found in this dimension.
-    }
-    else
+    if ((actualMinDimError > min_error) && dim_split_found)
     {
-      point_mass_in_dim++;
-    }
+      // Calculate actual error (in logspace) by adding terms back to our
+      // estimate.
+      min_error = actualMinDimError;
+      *split_dim = dim;
+      *split_ind = dim_split_ind;
+      *left_error = std::log(temp_lval) - 2 * std::log(data.n_cols) -
+          log_range_all_not_dim;
+      *right_error = std::log(temp_rval) - 2 * std::log(data.n_cols) -
+          log_range_all_not_dim;
+      some_split_found = true;
+    } // end if better split found in this dimension.
   }
 
+  // Map out of logspace.
+  min_error = -std::exp(min_error);
+  *left_error = -std::exp(*left_error);
+  *right_error = -std::exp(*right_error);
+
   return some_split_found;
 } // end FindSplit_
 
@@ -383,7 +383,7 @@
     // Find the split.
     size_t dim, split_ind;
     cT left_error, right_error;
-    if (FindSplit_(data, &dim, &split_ind, &left_error, &right_error,
+    if (FindSplit_(*data, &dim, &split_ind, &left_error, &right_error,
         maxLeafSize, minLeafSize))
     {
       // Move the data around for the children to have points in a node lie




More information about the mlpack-svn mailing list