[mlpack-svn] r16714 - mlpack/trunk/src/mlpack/methods/decision_stump

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jun 25 15:27:19 EDT 2014


Author: rcurtin
Date: Wed Jun 25 15:27:19 2014
New Revision: 16714

Log:
Remove classLabels; it isn't necessary for the DecisionStump class to hold it.


Modified:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	Wed Jun 25 15:27:19 2014
@@ -61,9 +61,6 @@
   //! Size of bucket while determining splitting criterion.
   size_t bucketSize;
 
-  //! Stores the class labels for the input data.
-  arma::Row<size_t> classLabels;
-
   //! Stores the splitting values after training.
   arma::vec split;
 
@@ -77,7 +74,8 @@
    * @param attribute A row from the training data, which might be a
    *     candidate for the splitting attribute.
    */
-  double SetupSplitAttribute(const arma::rowvec& attribute);
+  double SetupSplitAttribute(const arma::rowvec& attribute,
+                             const arma::Row<size_t>& labels);
 
   /**
    * After having decided the attribute on which to split, train on that
@@ -86,7 +84,8 @@
    * @param attribute attribute is the attribute decided by the constructor
    *      on which we now train the decision stump.
    */
-  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute);
+  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
+                                            const arma::Row<size_t>& labels);
 
   /**
    * After the "split" matrix has been set up, merge ranges with identical class

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	Wed Jun 25 15:27:19 2014
@@ -31,24 +31,19 @@
                                       const size_t classes,
                                       size_t inpBucketSize)
 {
-  arma::Row<size_t> zLabels(labels.n_elem);
-  zLabels.fill(0);
-  classLabels = labels + zLabels;
-
   numClass = classes;
   bucketSize = inpBucketSize;
 
   // Check whether the input labels are not all identical.
   if (!isDistinct<size_t>(labels))
   {
-    // If the classLabels are all identical, the default class is the only
-    // class.
+    // If the labels are all identical, the default class is the only class.
     oneClass = true;
-    defaultClass = classLabels(0);
+    defaultClass = labels(0);
   }
   else
   {
-    // If classLabels are not all identical, proceed with training.
+    // If labels are not all identical, proceed with training.
     oneClass = false;
     int bestAtt = -1;
     double entropy;
@@ -56,7 +51,7 @@
 
     // Set the default class to handle attribute values which are not present in
     // the training data.
-    defaultClass = CountMostFreq<size_t>(classLabels);
+    defaultClass = CountMostFreq<size_t>(labels);
 
     for (int i = 0; i < data.n_rows; i++)
     {
@@ -65,7 +60,7 @@
       {
         // For each attribute with non-identical values, treat it as a potential
         // splitting attribute and calculate entropy if split on it.
-        entropy = SetupSplitAttribute(data.row(i));
+        entropy = SetupSplitAttribute(data.row(i), labels);
 
         // Find the attribute with the bestEntropy so that the gain is
         // maximized.
@@ -79,7 +74,7 @@
     splitCol = bestAtt;
 
     // Once the splitting column/attribute has been decided, train on it.
-    TrainOnAtt<double>(data.row(splitCol));
+    TrainOnAtt<double>(data.row(splitCol), labels);
   }
 }
 
@@ -131,7 +126,9 @@
  *      the splitting attribute.
  */
 template <typename MatType>
-double DecisionStump<MatType>::SetupSplitAttribute(const arma::rowvec& attribute)
+double DecisionStump<MatType>::SetupSplitAttribute(
+    const arma::rowvec& attribute,
+    const arma::Row<size_t>& labels)
 {
   int i, count, begin, end;
   double entropy = 0.0;
@@ -147,7 +144,7 @@
   sortedLabels.fill(0);
 
   for (i = 0; i < attribute.n_elem; i++)
-    sortedLabels(i) = classLabels(sortedIndexAtt(i));
+    sortedLabels(i) = labels(sortedIndexAtt(i));
 
   arma::rowvec subColLabels;
   arma::rowvec subColAtts;
@@ -225,7 +222,8 @@
  */
 template <typename MatType>
 template <typename rType>
-void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
+void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
+                                        const arma::Row<size_t>& labels)
 {
   int i, count, begin, end;
 
@@ -237,7 +235,7 @@
   arma::Row<size_t> tempLabel;
 
   for (i = 0; i < attribute.n_elem; i++)
-    sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
+    sortedLabels(i) = labels(sortedSplitIndexAtt(i));
 
   arma::rowvec subCols;
   rType mostFreq;



More information about the mlpack-svn mailing list