[mlpack-git] master: Add Train() and a test for it. (da1c9a6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:40 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271

>---------------------------------------------------------------

commit da1c9a6db792e55ea18549f073c8cd0ea2f61651
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 30 21:33:07 2015 +0000

    Add Train() and a test for it.


>---------------------------------------------------------------

da1c9a6db792e55ea18549f073c8cd0ea2f61651
 src/mlpack/methods/adaboost/adaboost.hpp      | 17 ++++++++
 src/mlpack/methods/adaboost/adaboost_impl.hpp | 18 ++++++++-
 src/mlpack/tests/adaboost_test.cpp            | 58 +++++++++++++++++++++++++++
 3 files changed, 92 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index 924e2e2..5ab00d6 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -104,6 +104,23 @@ class AdaBoost
   double tolerance;
 
   /**
+   * Train AdaBoost on the given dataset.  This method takes an initialized
+   * WeakLearner; the parameters for this weak learner will be used to train
+   * each of the weak learners during AdaBoost training.  Note that this will
+   * completely overwrite any model that has already been trained with this
+   * object.
+   *
+   * @param data Dataset to train on.
+   * @param labels Labels for each point in the dataset.
+   * @param learner Learner to use for training.
+   */
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const WeakLearner& learner,
+             const size_t iterations = 100,
+             const double tolerance = 1e-6);
+
+  /**
    * Classify the given test points.
    *
    * @param test Testing data.
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 3fc8a21..4d2f8fc 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -43,9 +43,25 @@ AdaBoost<MatType, WeakLearner>::AdaBoost(
     const size_t iterations,
     const double tol)
 {
+  Train(data, labels, other, iterations, tol);
+}
+
+// Train AdaBoost.
+template<typename MatType, typename WeakLearner>
+void AdaBoost<MatType, WeakLearner>::Train(
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const WeakLearner& other,
+    const size_t iterations,
+    const double tolerance)
+{
+  // Clear information from previous runs.
+  wl.clear();
+  alpha.clear();
+
   // Count the number of classes.
   classes = (arma::max(labels) - arma::min(labels)) + 1;
-  tolerance = tol;
+  this->tolerance = tolerance;
 
   // crt is the cumulative rt value for terminating the optimization when rt is
   // changing by less than the tolerance.
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index fa3999c..b8db20c 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -664,4 +664,62 @@ BOOST_AUTO_TEST_CASE(ClassifyTest_IRIS)
   BOOST_REQUIRE(lError <= 0.30);
 }
 
+/**
+ * Ensure that the Train() function works like it is supposed to, by building
+ * AdaBoost on one dataset and then re-training on another dataset.
+ */
+BOOST_AUTO_TEST_CASE(TrainTest)
+{
+  // First train on the iris dataset.
+  arma::mat inputData;
+  if (!data::Load("iris_train.csv", inputData))
+    BOOST_FAIL("Cannot load test dataset iris_train.csv!");
+
+  arma::Mat<size_t> labels;
+  if (!data::Load("iris_train_labels.csv",labels))
+    BOOST_FAIL("Cannot load labels for iris_train_labels.csv");
+
+  size_t perceptronIter = 800;
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
+
+  // Now train AdaBoost.
+  size_t iterations = 50;
+  double tolerance = 1e-10;
+  AdaBoost<> a(inputData, labels.row(0), p, iterations, tolerance);
+
+  // Now load another dataset...
+  if (!data::Load("vc2.txt", inputData))
+    BOOST_FAIL("Cannot load test dataset vc2.txt!");
+  if (!data::Load("vc2_labels.txt",labels))
+    BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+  perceptron::Perceptron<> p2(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
+
+  a.Train(inputData, labels.row(0), p2, iterations, tolerance);
+
+  // Load test set to see if it trained on vc2 correctly.
+  arma::mat testData;
+  if (!data::Load("vc2_test.txt", testData))
+    BOOST_FAIL("Cannot load test dataset vc2_test.txt!");
+
+  arma::Mat<size_t> trueTestLabels;
+  if (!data::Load("vc2_test_labels.txt",trueTestLabels))
+    BOOST_FAIL("Cannot load labels for vc2_test_labels.txt");
+
+  // Define parameters for AdaBoost.
+  arma::Row<size_t> predictedLabels(testData.n_cols);
+  a.Classify(testData, predictedLabels);
+
+  int localError = 0;
+  for (size_t i = 0; i < trueTestLabels.n_cols; i++)
+    if (trueTestLabels(i) != predictedLabels(i))
+      localError++;
+
+  double lError = (double) localError / trueTestLabels.n_cols;
+
+  BOOST_REQUIRE(lError <= 0.30);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list