[mlpack-git] master: Add Train() and a test for it. (da1c9a6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:40 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271
>---------------------------------------------------------------
commit da1c9a6db792e55ea18549f073c8cd0ea2f61651
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 30 21:33:07 2015 +0000
Add Train() and a test for it.
>---------------------------------------------------------------
da1c9a6db792e55ea18549f073c8cd0ea2f61651
src/mlpack/methods/adaboost/adaboost.hpp | 17 ++++++++
src/mlpack/methods/adaboost/adaboost_impl.hpp | 18 ++++++++-
src/mlpack/tests/adaboost_test.cpp | 58 +++++++++++++++++++++++++++
3 files changed, 92 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index 924e2e2..5ab00d6 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -104,6 +104,23 @@ class AdaBoost
double tolerance;
/**
+ * Train AdaBoost on the given dataset. This method takes an initialized
+ * WeakLearner; the parameters for this weak learner will be used to train
+ * each of the weak learners during AdaBoost training. Note that this will
+ * completely overwrite any model that has already been trained with this
+ * object.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for each point in the dataset.
+ * @param learner Learner to use for training.
+ */
+ void Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const WeakLearner& learner,
+ const size_t iterations = 100,
+ const double tolerance = 1e-6);
+
+ /**
* Classify the given test points.
*
* @param test Testing data.
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 3fc8a21..4d2f8fc 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -43,9 +43,25 @@ AdaBoost<MatType, WeakLearner>::AdaBoost(
const size_t iterations,
const double tol)
{
+ Train(data, labels, other, iterations, tol);
+}
+
+// Train AdaBoost.
+template<typename MatType, typename WeakLearner>
+void AdaBoost<MatType, WeakLearner>::Train(
+ const MatType& data,
+ const arma::Row<size_t>& labels,
+ const WeakLearner& other,
+ const size_t iterations,
+ const double tolerance)
+{
+ // Clear information from previous runs.
+ wl.clear();
+ alpha.clear();
+
// Count the number of classes.
classes = (arma::max(labels) - arma::min(labels)) + 1;
- tolerance = tol;
+ this->tolerance = tolerance;
// crt is the cumulative rt value for terminating the optimization when rt is
// changing by less than the tolerance.
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index fa3999c..b8db20c 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -664,4 +664,62 @@ BOOST_AUTO_TEST_CASE(ClassifyTest_IRIS)
BOOST_REQUIRE(lError <= 0.30);
}
+/**
+ * Ensure that the Train() function works like it is supposed to, by building
+ * AdaBoost on one dataset and then re-training on another dataset.
+ */
+BOOST_AUTO_TEST_CASE(TrainTest)
+{
+ // First train on the iris dataset.
+ arma::mat inputData;
+ if (!data::Load("iris_train.csv", inputData))
+ BOOST_FAIL("Cannot load test dataset iris_train.csv!");
+
+ arma::Mat<size_t> labels;
+ if (!data::Load("iris_train_labels.csv",labels))
+ BOOST_FAIL("Cannot load labels for iris_train_labels.csv");
+
+ size_t perceptronIter = 800;
+ perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+ perceptronIter);
+
+ // Now train AdaBoost.
+ size_t iterations = 50;
+ double tolerance = 1e-10;
+ AdaBoost<> a(inputData, labels.row(0), p, iterations, tolerance);
+
+ // Now load another dataset...
+ if (!data::Load("vc2.txt", inputData))
+ BOOST_FAIL("Cannot load test dataset vc2.txt!");
+ if (!data::Load("vc2_labels.txt",labels))
+ BOOST_FAIL("Cannot load labels for vc2_labels.txt");
+
+ perceptron::Perceptron<> p2(inputData, labels.row(0), max(labels.row(0)) + 1,
+ perceptronIter);
+
+ a.Train(inputData, labels.row(0), p2, iterations, tolerance);
+
+ // Load test set to see if it trained on vc2 correctly.
+ arma::mat testData;
+ if (!data::Load("vc2_test.txt", testData))
+ BOOST_FAIL("Cannot load test dataset vc2_test.txt!");
+
+ arma::Mat<size_t> trueTestLabels;
+ if (!data::Load("vc2_test_labels.txt",trueTestLabels))
+ BOOST_FAIL("Cannot load labels for vc2_test_labels.txt");
+
+ // Define parameters for AdaBoost.
+ arma::Row<size_t> predictedLabels(testData.n_cols);
+ a.Classify(testData, predictedLabels);
+
+ int localError = 0;
+ for (size_t i = 0; i < trueTestLabels.n_cols; i++)
+ if (trueTestLabels(i) != predictedLabels(i))
+ localError++;
+
+ double lError = (double) localError / trueTestLabels.n_cols;
+
+ BOOST_REQUIRE(lError <= 0.30);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list