[mlpack-git] master: Decision Stumps modified, along with adding Classify() function to AdaBoost. Other minor changes (renaming). (067da88)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:57:59 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 067da88391a0138f73b6241573ebd9d1ab140abe
Author: Udit Saxena <saxena.udit at gmail.com>
Date:   Sat Aug 16 16:00:06 2014 +0000

    Decision Stumps modified, along with adding Classify() function to AdaBoost. Other minor changes (renaming).


>---------------------------------------------------------------

067da88391a0138f73b6241573ebd9d1ab140abe
 src/mlpack/methods/adaboost/adaboost.hpp           | 16 +++--
 src/mlpack/methods/adaboost/adaboost_impl.hpp      | 71 +++++++++++++++----
 src/mlpack/methods/adaboost/adaboost_main.cpp      | 21 ++++--
 .../methods/decision_stump/decision_stump.hpp      | 19 +++--
 .../methods/decision_stump/decision_stump_impl.hpp | 82 ++++++++++++++--------
 .../methods/decision_stump/decision_stump_main.cpp |  3 +-
 src/mlpack/tests/adaboost_test.cpp                 | 61 ++++++++--------
 src/mlpack/tests/decision_stump_test.cpp           |  0
 8 files changed, 188 insertions(+), 85 deletions(-)

diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index b013355..58ee336 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -32,11 +32,11 @@ namespace adaboost {
 
 template<typename MatType = arma::mat,
          typename WeakLearner = mlpack::perceptron::Perceptron<> >
-class Adaboost
+class AdaBoost
 {
  public:
   /**
-   * Constructor. Currently runs the Adaboost.mh algorithm.
+   * Constructor. Currently runs the AdaBoost.mh algorithm.
    *
    * @param data Input data.
    * @param labels Corresponding labels.
@@ -44,7 +44,7 @@ class Adaboost
    * @param tol The tolerance for change in values of rt.
    * @param other Weak Learner, which has been initialized already.
    */
-  Adaboost(const MatType& data,
+  AdaBoost(const MatType& data,
            const arma::Row<size_t>& labels,
            const int iterations,
            const double tol,
@@ -59,6 +59,8 @@ class Adaboost
   // The tolerance for change in rt and when to stop.
   double tolerance;
 
+  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
 private:
   /**
    *  This function helps in building the Weight Distribution matrix
@@ -72,8 +74,14 @@ private:
    */
   void BuildWeightMatrix(const arma::mat& D, arma::rowvec& weights);
 
+  size_t numClasses;
+  
+  std::vector<WeakLearner> wl;
+  std::vector<double> alpha;
+  std::vector<double> z;
+
   
-}; // class Adaboost
+}; // class AdaBoost
 
 } // namespace adaboost
 } // namespace mlpack
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 70cc9a7..187ce2f 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -2,7 +2,7 @@
  * @file adaboost_impl.hpp
  * @author Udit Saxena
  *
- * Implementation of the Adaboost class.
+ * Implementation of the AdaBoost class.
  *
  * @code
  * @article{schapire1999improved,
@@ -27,7 +27,7 @@ namespace mlpack {
 namespace adaboost {
 
 /**
- *  Constructor. Currently runs the Adaboost.mh algorithm
+ *  Constructor. Currently runs the AdaBoost.mh algorithm
  *
  *  @param data Input data
  *  @param labels Corresponding labels
@@ -35,7 +35,7 @@ namespace adaboost {
  *  @param other Weak Learner, which has been initialized already
  */
 template<typename MatType, typename WeakLearner>
-Adaboost<MatType, WeakLearner>::Adaboost(
+AdaBoost<MatType, WeakLearner>::AdaBoost(
     const MatType& data,
     const arma::Row<size_t>& labels,
     const int iterations,
@@ -43,7 +43,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
     const WeakLearner& other)
 {
   // Count the number of classes.
-  const size_t numClasses = (arma::max(labels) - arma::min(labels)) + 1;
+  numClasses = (arma::max(labels) - arma::min(labels)) + 1;
   tolerance = tol;
 
   double rt, crt, alphat = 0.0, zt;
@@ -97,6 +97,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
     // Build the weight vectors
     BuildWeightMatrix(D, weights);
 
+    // std::cout<<"Just about to call the weak leaerner. \n";
     // call the other weak learner and train the labels.
     WeakLearner w(other, tempData, weights, labels);
     w.Classify(tempData, predictedLabels);
@@ -110,14 +111,16 @@ Adaboost<MatType, WeakLearner>::Adaboost(
     {
       if (predictedLabels(j) == labels(j))
       {
-        for (int k = 0;k < numClasses; k++) 
-          rt += D(j,k);
+        // for (int k = 0;k < numClasses; k++) 
+        //   rt += D(j,k);
+        rt += arma::accu(D.row(j));
       }
 
       else
       {
-        for (int k = 0;k < numClasses; k++)
-          rt -= D(j,k); 
+        // for (int k = 0;k < numClasses; k++)
+        //   rt -= D(j,k); 
+        rt -= arma::accu(D.row(j));
       }
     }
     // end calculation of rt
@@ -136,6 +139,9 @@ Adaboost<MatType, WeakLearner>::Adaboost(
     alphat = 0.5 * log((1 + rt) / (1 - rt));
     // end calculation of alphat
 
+    alpha.push_back(alphat);
+    wl.push_back(w);
+
     // now start modifying weights
     for (int j = 0;j < D.n_rows; j++)
     {
@@ -178,17 +184,20 @@ Adaboost<MatType, WeakLearner>::Adaboost(
 
     // Accumulating the value of zt for the Hamming Loss bound.
     ztAccumulator *= zt;
+    z.push_back(zt);
   }
 
   // Iterations are over, now build a strong hypothesis
   // from a weighted combination of these weak hypotheses.
 
-  arma::rowvec tempSumFinalH;
+  // std::cout<<"Just about to look at the final hypo.\n";
+  arma::colvec tempSumFinalH;
   arma::uword max_index;
+  arma::mat sfh = sumFinalH.t();
   
-  for (int i = 0;i < sumFinalH.n_rows; i++)
+  for (int i = 0;i < sfh.n_cols; i++)
   {
-    tempSumFinalH = sumFinalH.row(i);
+    tempSumFinalH = sfh.col(i);
     tempSumFinalH.max(max_index);
     finalH(i) = max_index;
   }
@@ -196,6 +205,44 @@ Adaboost<MatType, WeakLearner>::Adaboost(
 }
 
 /**
+ *
+ */
+ template <typename MatType, typename WeakLearner>
+ void AdaBoost<MatType, WeakLearner>::Classify(
+                                      const MatType& test, 
+                                      arma::Row<size_t>& predictedLabels
+                                      )
+ {
+  arma::Row<size_t> tempPredictedLabels(predictedLabels.n_cols);
+  arma::mat cMatrix(test.n_cols, numClasses);
+
+  cMatrix.fill(0.0);
+  predictedLabels.fill(0);
+
+  for(int i = 0;i < wl.size();i++)
+  {
+    wl[i].Classify(test,tempPredictedLabels);
+
+    for(int j = 0;j < tempPredictedLabels.n_cols; j++)
+    {
+      cMatrix(j,tempPredictedLabels(j)) += (alpha[i] * tempPredictedLabels(j)); 
+    }
+    
+  }
+
+  arma::rowvec cMRow;
+  arma::uword max_index;
+
+  for(int i = 0;i < predictedLabels.n_cols;i++)
+  {
+    cMRow = cMatrix.row(i);
+    cMRow.max(max_index);
+    predictedLabels(i) = max_index;
+  }
+
+ }
+
+/**
  *  This function helps in building the Weight Distribution matrix
  *  which is updated during every iteration. It calculates the
  *  "difficulty" in classifying a point by adding the weights for all
@@ -206,7 +253,7 @@ Adaboost<MatType, WeakLearner>::Adaboost(
  *  @param weights The output weight vector.
  */
 template <typename MatType, typename WeakLearner>
-void Adaboost<MatType, WeakLearner>::BuildWeightMatrix(
+void AdaBoost<MatType, WeakLearner>::BuildWeightMatrix(
     const arma::mat& D,
     arma::rowvec& weights)
 {
diff --git a/src/mlpack/methods/adaboost/adaboost_main.cpp b/src/mlpack/methods/adaboost/adaboost_main.cpp
index d6a1c12..ecc1cbc 100644
--- a/src/mlpack/methods/adaboost/adaboost_main.cpp
+++ b/src/mlpack/methods/adaboost/adaboost_main.cpp
@@ -2,7 +2,7 @@
  * @file: adaboost_main.cpp
  * @author: Udit Saxena
  *
- * Implementation of the Adaboost main file
+ * Implementation of the AdaBoost main file
  *
  *  @code
  *  @article{Schapire:1999:IBA:337859.337870,
@@ -37,8 +37,8 @@ using namespace std;
 using namespace arma;
 using namespace mlpack::adaboost;
 
-PROGRAM_INFO("Adaboost","This program implements the Adaboost (or Adaptive Boost)"
- " algorithm. The variant of Adaboost implemented here is Adaboost.mh. It uses a"
+PROGRAM_INFO("AdaBoost","This program implements the AdaBoost (or Adaptive Boost)"
+ " algorithm. The variant of AdaBoost implemented here is AdaBoost.mh. It uses a"
  " weak learner, either of Decision Stumps or a Perceptron, and over many"
  " iterations, creates a strong learner. It runs these iterations till a tolerance"
  " value is crossed for change in the value of rt."
@@ -64,7 +64,7 @@ PARAM_STRING("output", "The file in which the predicted labels for the test set"
     " will be written.", "o", "output.csv");
 PARAM_INT("iterations","The maximum number of boosting iterations "
   "to be run", "i", 1000);
-PARAM_INT_REQ("classes","The number of classes in the input label set.","c");
+// PARAM_INT("classes","The number of classes in the input label set.","c");
 PARAM_DOUBLE("tolerance","The tolerance for change in values of rt","e",1e-10);
 
 int main(int argc, char *argv[])
@@ -129,8 +129,19 @@ int main(int argc, char *argv[])
   perceptron::Perceptron<> p(trainingData, labels.t(), iter);
   
   Timer::Start("Training");
-  Adaboost<> a(trainingData, labels.t(), iterations, tolerance, p);
+  AdaBoost<> a(trainingData, labels.t(), iterations, tolerance, p);
   Timer::Stop("Training");
 
+  Row<size_t> predictedLabels(testingData.n_cols);
+  Timer::Start("testing");
+  a.Classify(testingData, predictedLabels);
+  Timer::Stop("testing");
+
+  vec results;
+  data::RevertLabels(predictedLabels.t(), mappings, results);
+
+  // Save the predicted labels in a transposed form as output.
+  const string outputFilename = CLI::GetParam<string>("output_file");
+  data::Save(outputFilename, results, true, false);
   return 0;
 }
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 895fd4a..de1418a 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -63,11 +63,13 @@ class DecisionStump
    * @param data The data on which to train this object on.
    * @param D Weight vector to use while training. For boosting purposes.
    * @param labels The labels of data.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
   DecisionStump(const DecisionStump<>& other,
                 const MatType& data,
                 const arma::rowvec& weights,
-                const arma::Row<size_t>& labels);
+                const arma::Row<size_t>& labels
+                );
 
   //! Access the splitting attribute.
   int SplitAttribute() const { return splitAttribute; }
@@ -106,9 +108,12 @@ class DecisionStump
    *
    * @param attribute A row from the training data, which might be a
    *     candidate for the splitting attribute.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
+  template <typename W>
   double SetupSplitAttribute(const arma::rowvec& attribute,
-                             const arma::Row<size_t>& labels);
+                             const arma::Row<size_t>& labels,
+                             W isWeight);
 
   /**
    * After having decided the attribute on which to split, train on that
@@ -147,17 +152,21 @@ class DecisionStump
    *
    * @param attribute The attribute of which we calculate the entropy.
    * @param labels Corresponding labels of the attribute.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
-  template <typename LabelType>
-  double CalculateEntropy(arma::subview_row<LabelType> labels, int begin);
+  template <typename LabelType, typename W>
+  double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
+                          W isWeight);
 
   /**
    * Train the decision stump on the given data and labels.
    *
    * @param data Dataset to train on.
    * @param labels Labels for dataset.
+   * @param isWeight Whether we need to run a weighted Decision Stump.
    */
-  void Train(const MatType& data, const arma::Row<size_t>& labels);
+  template <typename W>
+  void Train(const MatType& data, const arma::Row<size_t>& labels, W isWeight);
 
   //! To store the weight vectors for boosting purposes.
   arma::rowvec weightD;
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 348ab9a..e3b5824 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -30,22 +30,27 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
 {
   numClass = classes;
   bucketSize = inpBucketSize;
+  const bool isWeight = false;
 
-  weightD = arma::rowvec(data.n_cols);
-  weightD.fill(1.0);
-  tempD = weightD;
-
-  Train(data, labels);
+  Train<bool>(data, labels, isWeight);
 }
 
+/**
+ * Train the decision stump on the given data and labels.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for dataset.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
+ */
 template<typename MatType>
-void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels)
+template <typename W>
+void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels, W isWeight)
 {
   // If classLabels are not all identical, proceed with training.
   int bestAtt = 0;
   double entropy;
-  const double rootEntropy = CalculateEntropy<size_t>(
-      labels.subvec(0, labels.n_elem - 1), 0);
+  const double rootEntropy = CalculateEntropy<size_t, W>(
+      labels.subvec(0, labels.n_elem - 1), 0, isWeight);
 
   double gain, bestGain = 0.0;
   for (int i = 0; i < data.n_rows; i++)
@@ -55,7 +60,7 @@ void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>&
     {
       // For each attribute with non-identical values, treat it as a potential
       // splitting attribute and calculate entropy if split on it.
-      entropy = SetupSplitAttribute(data.row(i), labels);
+      entropy = SetupSplitAttribute<W>(data.row(i), labels, isWeight);
 
       gain = rootEntropy - entropy;
       // Find the attribute with the best entropy so that the gain is
@@ -119,6 +124,7 @@ void DecisionStump<MatType>::Classify(const MatType& test,
  * @param data The data on which to train this object on.
  * @param D Weight vector to use while training. For boosting purposes.
  * @param labels The labels of data.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template <typename MatType>
 DecisionStump<MatType>::DecisionStump(
@@ -133,8 +139,8 @@ DecisionStump<MatType>::DecisionStump(
 
   weightD = weights;
   tempD = weightD;
-
-  Train(data, labels);
+  const bool isWeight = true;
+  Train<bool>(data, labels, isWeight);
 }
 
 /**
@@ -143,11 +149,14 @@ DecisionStump<MatType>::DecisionStump(
  *
  * @param attribute A row from the training data, which might be a candidate for
  *      the splitting attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template <typename MatType>
+template <typename W>
 double DecisionStump<MatType>::SetupSplitAttribute(
     const arma::rowvec& attribute,
-    const arma::Row<size_t>& labels)
+    const arma::Row<size_t>& labels,
+    W isWeight)
 {
   int i, count, begin, end;
   double entropy = 0.0;
@@ -167,7 +176,9 @@ double DecisionStump<MatType>::SetupSplitAttribute(
   for (i = 0; i < attribute.n_elem; i++)
   {
     sortedLabels(i) = labels(sortedIndexAtt(i));
-    tempD(i) = weightD(sortedIndexAtt(i));
+    
+    if(isWeight)
+      tempD(i) = weightD(sortedIndexAtt(i));
   }
 
   i = 0;
@@ -188,8 +199,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       // Use ratioEl to calculate the ratio of elements in this split.
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t>(
-          sortedLabels.subvec(begin, end), begin);
+      entropy += ratioEl * CalculateEntropy<size_t, W>(
+          sortedLabels.subvec(begin, end), begin, isWeight);
       i++;
     }
     else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -215,8 +226,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
       }
       const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
 
-      entropy += ratioEl * CalculateEntropy<size_t>(
-          sortedLabels.subvec(begin, end), begin);
+      entropy += ratioEl * CalculateEntropy<size_t, W>(
+          sortedLabels.subvec(begin, end), begin, isWeight);
 
       i = end + 1;
       count = 0;
@@ -404,12 +415,13 @@ int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
  *
  * @param attribute The attribute for which we calculate the entropy.
  * @param labels Corresponding labels of the attribute.
+ * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template<typename MatType>
-template<typename LabelType>
+template<typename LabelType, typename W>
 double DecisionStump<MatType>::CalculateEntropy(
     arma::subview_row<LabelType> labels,
-    int begin)
+    int begin, W isWeight)
 {
   double entropy = 0.0;
   size_t j;
@@ -421,20 +433,34 @@ double DecisionStump<MatType>::CalculateEntropy(
   double accWeight = 0.0;
   // Populate numElem; they are used as helpers to calculate entropy.
 
-  for (j = 0; j < labels.n_elem; j++)
+  if(isWeight)
   {
-    numElem(labels(j)) += tempD(j + begin);
-    accWeight += tempD(j + begin);
-  }
-    // numElem(labels(j))++;
+    for (j = 0; j < labels.n_elem; j++)
+    {
+      numElem(labels(j)) += tempD(j + begin);
+      accWeight += tempD(j + begin);
+    }
+      // numElem(labels(j))++;
 
-  for (j = 0; j < numClass; j++)
-  {
-    const double p1 = ((double) numElem(j) / accWeight);
+    for (j = 0; j < numClass; j++)
+    {
+      const double p1 = ((double) numElem(j) / accWeight);
 
-    entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+      entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+    }
   }
+  else
+  {
+    for (j = 0; j < labels.n_elem; j++)
+      numElem(labels(j))++;
 
+    for (j = 0; j < numClass; j++)
+    {
+      const double p1 = ((double) numElem(j) / labels.n_elem);
+
+      entropy += (p1 == 0) ? 0 : p1 * log2(p1);
+    }
+  }
   return entropy;
 }
 
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
index f6d6053..48ad4e3 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_main.cpp
@@ -95,7 +95,8 @@ int main(int argc, char *argv[])
         << ")!" << std::endl;
 
   Timer::Start("training");
-  DecisionStump<> ds(trainingData, labels.t(), numClasses, inpBucketSize);
+  DecisionStump<> ds(trainingData, labels.t(), numClasses,
+                     inpBucketSize);
   Timer::Stop("training");
 
   Row<size_t> predictedLabels(testingData.n_cols);
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index 704f3d0..3abba11 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -1,8 +1,8 @@
 /**
- * @file Adaboost_test.cpp
+ * @file AdaBoost_test.cpp
  * @author Udit Saxena
  *
- * Tests for Adaboost class.
+ * Tests for AdaBoost class.
  */
 
 #include <mlpack/core.hpp>
@@ -15,10 +15,10 @@ using namespace mlpack;
 using namespace arma;
 using namespace mlpack::adaboost;
 
-BOOST_AUTO_TEST_SUITE(AdaboostTest);
+BOOST_AUTO_TEST_SUITE(AdaBoostTest);
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
  *  It checks whether the hamming loss breaches the upperbound, which
  *  is provided by ztAccumulator.
  */
@@ -45,7 +45,8 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundIris)
   // Define parameters for the adaboost
   int iterations = 100;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
+
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -56,7 +57,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundIris)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
  *  adaboost.
@@ -92,7 +93,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris)
   // Define parameters for the adaboost
   int iterations = 100;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -103,7 +104,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Vertebral 
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Vertebral 
  *  Column dataset.
  *  It checks whether the hamming loss breaches the upperbound, which
  *  is provided by ztAccumulator.
@@ -131,7 +132,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
   // Define parameters for the adaboost
   int iterations = 50;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -142,7 +143,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Vertebral 
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Vertebral 
  *  Column dataset.
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -179,7 +180,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn)
   // Define parameters for the adaboost
   int iterations = 50;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -190,7 +191,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on non-linearly 
+ *  This test case runs the AdaBoost.mh algorithm on non-linearly 
  *  separable dataset. 
  *  It checks whether the hamming loss breaches the upperbound, which
  *  is provided by ztAccumulator.
@@ -218,7 +219,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData)
   // Define parameters for the adaboost
   int iterations = 50;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -229,7 +230,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on a non-linearly 
+ *  This test case runs the AdaBoost.mh algorithm on a non-linearly 
  *  separable dataset. 
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -266,7 +267,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData)
   // Define parameters for the adaboost
   int iterations = 50;
   double tolerance = 1e-10;
-  Adaboost<> a(inputData, labels.row(0), iterations, tolerance, p);
+  AdaBoost<> a(inputData, labels.row(0), iterations, tolerance, p);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
     if(labels(i) != a.finalHypothesis(i))
@@ -277,7 +278,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Iris dataset.
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Iris dataset.
  *  It checks whether the hamming loss breaches the upperbound, which
  *  is provided by ztAccumulator.
  *  This is for the weak learner: Decision Stumps.
@@ -307,7 +308,7 @@ BOOST_AUTO_TEST_CASE(HammingLossIris_DS)
   int iterations = 50;
   double tolerance = 1e-10;
   
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
           labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
@@ -319,7 +320,7 @@ BOOST_AUTO_TEST_CASE(HammingLossIris_DS)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on a non-linearly 
+ *  This test case runs the AdaBoost.mh algorithm on a non-linearly 
  *  separable dataset. 
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -360,7 +361,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris_DS)
   int iterations = 50;
   double tolerance = 1e-10;
   
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
            labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
@@ -371,7 +372,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris_DS)
   BOOST_REQUIRE(error <= weakLearnerErrorRate);
 }
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Vertebral 
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Vertebral 
  *  Column dataset.
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -404,7 +405,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn_DS)
   int iterations = 50;
   double tolerance = 1e-10;
   
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData,
            labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
@@ -416,7 +417,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn_DS)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on the UCI Vertebral 
+ *  This test case runs the AdaBoost.mh algorithm on the UCI Vertebral 
  *  Column dataset.
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -456,7 +457,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn_DS)
   // Define parameters for the adaboost
   int iterations = 50;
   double tolerance = 1e-10;
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
            labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
@@ -467,7 +468,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn_DS)
   BOOST_REQUIRE(error <= weakLearnerErrorRate);
 }
 /**
- *  This test case runs the Adaboost.mh algorithm on non-linearly 
+ *  This test case runs the AdaBoost.mh algorithm on non-linearly 
  *  separable dataset. 
  *  It checks whether the hamming loss breaches the upperbound, which
  *  is provided by ztAccumulator.
@@ -499,7 +500,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData_DS)
   int iterations = 50;
   double tolerance = 1e-10;
   
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
            labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)
@@ -511,7 +512,7 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData_DS)
 }
 
 /**
- *  This test case runs the Adaboost.mh algorithm on a non-linearly 
+ *  This test case runs the AdaBoost.mh algorithm on a non-linearly 
  *  separable dataset. 
  *  It checks if the error returned by running a single instance of the 
  *  weak learner is worse than running the boosted weak learner using 
@@ -535,7 +536,7 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData_DS)
   // Define your own weak learner, Decision Stump in this case.
 
   const size_t numClasses = 2;
-  const size_t inpBucketSize = 6;
+  const size_t inpBucketSize = 3;
 
   arma::Row<size_t> dsPrediction(labels.n_cols);
 
@@ -549,10 +550,10 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData_DS)
   double weakLearnerErrorRate = (double) countWeakLearnerError / labels.n_cols;
   
   // Define parameters for the adaboost
-  int iterations = 50;
-  double tolerance = 1e-10;
+  int iterations = 500;
+  double tolerance = 1e-23;
   
-  Adaboost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
+  AdaBoost<arma::mat, mlpack::decision_stump::DecisionStump<> > a(inputData, 
            labels.row(0), iterations, tolerance, ds);
   int countError = 0;
   for (size_t i = 0; i < labels.n_cols; i++)



More information about the mlpack-git mailing list