[mlpack-git] master, mlpack-1.0.x: For clarity, use separate split and binLabels objects instead of storing the label in the split matrix. No casting is necessary anymore. (8e76f8f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:49:39 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 8e76f8f451c4f36317ef367a25708d7dda250b16
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jun 25 00:22:09 2014 +0000

    For clarity, use separate split and binLabels objects instead of storing the label in the split matrix.  No casting is necessary anymore.


>---------------------------------------------------------------

8e76f8f451c4f36317ef367a25708d7dda250b16
 .../methods/decision_stump/decision_stump.hpp      |  7 ++-
 .../methods/decision_stump/decision_stump_impl.hpp | 52 ++++++++++------------
 2 files changed, 29 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 5db64be..e1bec19 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -64,8 +64,11 @@ class DecisionStump
   //! Stores the class labels for the input data.
   arma::Row<size_t> classLabels;
 
-  //! Stores the splitting criterion after training.
-  arma::mat split;
+  //! Stores the splitting values after training.
+  arma::vec split;
+
+  //! Stores the labels for each splitting bin.
+  arma::Col<size_t> binLabels;
 
   /**
    * Sets up attribute as if it were splitting on it and finds entropy when
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 2c757ca..b3d4075 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -39,7 +39,7 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
   bucketSize = inpBucketSize;
 
   // Check whether the input labels are not all identical.
-  if (!isDistinct<size_t>(classLabels))
+  if (!isDistinct<size_t>(labels))
   {
     // If the classLabels are all identical, the default class is the only
     // class.
@@ -99,31 +99,21 @@ void DecisionStump<MatType>::Classify(const MatType& test,
   {
     for (int i = 0; i < test.n_cols; i++)
     {
-      int j = 0;
-
+      // Determine which bin the test point falls into.
+      // Assume first that it falls into the first bin, then proceed through the
+      // bins until it is known which bin it falls into.
+      int bin = 0;
       const double val = test(splitCol, i);
-      while (j < split.n_rows)
+
+      while (bin < split.n_elem - 1)
       {
-        if (val < split(j, 0) && (!j))
-        {
-          predictedLabels(i) = split(0, 1);
+        if (val < split(bin + 1))
           break;
-        }
-        else if (val >= split(j, 0))
-        {
-          if (j == split.n_rows - 1)
-          {
-            predictedLabels(i) = split(split.n_rows - 1, 1);
-            break;
-          }
-          else if (val < split(j + 1, 0))
-          {
-            predictedLabels(i) = split(j, 1);
-            break;
-          }
-        }
-        j++;
+
+        ++bin;
       }
+
+      predictedLabels(i) = binLabels(bin);
     }
   }
   else
@@ -243,7 +233,8 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
   arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
-  arma::mat tempSplit;
+  arma::vec tempSplit;
+  arma::Row<size_t> tempLabel;
 
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
@@ -267,8 +258,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
 
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i++;
     }
@@ -297,8 +290,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
       // the bucket of subCols.
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i = end + 1;
       count = 0;
@@ -321,9 +316,10 @@ void DecisionStump<MatType>::MergeRanges()
 {
   for (int i = 1; i < split.n_rows; i++)
   {
-    if (split(i, 1) == split(i - 1, 1))
+    if (binLabels(i) == binLabels(i - 1))
     {
       // Remove this row, as it has the same label as the previous bucket.
+      binLabels.shed_row(i);
       split.shed_row(i);
       // Go back to previous row.
       i--;



More information about the mlpack-git mailing list