[mlpack-git] master: Refactor DecisionStump significantly. (cca2005)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:11 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271

>---------------------------------------------------------------

commit cca2005449af19053365a595d60089dac2a487a0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Nov 21 01:17:02 2015 +0000

    Refactor DecisionStump significantly.
    
    Remove unnecessary template definitions, avoid copies, clean up API, and add a standalone public Train() function.


>---------------------------------------------------------------

cca2005449af19053365a595d60089dac2a487a0
 .../methods/decision_stump/decision_stump.hpp      | 100 ++++++------
 .../methods/decision_stump/decision_stump_impl.hpp | 167 ++++++++++-----------
 2 files changed, 139 insertions(+), 128 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 61eea06..10099b9 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -25,7 +25,7 @@ namespace decision_stump {
  *
  * @tparam MatType Type of matrix that is being used (sparse or dense).
  */
-template <typename MatType = arma::mat>
+template<typename MatType = arma::mat>
 class DecisionStump
 {
  public:
@@ -36,22 +36,12 @@ class DecisionStump
    * @param data Input, training data.
    * @param labels Labels of training data.
    * @param classes Number of distinct classes in labels.
-   * @param inpBucketSize Minimum size of bucket when splitting.
+   * @param bucketSize Minimum size of bucket when splitting.
    */
   DecisionStump(const MatType& data,
                 const arma::Row<size_t>& labels,
                 const size_t classes,
-                size_t inpBucketSize);
-
-  /**
-   * Classification function. After training, classify test, and put the
-   * predicted classes in predictedLabels.
-   *
-   * @param test Testing data or data to classify.
-   * @param predictedLabels Vector to store the predicted classes after
-   *     classifying test data.
-   */
-  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+                const size_t bucketSize);
 
   /**
    * Alternate constructor which copies parameters bucketSize and numClass from
@@ -61,19 +51,43 @@ class DecisionStump
    * @param other The other initiated Decision Stump object from
    *      which we copy the values.
    * @param data The data on which to train this object on.
-   * @param D Weight vector to use while training. For boosting purposes.
    * @param labels The labels of data.
-   * @param isWeight Whether we need to run a weighted Decision Stump.
+   * @param weights Weight vector to use while training. For boosting purposes.
    */
   DecisionStump(const DecisionStump<>& other,
                 const MatType& data,
                 const arma::Row<size_t>& labels,
                 const arma::rowvec& weights);
 
+  /**
+   * Train the decision stump on the given data.  This completely overwrites any
+   * previous training data, so after training the stump may be completely
+   * different.
+   *
+   * @param data Dataset to train on.
+   * @param labels Labels for each point in the dataset.
+   * @param classes Number of classes in the dataset.
+   * @param bucketSize Minimum size of bucket when splitting.
+   */
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const size_t classes,
+             const size_t bucketSize);
+
+  /**
+   * Classification function. After training, classify test, and put the
+   * predicted classes in predictedLabels.
+   *
+   * @param test Testing data or data to classify.
+   * @param predictedLabels Vector to store the predicted classes after
+   *     classifying test data.
+   */
+  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
   //! Access the splitting attribute.
-  int SplitAttribute() const { return splitAttribute; }
+  size_t SplitAttribute() const { return splitAttribute; }
   //! Modify the splitting attribute (be careful!).
-  int& SplitAttribute() { return splitAttribute; }
+  size_t& SplitAttribute() { return splitAttribute; }
 
   //! Access the splitting values.
   const arma::vec& Split() const { return split; }
@@ -86,18 +100,15 @@ class DecisionStump
   arma::Col<size_t>& BinLabels() { return binLabels; }
 
  private:
-  //! Stores the number of classes.
-  size_t numClasses;
-
-  //! Stores the value of the attribute on which to split.
-  int splitAttribute;
-
-  //! Size of bucket while determining splitting criterion.
+  //! The number of classes (we must store this for boosting).
+  size_t classes;
+  //! The minimum number of points in a bucket.
   size_t bucketSize;
 
+  //! Stores the value of the attribute on which to split.
+  size_t splitAttribute;
   //! Stores the splitting values after training.
   arma::vec split;
-
   //! Stores the labels for each splitting bin.
   arma::Col<size_t> binLabels;
 
@@ -107,9 +118,9 @@ class DecisionStump
    *
    * @param attribute A row from the training data, which might be a
    *     candidate for the splitting attribute.
-   * @param isWeight Whether we need to run a weighted Decision Stump.
+   * @param UseWeights Whether we need to run a weighted Decision Stump.
    */
-  template <bool isWeight>
+  template<bool UseWeights>
   double SetupSplitAttribute(const arma::rowvec& attribute,
                              const arma::Row<size_t>& labels,
                              const arma::rowvec& weightD);
@@ -121,8 +132,8 @@ class DecisionStump
    * @param attribute attribute is the attribute decided by the constructor
    *      on which we now train the decision stump.
    */
-  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
-                                            const arma::Row<size_t>& labels);
+  void TrainOnAtt(const arma::rowvec& attribute,
+                  const arma::Row<size_t>& labels);
 
   /**
    * After the "split" matrix has been set up, merge ranges with identical class
@@ -136,42 +147,43 @@ class DecisionStump
    * @param subCols The vector in which to find the most frequently
    *     occurring element.
    */
-  template <typename rType> rType CountMostFreq(const arma::Row<rType>&
-      subCols);
+  template<typename VecType>
+  double CountMostFreq(const VecType& subCols);
 
   /**
    * Returns 1 if all the values of featureRow are not same.
    *
    * @param featureRow The attribute which is checked for identical values.
    */
-  template <typename rType> int IsDistinct(const arma::Row<rType>& featureRow);
+  template<typename VecType>
+  int IsDistinct(const VecType& featureRow);
 
   /**
    * Calculate the entropy of the given attribute.
    *
-   * @param attribute The attribute of which we calculate the entropy.
    * @param labels Corresponding labels of the attribute.
-   * @param isWeight Whether we need to run a weighted Decision Stump.
+   * @param classes Number of classes.
+   * @param weights Weights for this set of labels.
    */
-  template <typename LabelType, bool isWeight>
-  double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
-                          const arma::rowvec& tempD);
+  template<bool UseWeights, typename VecType, typename WeightVecType>
+  double CalculateEntropy(const VecType& labels,
+                          const WeightVecType& weights);
 
   /**
    * Train the decision stump on the given data and labels.
    *
    * @param data Dataset to train on.
    * @param labels Labels for dataset.
-   * @param isWeight Whether we need to run a weighted Decision Stump.
+   * @param UseWeights Whether we need to run a weighted Decision Stump.
    */
-  template <bool isWeight>
-  void Train(const MatType& data, const arma::Row<size_t>& labels,
-             const arma::rowvec& weightD);
-
+  template<bool UseWeights>
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const arma::rowvec& weights);
 };
 
-}; // namespace decision_stump
-}; // namespace mlpack
+} // namespace decision_stump
+} // namespace mlpack
 
 #include "decision_stump_impl.hpp"
 
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 4b5d987..c052063 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -20,20 +20,35 @@ namespace decision_stump {
  * @param data Input, training data.
  * @param labels Labels of data.
  * @param classes Number of distinct classes in labels.
- * @param inpBucketSize Minimum size of bucket when splitting.
+ * @param bucketSize Minimum size of bucket when splitting.
  */
 template<typename MatType>
 DecisionStump<MatType>::DecisionStump(const MatType& data,
                                       const arma::Row<size_t>& labels,
                                       const size_t classes,
-                                      size_t inpBucketSize)
+                                      const size_t bucketSize) :
+    classes(classes),
+    bucketSize(bucketSize)
 {
-  numClasses = classes;
-  bucketSize = inpBucketSize;
+  arma::rowvec weights;
+  Train<false>(data, labels, weights);
+}
 
-  arma::rowvec weightD;
+/**
+ * Train on the given data and labels.
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Train(const MatType& data,
+                                   const arma::Row<size_t>& labels,
+                                   const size_t classes,
+                                   const size_t bucketSize)
+{
+  this->classes = classes;
+  this->bucketSize = bucketSize;
 
-  Train<false>(data, labels, weightD);
+  // Pass to unweighted training function.
+  arma::rowvec weights;
+  Train<false>(data, labels, weights);
 }
 
 /**
@@ -41,36 +56,38 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
  *
  * @param data Dataset to train on.
  * @param labels Labels for dataset.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
 template<typename MatType>
-template <bool isWeight>
-void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels,
-                                    const arma::rowvec& weightD)
+template<bool UseWeights>
+void DecisionStump<MatType>::Train(const MatType& data,
+                                   const arma::Row<size_t>& labels,
+                                   const arma::rowvec& weights)
 {
+  this->classes = classes;
+  this->bucketSize = bucketSize;
+
   // If classLabels are not all identical, proceed with training.
-  int bestAtt = 0;
+  size_t bestAtt = 0;
   double entropy;
-  const double rootEntropy = CalculateEntropy<size_t, isWeight>(
-      labels.subvec(0, labels.n_elem - 1), 0, weightD);
+  const double rootEntropy = CalculateEntropy<UseWeights>(labels, weights);
 
   double gain, bestGain = 0.0;
   for (size_t i = 0; i < data.n_rows; i++)
   {
     // Go through each attribute of the data.
-    if (IsDistinct<double>(data.row(i)))
+    if (IsDistinct(data.row(i)))
     {
       // For each attribute with non-identical values, treat it as a potential
       // splitting attribute and calculate entropy if split on it.
-      entropy = SetupSplitAttribute<isWeight>(data.row(i), labels, weightD);
+      entropy = SetupSplitAttribute<UseWeights>(data.row(i), labels, weights);
 
       gain = rootEntropy - entropy;
       // Find the attribute with the best entropy so that the gain is
       // maximized.
 
-      // if (entropy < bestEntropy)
-      // Instead of the above rule, we are maximizing gain, which was
-      // what is returned from SetupSplitAttribute.
+      // We are maximizing gain, which is what is returned from
+      // SetupSplitAttribute().
       if (gain < bestGain)
       {
         bestAtt = i;
@@ -81,7 +98,7 @@ void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>&
   splitAttribute = bestAtt;
 
   // Once the splitting column/attribute has been decided, train on it.
-  TrainOnAtt<double>(data.row(splitAttribute), labels);
+  TrainOnAtt(data.row(splitAttribute), labels);
 }
 
 /**
@@ -126,20 +143,16 @@ void DecisionStump<MatType>::Classify(const MatType& test,
  * @param data The data on which to train this object on.
  * @param D Weight vector to use while training. For boosting purposes.
  * @param labels The labels of data.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template <typename MatType>
+template<typename MatType>
 DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
                                       const MatType& data,
                                       const arma::Row<size_t>& labels,
-                                      const arma::rowvec& weights)
+                                      const arma::rowvec& weights) :
+    classes(other.classes),
+    bucketSize(other.bucketSize)
 {
-  numClasses = other.numClasses;
-  bucketSize = other.bucketSize;
-
-  // weightD = weights;
-  // tempD = weightD;
-
   Train<true>(data, labels, weights);
 }
 
@@ -149,14 +162,14 @@ DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
  *
  * @param attribute A row from the training data, which might be a candidate for
  *      the splitting attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
-template <typename MatType>
-template <bool isWeight>
+template<typename MatType>
+template<bool UseWeights>
 double DecisionStump<MatType>::SetupSplitAttribute(
     const arma::rowvec& attribute,
     const arma::Row<size_t>& labels,
-    const arma::rowvec& weightD)
+    const arma::rowvec& weights)
 {
   size_t i, count, begin, end;
   double entropy = 0.0;
@@ -169,23 +182,22 @@ double DecisionStump<MatType>::SetupSplitAttribute(
   arma::uvec sortedIndexAtt = arma::stable_sort_index(attribute.t());
 
   arma::Row<size_t> sortedLabels(attribute.n_elem);
-  sortedLabels.fill(0);
-
-  arma::rowvec tempD = arma::rowvec(weightD.n_cols);
+  arma::rowvec sortedWeights(attribute.n_elem);
 
   for (i = 0; i < attribute.n_elem; i++)
   {
     sortedLabels(i) = labels(sortedIndexAtt(i));
 
-    if(isWeight)
-      tempD(i) = weightD(sortedIndexAtt(i));
+    // Apply weights if necessary.
+    if (UseWeights)
+      sortedWeights(i) = weights(sortedIndexAtt(i));
   }
 
   i = 0;
   count = 0;
 
-  // This splits the sorted into buckets of size greater than or equal to
-  // inpBucketSize.
+  // This splits the sorted data into buckets of size greater than or equal to
+  // bucketSize.
   while (i < sortedLabels.n_elem)
   {
     count++;
@@ -199,8 +211,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       // Use ratioEl to calculate the ratio of elements in this split.
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
-          sortedLabels.subvec(begin, end), begin, tempD);
+      entropy += ratioEl * CalculateEntropy<UseWeights>(
+          sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end));
       i++;
     }
     else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -226,8 +238,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       }
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
-          sortedLabels.subvec(begin, end), begin, tempD);
+      entropy += ratioEl * CalculateEntropy<UseWeights>(
+          sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end));
 
       i = end + 1;
       count = 0;
@@ -245,8 +257,7 @@ double DecisionStump<MatType>::SetupSplitAttribute(
  * @param attribute Attribute is the attribute decided by the constructor on
  *      which we now train the decision stump.
  */
-template <typename MatType>
-template <typename rType>
+template<typename MatType>
 void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
                                         const arma::Row<size_t>& labels)
 {
@@ -256,14 +267,12 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
   arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
-  arma::vec tempSplit;
-  arma::Row<size_t> tempLabel;
 
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = labels(sortedSplitIndexAtt(i));
 
   arma::rowvec subCols;
-  rType mostFreq;
+  double mostFreq;
   i = 0;
   count = 0;
   while (i < sortedLabels.n_elem)
@@ -274,12 +283,7 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
       begin = i - count + 1;
       end = i;
 
-      arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem);
-      zSubCols.fill(0.0);
-
-      subCols = sortedLabels.cols(begin, end) + zSubCols;
-
-      mostFreq = CountMostFreq<double>(subCols);
+      mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
 
       split.resize(split.n_elem + 1);
       split(split.n_elem - 1) = sortedSplitAtt(begin);
@@ -304,14 +308,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
         begin = i - count + 1;
         end = i;
       }
-      arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem);
-      zSubCols.fill(0.0);
-
-      subCols = sortedLabels.cols(begin, end) + zSubCols;
 
       // Find the most frequent element in subCols so as to assign a label to
       // the bucket of subCols.
-      mostFreq = CountMostFreq<double>(subCols);
+      mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
 
       split.resize(split.n_elem + 1);
       split(split.n_elem - 1) = sortedSplitAtt(begin);
@@ -334,7 +334,7 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
  * After the "split" matrix has been set up, merge ranges with identical class
  * labels.
  */
-template <typename MatType>
+template<typename MatType>
 void DecisionStump<MatType>::MergeRanges()
 {
   for (size_t i = 1; i < split.n_rows; i++)
@@ -350,13 +350,13 @@ void DecisionStump<MatType>::MergeRanges()
   }
 }
 
-template <typename MatType>
-template <typename rType>
-rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
+template<typename MatType>
+template<typename VecType>
+double DecisionStump<MatType>::CountMostFreq(const VecType& subCols)
 {
   // We'll create a map of elements and the number of times that each element is
   // seen.
-  std::map<rType, size_t> countMap;
+  std::map<double, size_t> countMap;
 
   for (size_t i = 0; i < subCols.n_elem; ++i)
   {
@@ -367,8 +367,8 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
   }
 
   // Now find the maximum value.
-  typename std::map<rType, size_t>::iterator it = countMap.begin();
-  rType mostFreq = it->first;
+  typename std::map<double, size_t>::iterator it = countMap.begin();
+  double mostFreq = it->first;
   size_t mostFreqCount = it->second;
   while (it != countMap.end())
   {
@@ -385,15 +385,15 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
 }
 
 /**
- * Returns 1 if all the values of featureRow are not same.
+ * Returns 1 if all the values of featureRow are not the same.
  *
  * @param featureRow The attribute which is checked for identical values.
  */
-template <typename MatType>
-template <typename rType>
-int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
+template<typename MatType>
+template<typename VecType>
+int DecisionStump<MatType>::IsDistinct(const VecType& featureRow)
 {
-  rType val = featureRow(0);
+  typename VecType::elem_type val = featureRow(0);
   for (size_t i = 1; i < featureRow.n_elem; ++i)
     if (val != featureRow(i))
       return 1;
@@ -405,34 +405,33 @@ int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
  *
  * @param attribute The attribute for which we calculate the entropy.
  * @param labels Corresponding labels of the attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
  */
 template<typename MatType>
-template<typename LabelType, bool isWeight>
+template<bool UseWeights, typename VecType, typename WeightVecType>
 double DecisionStump<MatType>::CalculateEntropy(
-    arma::subview_row<LabelType> labels,
-    int begin, const arma::rowvec& tempD)
+    const VecType& labels,
+    const WeightVecType& weights)
 {
   double entropy = 0.0;
   size_t j;
 
-  arma::Row<size_t> numElem(numClasses);
+  arma::Row<size_t> numElem(classes);
   numElem.fill(0);
 
   // Variable to accumulate the weight in this subview_row.
   double accWeight = 0.0;
   // Populate numElem; they are used as helpers to calculate entropy.
 
-  if (isWeight)
+  if (UseWeights)
   {
     for (j = 0; j < labels.n_elem; j++)
     {
-      numElem(labels(j)) += tempD(j + begin);
-      accWeight += tempD(j + begin);
+      numElem(labels(j)) += weights(j);
+      accWeight += weights(j);
     }
-      // numElem(labels(j))++;
 
-    for (j = 0; j < numClasses; j++)
+    for (j = 0; j < classes; j++)
     {
       const double p1 = ((double) numElem(j) / accWeight);
 
@@ -447,7 +446,7 @@ double DecisionStump<MatType>::CalculateEntropy(
     for (j = 0; j < labels.n_elem; j++)
       numElem(labels(j))++;
 
-    for (j = 0; j < numClasses; j++)
+    for (j = 0; j < classes; j++)
     {
       const double p1 = ((double) numElem(j) / labels.n_elem);
 
@@ -461,7 +460,7 @@ double DecisionStump<MatType>::CalculateEntropy(
   return entropy / std::log(2.0);
 }
 
-}; // namespace decision_stump
-}; // namespace mlpack
+} // namespace decision_stump
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list