[mlpack-svn] r16792 - mlpack/trunk/src/mlpack/methods/decision_stump

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 12:49:44 EDT 2014


Author: rcurtin
Date: Wed Jul  9 12:49:43 2014
New Revision: 16792

Log:
Fix some formatting, fix backwards entropy splitting, add getters/setters, and
comment a little bit about the internal structure of the class.


Modified:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	Wed Jul  9 12:49:43 2014
@@ -16,6 +16,14 @@
  * This class implements a decision stump. It constructs a single level
  * decision tree, i.e., a decision stump. It uses entropy to decide splitting
  * ranges.
+ *
+ * The stump is parameterized by a splitting attribute (the dimension on which
+ * points are split), a vector of bin split values, and a vector of labels for
+ * each bin.  Bin i is specified by the range [split[i], split[i + 1]).  The
+ * last bin has range up to \infty (split[i + 1] does not exist in that case).
+ * Points that are below the first bin will take the label of the first bin.
+ *
+ * @tparam MatType Type of matrix that is being used (sparse or dense).
  */
 template <typename MatType = arma::mat>
 class DecisionStump
@@ -45,13 +53,27 @@
    */
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
-  int splitCol;
+  //! Access the splitting attribute.
+  int SplitAttribute() const { return splitAttribute; }
+  //! Modify the splitting attribute (be careful!).
+  int& SplitAttribute() { return splitAttribute; }
+
+  //! Access the splitting values.
+  const arma::vec& Split() const { return split; }
+  //! Modify the splitting values (be careful!).
+  arma::vec& Split() { return split; }
+
+  //! Access the labels for each split bin.
+  const arma::Col<size_t> BinLabels() const { return binLabels; }
+  //! Modify the labels for each split bin (be careful!).
+  arma::Col<size_t>& BinLabels() { return binLabels; }
+
  private:
   //! Stores the number of classes.
   size_t numClass;
 
   //! Stores the value of the attribute on which to split.
-  // int splitCol;
+  int splitAttribute;
 
   //! Size of bucket while determining splitting criterion.
   size_t bucketSize;

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	Wed Jul  9 12:49:43 2014
@@ -37,11 +37,7 @@
   // If classLabels are not all identical, proceed with training.
   int bestAtt = -1;
   double entropy;
-  double bestEntropy = DBL_MAX;
-
-  // Set the default class to handle attribute values which are not present in
-  // the training data.
-  //defaultClass = CountMostFreq<size_t>(classLabels);
+  double bestEntropy = -DBL_MAX;
 
   for (int i = 0; i < data.n_rows; i++)
   {
@@ -52,37 +48,21 @@
       // splitting attribute and calculate entropy if split on it.
       entropy = SetupSplitAttribute(data.row(i), labels);
 
-      // Find the attribute with the bestEntropy so that the gain is
+      Log::Debug << "Entropy for attribute " << i << " is " << entropy << ".\n";
+
+      // Find the attribute with the best entropy so that the gain is
       // maximized.
-      if (entropy < bestEntropy)
+      if (entropy > bestEntropy)
       {
         bestAtt = i;
         bestEntropy = entropy;
       }
-
-      /* This section is commented out because I believe entropy calculation is
-       * wrong.  Entropy should only be 0 if there is only one class, in which
-       * case classification is perfect and we can take the shortcut below.
-
-      // If the entropy is 0, then all the labels are the same and we are done.
-      Log::Debug << "Entropy is " << entropy << "\n";
-      if (entropy == 0)
-      {
-        // Only one split element... there is no split at all, just one bin.
-        split.set_size(1);
-        binLabels.set_size(1);
-        split[0] = -DBL_MAX;
-        binLabels[0] = labels[0];
-        splitCol = 0; // It doesn't matter.
-        return;
-      }
-      */
     }
   }
-  splitCol = bestAtt;
+  splitAttribute = bestAtt;
 
   // Once the splitting column/attribute has been decided, train on it.
-  TrainOnAtt<double>(data.row(splitCol), labels);
+  TrainOnAtt<double>(data.row(splitAttribute), labels);
 }
 
 /**
@@ -103,7 +83,7 @@
     // Assume first that it falls into the first bin, then proceed through the
     // bins until it is known which bin it falls into.
     int bin = 0;
-    const double val = test(splitCol, i);
+    const double val = test(splitAttribute, i);
 
     while (bin < split.n_elem - 1)
     {
@@ -147,33 +127,34 @@
 
   i = 0;
   count = 0;
-  double ratioEl;
+
   // This splits the sorted into buckets of size greater than or equal to
   // inpBucketSize.
   while (i < sortedLabels.n_elem)
   {
     count++;
-    if (i == sortedLabels.n_elem - 1) 
+    if (i == sortedLabels.n_elem - 1)
     {
-      // if we're at the end, then don't worry about the bucket size
-      // just take this as the last bin.
+      // If we're at the end, then don't worry about the bucket size; just take
+      // this as the last bin.
       begin = i - count + 1;
       end = i;
-      
-      // using ratioEl to calculate the ratio of elements in this split.
-      ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
-      
-      entropy += ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
+
+      // Use ratioEl to calculate the ratio of elements in this split.
+      const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
+
+      entropy += ratioEl * CalculateEntropy<size_t>(
+          sortedLabels.subvec(begin, end));
       i++;
     }
     else if (sortedLabels(i) != sortedLabels(i + 1))
     {
-      // if we're not at the last element of sortedLabels, then check whether
+      // If we're not at the last element of sortedLabels, then check whether
       // count is less than the current bucket size.
       if (count < bucketSize)
       {
-        // if it is, then take the minimum bucket size anyways
-        // this is where the inpBucketSize comes into use
+        // If it is, then take the minimum bucket size anyways.
+        // This is where the inpBucketSize comes into use.
         // This makes sure there isn't a bucket for every change in labels.
         begin = i - count + 1;
         end = begin + bucketSize - 1;
@@ -183,13 +164,14 @@
       }
       else
       {
-        // if it is not, then take the bucket size as the value of count.
+        // If it is not, then take the bucket size as the value of count.
         begin = i - count + 1;
         end = i;
       }
-      ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
-    
-      entropy +=ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
+      const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
+
+      entropy += ratioEl * CalculateEntropy<size_t>(
+          sortedLabels.subvec(begin, end));
 
       i = end + 1;
       count = 0;
@@ -321,6 +303,9 @@
   rType element;
   int count = 0, localCount = 0;
 
+  if (sortCounts.n_elem == 1)
+    return sortCounts[0];
+
   // An O(n) loop which counts the most frequent element in sortCounts
   for (int i = 0; i < sortCounts.n_elem; ++i)
   {



More information about the mlpack-svn mailing list