[mlpack-svn] r16700 - mlpack/trunk/src/mlpack/methods/decision_stump

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jun 23 20:03:25 EDT 2014


Author: rcurtin
Date: Mon Jun 23 20:03:24 2014
New Revision: 16700

Log:
First pass: comment standardization, fix header guard names, move .cpp to .hpp
because it's all templated functions.


Added:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
      - copied, changed from r16694, /mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
Removed:
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/decision_stump/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
   mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/CMakeLists.txt	Mon Jun 23 20:03:24 2014
@@ -4,7 +4,7 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   decision_stump.hpp
-  decision_stump_impl.cpp
+  decision_stump_impl.hpp
 )
 
 # Add directory name to sources.

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp	Mon Jun 23 20:03:24 2014
@@ -1,127 +1,124 @@
 /**
  * @file decision_stump.hpp
  * @author Udit Saxena
- * 
- * Defintion of decision stumps.
+ *
+ * Definition of decision stumps.
  */
-
-#ifndef _MLPACK_METHODS_DECISION_STUMP_HPP
-#define _MLPACK_METHODS_DECISION_STUMP_HPP
+#ifndef __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
+#define __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
 
 #include <mlpack/core.hpp>
 
 namespace mlpack {
 namespace decision_stump {
-/*
+
+/**
  * This class implements a decision stump. It constructs a single level
- * decision tree, i.e. a decision stump. It uses entropy to decided splitting
+ * decision tree, i.e., a decision stump. It uses entropy to decide splitting
  * ranges.
- *
  */
 template <typename MatType = arma::mat>
 class DecisionStump
 {
  public:
-  /*
-  Constructor. Train on the provided data. Generate a decision stump
-  from data. 
-
-  @param: data - Input, training data.
-  @param: labels - Labels of data.
-  @param: classes - number of distinct classes in labels.
-  @param: inpBucketSize - minimum size of bucket when splitting.
+  /**
+   * Constructor. Train on the provided data. Generate a decision stump from
+   * data.
+   *
+   * @param data Input, training data.
+   * @param labels Labels of training data.
+   * @param classes Number of distinct classes in labels.
+   * @param inpBucketSize Minimum size of bucket when splitting.
    */
   DecisionStump(const MatType& data,
                 const arma::Row<size_t>& labels,
-                const size_t classes, 
+                const size_t classes,
                 size_t inpBucketSize);
 
-  /*
-  Classification function. After training, classify test, and put the 
-  predicted classes in predictedLabels.
-
-  @param: test - testing data or data to classify. 
-  @param: predictedLabels - vector to store the predicted classes after
-                            classifying test
+  /**
+   * Classification function. After training, classify test, and put the
+   * predicted classes in predictedLabels.
+   *
+   * @param test Testing data or data to classify.
+   * @param predictedLabels Vector to store the predicted classes after
+   *     classifying test data.
    */
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
  private:
-  /* Stores the number of classes.*/
-  size_t numClass; 
-  
-  /* Stores the default class. Provided for handling missing attribute values.*/
+  //! Stores the number of classes.
+  size_t numClass;
+
+  //! Stores the default class. Provided for handling missing attribute values.
   size_t defaultClass;
-  
-  /* Stores the value of the attribute on which to split.*/
+
+  //! Stores the value of the attribute on which to split.
   int splitCol;
-  
-  /* Flag value for distinct input class labels.*/
-  int oneClass; 
-  
-  /* Size of bucket while determining splitting criterion.*/
+
+  //! Flag value for distinct input class labels.
+  int oneClass;
+
+  //! Size of bucket while determining splitting criterion.
   size_t bucketSize;
-  
-  /* Stores the class labels for the input data*/
+
+  //! Stores the class labels for the input data.
   arma::Row<size_t> classLabels;
-  
-  /* Stores the splitting criterion after training.*/
+
+  //! Stores the splitting criterion after training.
   arma::mat split;
-  
-  /* 
-  Sets up attribute as if it were splitting on it and 
-  finds entropy when splitting on attribute.
-
-  @param: attribute - a row from the training data, which might be a
-                      candidate for the splitting attribute.
-  */
-  double SetupSplitAttribute(const arma::rowvec& attribute);
 
-  /* 
-  After having decided the attribute on which to split, 
-  train on that attribute.
+  /**
+   * Sets up attribute as if it were splitting on it and finds entropy when
+   * splitting on attribute.
+   *
+   * @param attribute A row from the training data, which might be a
+   *     candidate for the splitting attribute.
+   */
+  double SetupSplitAttribute(const arma::rowvec& attribute);
 
-  @param: attribute - attribute is the attribute decided by the constructor
-                      on which we now train the decision stump.
+  /**
+   * After having decided the attribute on which to split, train on that
+   * attribute.
+   *
+   * @param attribute attribute is the attribute decided by the constructor
+   *      on which we now train the decision stump.
    */
   template <typename rType> void TrainOnAtt(const arma::rowvec& attribute);
 
-  /* After the "split" matrix has been set up, 
-     merging ranges with identical class labels.
+  /**
+   * After the "split" matrix has been set up, merge ranges with identical class
+   * labels.
    */
   void MergeRanges();
 
-  /* 
-  Used to count the most frequently occurring element in subCols.
-
-  @param: subCols - the vector in which to find the most frequently 
-                    occurring element.  
+  /**
+   * Count the most frequently occurring element in subCols.
+   *
+   * @param subCols The vector in which to find the most frequently
+   *     occurring element.
    */
   template <typename rType> rType CountMostFreq(const arma::Row<rType>& subCols);
- 
-  /* 
-  Returns 1 if all the values of featureRow are not same.
-
-  @param: featureRow - the attribute which is checked so that it 
-                       does not have identical values. 
-  */
-  template <typename rType> int isDistinct(const arma::Row<rType>& featureRow);
 
-  /* 
-  Calculating Entropy of attribute.
+  /**
+   * Returns 1 if all the values of featureRow are not same.
+   *
+   * @param featureRow The attribute which is checked for identical values.
+   */
+  template <typename rType> int isDistinct(const arma::Row<rType>& featureRow);
 
-  @param: attribute - the attribute of which we calculate the entropy.
-  @param: labels - corresponding labels of the attribute.
-  */
-  double CalculateEntropy(const arma::rowvec& attribute, 
+  /**
+   * Calculate the entropy of the given attribute.
+   *
+   * @param attribute The attribute of which we calculate the entropy.
+   * @param labels Corresponding labels of the attribute.
+   */
+  double CalculateEntropy(const arma::rowvec& attribute,
                           const arma::rowvec& labels);
-
-  
 };
 
-}; //namespace decision_stump
-}; //namespace mlpack
+}; // namespace decision_stump
+}; // namespace mlpack
 
-#include "decision_stump_impl.cpp"
+#include "decision_stump_impl.hpp"
 
-#endif
\ No newline at end of file
+#endif

Copied: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp (from r16694, /mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.cpp)
==============================================================================
--- /mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp	Mon Jun 23 20:03:24 2014
@@ -1,11 +1,14 @@
 /**
  * @file decision_stump_impl.hpp
  * @author Udit Saxena
-**/
+ *
+ * Implementation of DecisionStump class.
+ */
 
-#ifndef _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
-#define _MLPACK_METHODS_DECISION_STUMP_IMPL_HPP
+#ifndef __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP
+#define __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_IMPL_HPP
 
+// In case it hasn't been included yet.
 #include "decision_stump.hpp"
 
 #include <set>
@@ -13,14 +16,14 @@
 
 namespace mlpack {
 namespace decision_stump {
-/*
-  Constructor. Train on the provided data. Generate a decision stump
-  from data. 
-
-  @param: data - Input, training data.
-  @param: labels - Labels of data.
-  @param: classes - number of distinct classes in labels.
-  @param: inpBucketSize - minimum size of bucket when splitting.
+
+/**
+ * Constructor. Train on the provided data. Generate a decision stump from data.
+ *
+ * @param data Input, training data.
+ * @param labels Labels of data.
+ * @param classes Number of distinct classes in labels.
+ * @param inpBucketSize Minimum size of bucket when splitting.
  */
 template<typename MatType>
 DecisionStump<MatType>::DecisionStump(const MatType& data,
@@ -31,99 +34,95 @@
   arma::Row<size_t> zLabels(labels.n_elem);
   zLabels.fill(0);
   classLabels = labels + zLabels;
-  
+
   numClass = classes;
   bucketSize = inpBucketSize;
 
   /* Check whether the input labels are not all identical. */
-  if ( !isDistinct<size_t>(classLabels) )
+  if (!isDistinct<size_t>(classLabels))
   {
-    // If the classLabels are all identical, 
-    // the default class is the only class set. 
+    // If the classLabels are all identical, the default class is the only
+    // class.
     oneClass = 1;
-    defaultClass = classLabels(0); 
+    defaultClass = classLabels(0);
   }
 
   else
   {
-    // If classLabels are not all identical
-    // proceed for training
-
+    // If classLabels are not all identical, proceed with training.
     oneClass = 0;
-    int bestAtt=-1,i;
-    double entropy,bestEntropy=DBL_MAX; 
+    int bestAtt = -1;
+    double entropy;
+    double bestEntropy = DBL_MAX;
 
-    // Set the default class to handle attribute values which are 
-    // not present in the training data. 
+    // Set the default class to handle attribute values which are not present in
+    // the training data.
     defaultClass = CountMostFreq<size_t>(classLabels);
 
-    for (i = 0;i < data.n_rows; i++)
+    for (int i = 0; i < data.n_rows; i++)
     {
-      // going through each attribute of data.
+      // Go through each attribute of the data.
       if (isDistinct<double>(data.row(i)))
       {
-        // for each attribute with non-identical values, 
-        // treat it as a potential splitting attribute
-        // and calculate entropy if split on it.
-        entropy=SetupSplitAttribute(data.row(i));
-    
-        // finding the attribute with the bestEntropy
-        // so that the gain is max.
+        // For each attribute with non-identical values, treat it as a potential
+        // splitting attribute and calculate entropy if split on it.
+        entropy = SetupSplitAttribute(data.row(i));
+
+        // Find the attribute with the bestEntropy so that the gain is
+        // maximized.
         if (entropy < bestEntropy)
         {
           bestAtt = i;
           bestEntropy = entropy;
         }
-
       }
     }
     splitCol = bestAtt;
 
-    // once the splitting column/attribute has been decided, 
-    // train on it.
+    // Once the splitting column/attribute has been decided, train on it.
     TrainOnAtt<double>(data.row(splitCol));
   }
 }
 
-/*
-  Classification function. After training, classify test, and put the 
-  predicted classes in predictedLabels.
-
-  @param: test - testing data or data to classify. 
-  @param: predictedLabels - vector to store the predicted classes after
-                            classifying test
+/**
+ * Classification function. After training, classify test, and put the predicted
+ * classes in predictedLabels.
+ *
+ * @param test Testing data or data to classify.
+ * @param predictedLabels Vector to store the predicted classes after
+ *      classifying test
  */
 template<typename MatType>
 void DecisionStump<MatType>::Classify(const MatType& test,
                                       arma::Row<size_t>& predictedLabels)
 {
-  int i,j,flag;
+  int flag;
   double val;
-  if ( !oneClass )
+  if (!oneClass)
   {
-    for (i = 0; i < test.n_cols; i++)
+    for (int i = 0; i < test.n_cols; i++)
     {
-      j = 0;
+      int j = 0;
       flag = 0;
 
       val = test(splitCol,i);
       while ((j < split.n_rows) && (!flag))
       {
-        if(val < split(j,0) && (!j))
+        if (val < split(j, 0) && (!j))
         {
-          predictedLabels(i) = split(0,1);
+          predictedLabels(i) = split(0, 1);
           flag = 1;
         }
-        else if (val >= split(j,0))
+        else if (val >= split(j, 0))
         {
-          if(j == split.n_rows - 1)
+          if (j == split.n_rows - 1)
           {
             predictedLabels(i) = split(split.n_rows - 1, 1);
             flag = 1;
           }
-          else if (val < split(j+1,0))
+          else if (val < split(j + 1, 0))
           {
-            predictedLabels(i) = split(j,1);
+            predictedLabels(i) = split(j, 1);
             flag = 1;
           }
         }
@@ -133,18 +132,17 @@
   }
   else
   {
-    for (i = 0;i < test.n_cols;i++)
-      predictedLabels(i)=defaultClass;
+    for (int i = 0; i < test.n_cols; i++)
+      predictedLabels(i) = defaultClass;
   }
-
 }
 
-/* 
-  Sets up attribute as if it were splitting on it and 
-  finds entropy when splitting on attribute.
-
-  @param: attribute - a row from the training data, which might be a
-                      candidate for the splitting attribute.
+/**
+ * Sets up attribute as if it were splitting on it and finds entropy when
+ * splitting on attribute.
+ *
+ * @param attribute A row from the training data, which might be a candidate for
+ *      the splitting attribute.
  */
 template <typename MatType>
 double DecisionStump<MatType>::SetupSplitAttribute(const arma::rowvec& attribute)
@@ -152,18 +150,16 @@
   int i, count, begin, end;
   double entropy = 0.0;
 
-  // sorting the attribute, for calculating splitting ranges
+  // Sort the attribute in order to calculate splitting ranges.
   arma::rowvec sortedAtt = arma::sort(attribute);
 
-  // storing the indexes of the sorted attribute to build 
-  // a vector of sorted labels.
-  // this sort is stable.
+  // Store the indices of the sorted attribute to build a vector of sorted
+  // labels.  This sort is stable.
   arma::uvec sortedIndexAtt = arma::stable_sort_index(attribute.t());
 
-  // vector of sorted labels
   arma::Row<size_t> sortedLabels(attribute.n_elem);
   sortedLabels.fill(0);
-  
+
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = classLabels(sortedIndexAtt(i));
 
@@ -173,7 +169,8 @@
   i = 0;
   count = 0;
 
-  // this splits the sorted into buckets of size >= inpBucketSize
+  // This splits the sorted into buckets of size greater than or equal to
+  // inpBucketSize.
   while (i < sortedLabels.n_elem)
   {
     count++;
@@ -195,14 +192,14 @@
       entropy += CalculateEntropy(subColAtts, subColLabels);
       i++;
     }
-    else if( sortedLabels(i) != sortedLabels(i + 1) )
+    else if (sortedLabels(i) != sortedLabels(i + 1))
     {
-      if (count < bucketSize) 
+      if (count < bucketSize)
       {
         begin = i - count + 1;
         end = begin + bucketSize - 1;
-        
-        if ( end > sortedLabels.n_elem - 1)
+
+        if (end > sortedLabels.n_elem - 1)
           end = sortedLabels.n_elem - 1;
       }
       else
@@ -221,12 +218,11 @@
 
       subColAtts = sortedAtt.cols(begin, end) + zSubColAtts;
 
-      // now using subColLabels and subColAtts to calculate entropuy
+      // Now use subColLabels and subColAtts to calculate entropy.
       entropy += CalculateEntropy(subColAtts, subColLabels);
 
       i = end + 1;
       count = 0;
-
     }
     else
       i++;
@@ -234,12 +230,12 @@
   return entropy;
 }
 
-/* 
-  After having decided the attribute on which to split, 
-  train on that attribute.
-
-  @param: attribute - attribute is the attribute decided by the constructor
-                      on which we now train the decision stump.
+/**
+ * After having decided the attribute on which to split, train on that
+ * attribute.
+ *
+ * @param attribute Attribute is the attribute decided by the constructor on
+ *      which we now train the decision stump.
  */
 template <typename MatType>
 template <typename rType>
@@ -255,7 +251,7 @@
 
   for (i = 0; i < attribute.n_elem; i++)
     sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
-  
+
   arma::rowvec subCols;
   rType mostFreq;
   i = 0;
@@ -280,14 +276,15 @@
 
       i++;
     }
-    else if( sortedLabels(i) != sortedLabels(i + 1) )
+    else if (sortedLabels(i) != sortedLabels(i + 1))
     {
-      if (count < bucketSize) // test for differevalues of bucketSize, especially extreme cases. 
+      if (count < bucketSize)
       {
+        // Test for different values of bucketSize, especially extreme cases.
         begin = i - count + 1;
         end = begin + bucketSize - 1;
-        
-        if ( end > sortedLabels.n_elem - 1)
+
+        if (end > sortedLabels.n_elem - 1)
           end = sortedLabels.n_elem - 1;
       }
       else
@@ -300,12 +297,11 @@
 
       subCols = sortedLabels.cols(begin, end) + zSubCols;
 
-      // finding the most freq element in subCols so as to assign a label to the
-      // bucket of subCols
-
+      // Find the most frequent element in subCols so as to assign a label to
+      // the bucket of subCols.
       mostFreq = CountMostFreq<double>(subCols);
 
-      tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
+      tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
       split = arma::join_cols(split, tempSplit);
 
       i = end + 1;
@@ -315,26 +311,25 @@
       i++;
   }
 
-  // now trimming the split matrix so that buckets one after the after 
-  // which point to the same classLabel are merged as one big bucket.
+  // Now trim the split matrix so that buckets one after the after which point
+  // to the same classLabel are merged as one big bucket.
   MergeRanges();
 }
 
-/* After the "split" matrix has been set up, 
-     merging ranges with identical class labels.
+/**
+ * After the "split" matrix has been set up, merge ranges with identical class
+ * labels.
  */
 template <typename MatType>
 void DecisionStump<MatType>::MergeRanges()
 {
-  int i;
-  for (i = 1;i < split.n_rows; i++)
+  for (int i = 1; i < split.n_rows; i++)
   {
-    if (split(i,1) == split(i-1,1))
+    if (split(i, 1) == split(i - 1, 1))
     {
-      // remove this row, as it has the same label as
-      // the previous bucket.
+      // Remove this row, as it has the same label as the previous bucket.
       split.shed_row(i);
-      // go back to previous row.
+      // Go back to previous row.
       i--;
     }
   }
@@ -344,26 +339,25 @@
 template <typename rType>
 rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
 {
-  // sort subCols for easier processing.
+  // Sort subCols for easier processing.
   arma::Row<rType> sortCounts = arma::sort(subCols);
   rType element;
-  int count = 0, localCount = 0,i;
+  int count = 0, localCount = 0;
 
-  // an O(n) loop which counts the most frequent element in sortCounts
-  for (i = 0; i < sortCounts.n_elem ; ++i)
+  // An O(n) loop which counts the most frequent element in sortCounts
+  for (int i = 0; i < sortCounts.n_elem; ++i)
   {
     if (i == sortCounts.n_elem - 1)
     {
-      if (sortCounts(i-1) == sortCounts(i))
+      if (sortCounts(i - 1) == sortCounts(i))
       {
-        // element = sortCounts(i-1);
+        // element = sortCounts(i - 1);
         localCount++;
       }
-      else
-      if (localCount > count)
+      else if (localCount > count)
         count = localCount;
     }
-    else if (sortCounts(i) != sortCounts(i+1))
+    else if (sortCounts(i) != sortCounts(i + 1))
     {
       localCount = 0;
       count++;
@@ -374,7 +368,7 @@
       if (localCount > count)
       {
         count = localCount;
-        if(localCount == 1)
+        if (localCount == 1)
           element = sortCounts(i);
       }
     }
@@ -382,34 +376,32 @@
   return element;
 }
 
-/* 
-  Returns 1 if all the values of featureRow are not same.
-
-  @param: featureRow - the attribute which is checked so that it 
-                       does not have identical values. 
+/**
+ * Returns 1 if all the values of featureRow are not same.
+ *
+ * @param featureRow The attribute which is checked for identical values.
  */
 template <typename MatType>
 template <typename rType>
 int DecisionStump<MatType>::isDistinct(const arma::Row<rType>& featureRow)
 {
-  if (featureRow.max()-featureRow.min() > 0)
+  if (featureRow.max() - featureRow.min() > 0)
     return 1;
   else
     return 0;
 }
 
-/* 
-  Calculating Entropy of attribute.
-
-  @param: attribute - the attribute of which we calculate the entropy.
-  @param: labels - corresponding labels of the attribute.
+/**
+ * Calculating Entropy of attribute.
+ *
+ * @param attribute The attribute for which we calculate the entropy.
+ * @param labels Corresponding labels of the attribute.
  */
 template<typename MatType>
 double DecisionStump<MatType>::CalculateEntropy(const arma::rowvec& attribute,
                                                 const arma::rowvec& labels)
 {
-  int i,j;
-  double entropy=0.0;
+  double entropy = 0.0;
 
   arma::rowvec uniqueAtt = arma::unique(attribute);
   arma::rowvec uniqueLabel = arma::unique(labels);
@@ -417,12 +409,12 @@
   numElem.fill(0);
   arma::Mat<size_t> entropyArray(uniqueAtt.n_elem,numClass);
   entropyArray.fill(0);
-  
-  // populating entropyArray and numElem, they are to be used as 
-  // helpers to calculate entropy
-  for (j = 0;j < uniqueAtt.n_elem; j++)
+
+  // Populate entropyArray and numElem; they are used as helpers to calculate
+  // entropy.
+  for (int j = 0; j < uniqueAtt.n_elem; j++)
   {
-    for (i = 0; i < attribute.n_elem; i++)
+    for (int i = 0; i < attribute.n_elem; i++)
     {
       if (uniqueAtt[j] == attribute[i])
       {
@@ -432,29 +424,23 @@
     }
   }
 
-  double p1, p2, p3;
-  for ( j = 0; j < uniqueAtt.size(); j++ )
+  for (int j = 0; j < uniqueAtt.size(); j++)
   {
-    p1 = ((double)numElem(j) / attribute.n_elem);
+    const double p1 = ((double) numElem(j) / attribute.n_elem);
 
-    for ( i = 0; i < numClass; i++)
+    for (int i = 0; i < numClass; i++)
     {
-      p2 = ((double)entropyArray(j,i) / numElem(j));
-      
-      if(p2 == 0)
-        p3 = 0;
-      else
-        p3 = (  p2 * log2(p2) );
+      const double p2 = ((double) entropyArray(j, i) / numElem(j));
+      const double p3 = (p2 == 0) ? 0 : p2 * log2(p2);
 
-      entropy+=( p1 * p3 );
+      entropy += p1 * p3;
     }
   }
 
   return entropy;
 }
 
-
 }; // namespace decision_stump
 }; // namespace mlpack
 
-#endif
\ No newline at end of file
+#endif

Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_main.cpp	Mon Jun 23 20:03:24 2014
@@ -2,7 +2,7 @@
  * @author: Udit Saxena
  * @file: decision_stump_main.cpp
  *
- *
+ * Main executable for the decision stump.
  */
 
 #include <mlpack/core.hpp>
@@ -52,13 +52,13 @@
     labelsIn = labelsIn.t();
 
   size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
-  
+
   // normalize the labels
   data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
 
   const size_t num_classes = CLI::GetParam<size_t>("num_classes");
   /*
-  Should number of classes be input or should it be 
+  Should number of classes be input or should it be
   derived from the labels row ?
   */
   const string testingDataFilename = CLI::GetParam<std::string>("test_file");
@@ -87,4 +87,4 @@
   // saving the predictedLabels in the transposed manner in output
 
   return 0;
-}
\ No newline at end of file
+}



More information about the mlpack-svn mailing list