[mlpack-svn] r16724 - mlpack/trunk/src/mlpack/methods/decision_stump

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jun 26 19:21:08 EDT 2014


Author: rcurtin
Date: Thu Jun 26 19:21:07 2014
New Revision: 16724

Log:
Remove oneClass and defaultClass variables.  There is a shortcut that can be taken when all the labels are the same, but the Entropy() function does not appear to be working correctly.


Modified:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	Thu Jun 26 19:21:07 2014
@@ -49,15 +49,9 @@
   //! Stores the number of classes.
   size_t numClass;
 
-  //! Stores the default class. Provided for handling missing attribute values.
-  size_t defaultClass;
-
   //! Stores the value of the attribute on which to split.
   int splitCol;
 
-  //! Flag value for distinct input class labels.
-  bool oneClass;
-
   //! Size of bucket while determining splitting criterion.
   size_t bucketSize;
 

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	Thu Jun 26 19:21:07 2014
@@ -34,48 +34,55 @@
   numClass = classes;
   bucketSize = inpBucketSize;
 
-  // Check whether the input labels are not all identical.
-  if (!isDistinct<size_t>(labels))
-  {
-    // If the labels are all identical, the default class is the only class.
-    oneClass = true;
-    defaultClass = labels(0);
-  }
-  else
-  {
-    // If labels are not all identical, proceed with training.
-    oneClass = false;
-    int bestAtt = -1;
-    double entropy;
-    double bestEntropy = DBL_MAX;
-
-    // Set the default class to handle attribute values which are not present in
-    // the training data.
-    defaultClass = CountMostFreq<size_t>(labels);
+  // If classLabels are not all identical, proceed with training.
+  int bestAtt = -1;
+  double entropy;
+  double bestEntropy = DBL_MAX;
+
+  // Set the default class to handle attribute values which are not present in
+  // the training data.
+  //defaultClass = CountMostFreq<size_t>(classLabels);
 
-    for (int i = 0; i < data.n_rows; i++)
+  for (int i = 0; i < data.n_rows; i++)
+  {
+    // Go through each attribute of the data.
+    if (isDistinct<double>(data.row(i)))
     {
-      // Go through each attribute of the data.
-      if (isDistinct<double>(data.row(i)))
+      // For each attribute with non-identical values, treat it as a potential
+      // splitting attribute and calculate entropy if split on it.
+      entropy = SetupSplitAttribute(data.row(i), labels);
+
+      // Find the attribute with the bestEntropy so that the gain is
+      // maximized.
+      if (entropy < bestEntropy)
       {
-        // For each attribute with non-identical values, treat it as a potential
-        // splitting attribute and calculate entropy if split on it.
-        entropy = SetupSplitAttribute(data.row(i), labels);
-
-        // Find the attribute with the bestEntropy so that the gain is
-        // maximized.
-        if (entropy < bestEntropy)
-        {
-          bestAtt = i;
-          bestEntropy = entropy;
-        }
+        bestAtt = i;
+        bestEntropy = entropy;
       }
-    }
-    splitCol = bestAtt;
 
-    // Once the splitting column/attribute has been decided, train on it.
-    TrainOnAtt<double>(data.row(splitCol), labels);
+      /* This section is commented out because I believe entropy calculation is
+       * wrong.  Entropy should only be 0 if there is only one class, in which
+       * case classification is perfect and we can take the shortcut below.
+
+      // If the entropy is 0, then all the labels are the same and we are done.
+      Log::Debug << "Entropy is " << entropy << "\n";
+      if (entropy == 0)
+      {
+        // Only one split element... there is no split at all, just one bin.
+        split.set_size(1);
+        binLabels.set_size(1);
+        split[0] = -DBL_MAX;
+        binLabels[0] = labels[0];
+        splitCol = 0; // It doesn't matter.
+        return;
+      }
+      */
+    }
   }
+  splitCol = bestAtt;
+
+  // Once the splitting column/attribute has been decided, train on it.
+  TrainOnAtt<double>(data.row(splitCol), labels);
 }
 
 /**
@@ -90,31 +97,23 @@
 void DecisionStump<MatType>::Classify(const MatType& test,
                                       arma::Row<size_t>& predictedLabels)
 {
-  if (!oneClass)
+  for (int i = 0; i < test.n_cols; i++)
   {
-    for (int i = 0; i < test.n_cols; i++)
-    {
-      // Determine which bin the test point falls into.
-      // Assume first that it falls into the first bin, then proceed through the
-      // bins until it is known which bin it falls into.
-      int bin = 0;
-      const double val = test(splitCol, i);
-
-      while (bin < split.n_elem - 1)
-      {
-        if (val < split(bin + 1))
-          break;
+    // Determine which bin the test point falls into.
+    // Assume first that it falls into the first bin, then proceed through the
+    // bins until it is known which bin it falls into.
+    int bin = 0;
+    const double val = test(splitCol, i);
 
-        ++bin;
-      }
+    while (bin < split.n_elem - 1)
+    {
+      if (val < split(bin + 1))
+        break;
 
-      predictedLabels(i) = binLabels(bin);
+      ++bin;
     }
-  }
-  else
-  {
-    for (int i = 0; i < test.n_cols; i++)
-      predictedLabels(i) = defaultClass;
+
+    predictedLabels(i) = binLabels(bin);
   }
 }
 
@@ -408,7 +407,7 @@
     {
       if (uniqueAtt[j] == attribute[i])
       {
-        entropyArray(j,labels(i))++;
+        entropyArray(j, labels(i))++;
         numElem(j)++;
       }
     }



More information about the mlpack-svn mailing list