[mlpack-svn] r16708 - mlpack/trunk/src/mlpack/methods/decision_stump

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Jun 24 20:22:10 EDT 2014


Author: rcurtin
Date: Tue Jun 24 20:22:09 2014
New Revision: 16708

Log:
For clarity, use separate split and binLabels objects instead of storing the label in the split matrix.  No casting is necessary anymore.


Modified:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	Tue Jun 24 20:22:09 2014
@@ -64,8 +64,11 @@
   //! Stores the class labels for the input data.
   arma::Row<size_t> classLabels;
 
-  //! Stores the splitting criterion after training.
-  arma::mat split;
+  //! Stores the splitting values after training.
+  arma::vec split;
+
+  //! Stores the labels for each splitting bin.
+  arma::Col<size_t> binLabels;
 
   /**
    * Sets up attribute as if it were splitting on it and finds entropy when

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	Tue Jun 24 20:22:09 2014
@@ -39,7 +39,7 @@
   bucketSize = inpBucketSize;
 
   // Check whether the input labels are not all identical.
-  if (!isDistinct<size_t>(classLabels))
+  if (!isDistinct<size_t>(labels))
   {
     // If the classLabels are all identical, the default class is the only
     // class.
@@ -99,31 +99,21 @@
   {
     for (int i = 0; i < test.n_cols; i++)
     {
-      int j = 0;
-
+      // Determine which bin the test point falls into.
+      // Assume first that it falls into the first bin, then proceed through the
+      // bins until it is known which bin it falls into.
+      int bin = 0;
       const double val = test(splitCol, i);
-      while (j < split.n_rows)
+
+      while (bin < split.n_elem - 1)
       {
-        if (val < split(j, 0) && (!j))
-        {
-          predictedLabels(i) = split(0, 1);
+        if (val < split(bin + 1))
           break;
-        }
-        else if (val >= split(j, 0))
-        {
-          if (j == split.n_rows - 1)
-          {
-            predictedLabels(i) = split(split.n_rows - 1, 1);
-            break;
-          }
-          else if (val < split(j + 1, 0))
-          {
-            predictedLabels(i) = split(j, 1);
-            break;
-          }
-        }
-        j++;
+
+        ++bin;
       }
+
+      predictedLabels(i) = binLabels(bin);
     }
   }
   else
@@ -243,7 +233,8 @@
   arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
-  arma::mat tempSplit;
+  arma::vec tempSplit;
+  arma::Row<size_t> tempLabel;
 
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
@@ -267,8 +258,10 @@
 
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i++;
     }
@@ -297,8 +290,10 @@
       // the bucket of subCols.
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
-      split = arma::join_cols(split, tempSplit);
+      split.resize(split.n_elem + 1);
+      split(split.n_elem - 1) = sortedSplitAtt(begin);
+      binLabels.resize(binLabels.n_elem + 1);
+      binLabels(binLabels.n_elem - 1) = mostFreq;
 
       i = end + 1;
       count = 0;
@@ -321,9 +316,10 @@
 {
   for (int i = 1; i < split.n_rows; i++)
   {
-    if (split(i, 1) == split(i - 1, 1))
+    if (binLabels(i) == binLabels(i - 1))
     {
       // Remove this row, as it has the same label as the previous bucket.
+      binLabels.shed_row(i);
       split.shed_row(i);
       // Go back to previous row.
       i--;



More information about the mlpack-svn mailing list