[mlpack-svn] r13232 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Jul 15 22:53:18 EDT 2012


Author: rcurtin
Date: 2012-07-15 22:53:18 -0400 (Sun, 15 Jul 2012)
New Revision: 13232

Modified:
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Clean up FindSplit() and reduce API to only what is necessary.


Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-16 00:57:59 UTC (rev 13231)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-16 02:53:18 UTC (rev 13232)
@@ -148,20 +148,17 @@
   inline double LogNegativeError(const size_t total_points) const;
 
   bool FindSplit(const arma::mat& data,
-                 size_t& split_dim,
-                 size_t& split_ind,
-                 double& left_error,
-                 double& right_error,
+                 size_t& splitDim,
+                 double& splitValue,
+                 double& leftError,
+                 double& rightError,
                  const size_t maxLeafSize = 10,
                  const size_t minLeafSize = 5) const;
 
-  void SplitData_(MatType* data,
-                  size_t split_dim,
-                  size_t split_ind,
-                  arma::Col<size_t>* old_from_new,
-                  eT* split_val,
-                  eT* lsplit_val,
-                  eT* rsplit_val);
+  size_t SplitData(arma::mat& data,
+                   const size_t splitDim,
+                   const double splitValue,
+                   arma::Col<size_t>& oldFromNew) const;
 
   void GetMaxMinVals_(MatType* data,
                       VecType* max_vals,

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-16 00:57:59 UTC (rev 13231)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-16 02:53:18 UTC (rev 13232)
@@ -31,7 +31,7 @@
 template<typename eT, typename cT>
 bool DTree<eT, cT>::FindSplit(const arma::mat& data,
                               size_t& splitDim,
-                              size_t& splitIndex,
+                              double& splitValue,
                               double& leftError,
                               double& rightError,
                               const size_t maxLeafSize,
@@ -63,9 +63,9 @@
     bool dimSplitFound = false;
     // Take an error estimate for this dimension.
     double minDimError = points / (max - min);
-    double dimLeftError = 0.0;
-    double dimRightError = 0.0;
-    size_t dimSplitIndex = -1;
+    double dimLeftError;
+    double dimRightError;
+    double dimSplitValue;
 
     // Find the log volume of all the other dimensions.
     double volumeWithoutDim = 0;
@@ -90,17 +90,13 @@
     // there are spikes if this min_leaf_size is enforced here...
     for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
     {
-      double split;
-      double lsplit = dimVec[i];
-      double rsplit = dimVec[i + 1];
-
-      if (lsplit == rsplit)
-        continue; // We can't split here.
-
       // This makes sense for real continuous data.  This kinda corrupts the
       // data and estimation if the data is ordinal.
-      split = (lsplit + rsplit) / 2;
+      const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
 
+      if (split == dimVec[i])
+        continue; // We can't split here (two points are the same).
+
       // Another way of picking split is using this:
       //   split = left_split;
       if ((split - min > 0.0) && (max - split > 0.0))
@@ -124,14 +120,12 @@
           minDimError = negLeftError + negRightError;
           dimLeftError = negLeftError;
           dimRightError = negRightError;
-          dimSplitIndex = i;
+          dimSplitValue = split;
           dimSplitFound = true;
         }
       }
     }
 
-    dimVec.clear();
-
     double actualMinDimError = std::log(minDimError) - 2 * std::log(data.n_cols)
         - volumeWithoutDim;
 
@@ -141,7 +135,7 @@
       // estimate.
       minError = actualMinDimError;
       splitDim = dim;
-      splitIndex = dimSplitIndex;
+      splitValue = dimSplitValue;
       leftError = std::log(dimLeftError) - 2 * std::log(data.n_cols) -
           volumeWithoutDim;
       rightError = std::log(dimRightError) - 2 * std::log(data.n_cols) -
@@ -159,59 +153,37 @@
 }
 
 template<typename eT, typename cT>
-void DTree<eT, cT>::SplitData_(MatType* data,
-                               size_t split_dim,
-                               size_t split_ind,
-                               arma::Col<size_t> *old_from_new,
-                               eT *split_val,
-                               eT *lsplit_val,
-                               eT *rsplit_val)
+size_t DTree<eT, cT>::SplitData(arma::mat& data,
+                                const size_t splitDim,
+                                const double splitValue,
+                                arma::Col<size_t>& oldFromNew) const
 {
-  // Get the values for the split dimension.
-  RowVecType dimVec = data->row(split_dim).subvec(start_, end_ - 1);
-
-  // Sort the values.
-  dimVec = arma::sort(dimVec);
-
-  *lsplit_val = dimVec[split_ind];
-  *rsplit_val = dimVec[split_ind + 1];
-  *split_val = (*lsplit_val + *rsplit_val) / 2 ;
-
-  std::vector<bool> left_membership;
-  left_membership.reserve(end_ - start_);
-
-  for (size_t i = start_; i < end_; ++i)
-  {
-    if ((*data)(split_dim, i) > *split_val)
-      left_membership[i - start_] = false;
-    else
-      left_membership[i - start_] = true;
-  }
-
-  size_t left_ind = start_, right_ind = end_ - 1;
+  // Swap all columns such that any columns with value in dimension splitDim
+  // less than or equal to splitValue are on the left side, and all others are
+  // on the right side.  A similar sort to this is also performed in
+  // BinarySpaceTree construction (its comments are more detailed).
+  size_t left = start_;
+  size_t right = end_ - 1;
   for (;;)
   {
-    while (left_membership[left_ind - start_] && (left_ind <= right_ind))
-      left_ind++;
+    while (data(splitDim, left) <= splitValue)
+      ++left;
+    while (data(splitDim, right) > splitValue)
+      --right;
 
-    while (!left_membership[right_ind - start_] && (left_ind <= right_ind))
-      right_ind--;
-
-    if (left_ind > right_ind)
+    if (left > right)
       break;
 
+    data.swap_cols(left, right);
 
-    data->swap_cols(left_ind, right_ind);
-    bool tmp = left_membership[left_ind - start_];
-    left_membership[left_ind - start_] = left_membership[right_ind - start_];
-    left_membership[right_ind - start_] = tmp;
-
-    size_t t = (*old_from_new)[left_ind];
-    (*old_from_new)[left_ind] = (*old_from_new)[right_ind];
-    (*old_from_new)[right_ind] = t;
+    // Store the mapping from old to new.
+    const size_t tmp = oldFromNew[left];
+    oldFromNew[left] = oldFromNew[right];
+    oldFromNew[right] = tmp;
   }
 
-  assert(left_ind == right_ind + 1);
+  // This now refers to the first index of the "right" side.
+  return left;
 }
 
 
@@ -371,16 +343,16 @@
   if ((size_t) (end_ - start_) > maxLeafSize) {
 
     // Find the split.
-    size_t dim, split_ind;
+    size_t dim;
+    double splitValue;
     double left_error, right_error;
-    if (FindSplit(*data, dim, split_ind, left_error, right_error, maxLeafSize,
+    if (FindSplit(*data, dim, splitValue, left_error, right_error, maxLeafSize,
         minLeafSize))
     {
       // Move the data around for the children to have points in a node lie
       // contiguously (to increase efficiency during the training).
-      eT split_val, lsplit_val, rsplit_val;
-      SplitData_(data, dim, split_ind, old_from_new, &split_val, &lsplit_val,
-          &rsplit_val);
+      const size_t splitIndex = SplitData(*data, dim, splitValue,
+          *old_from_new);
 
       // Make max and min vals for the children.
       VecType* max_vals_l = new VecType(*max_vals_);
@@ -388,17 +360,17 @@
       VecType* min_vals_l = new VecType(*min_vals_);
       VecType* min_vals_r = new VecType(*min_vals_);
 
-      (*max_vals_l)[dim] = split_val;
-      (*min_vals_r)[dim] = split_val;
+      (*max_vals_l)[dim] = splitValue;
+      (*min_vals_r)[dim] = splitValue;
 
       // Store split dim and split val in the node.
-      split_value_ = split_val;
+      split_value_ = splitValue;
       split_dim_ = dim;
 
       // Recursively grow the children.
-      left_ = new DTree(max_vals_l, min_vals_l, start_, start_ + split_ind + 1,
+      left_ = new DTree(max_vals_l, min_vals_l, start_, splitIndex,
           (cT) left_error);
-      right_ = new DTree(max_vals_r, min_vals_r, start_ + split_ind + 1, end_,
+      right_ = new DTree(max_vals_r, min_vals_r, splitIndex, end_,
           (cT) right_error);
 
       left_g = left_->Grow(data, old_from_new, useVolReg, maxLeafSize,




More information about the mlpack-svn mailing list