[mlpack-git] master: Add test for new constructor, Train(), and fix compilation error. (050bc70)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:58 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3
>---------------------------------------------------------------
commit 050bc7007fde3e634229b6b53d4bffbd1699d5e7
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Sep 28 17:05:49 2015 -0400
Add test for new constructor, Train(), and fix compilation error.
>---------------------------------------------------------------
050bc7007fde3e634229b6b53d4bffbd1699d5e7
.../softmax_regression/softmax_regression_impl.hpp | 4 +-
src/mlpack/tests/softmax_regression_test.cpp | 57 ++++++++++++++++++++++
2 files changed, 60 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 913b6e1..46afc6c 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -23,7 +23,9 @@ SoftmaxRegression(const size_t inputSize,
lambda(0.0001),
fitIntercept(fitIntercept)
{
- SoftmaxRegressionFunction regressor(arma::mat(), 1,
+ arma::mat tmp;
+ arma::vec tmplabels;
+ SoftmaxRegressionFunction regressor(tmp, tmplabels,
inputSize, numClasses,
lambda, fitIntercept);
parameters = regressor.GetInitialPoint();
diff --git a/src/mlpack/tests/softmax_regression_test.cpp b/src/mlpack/tests/softmax_regression_test.cpp
index 8e14c51..18bf3f5 100644
--- a/src/mlpack/tests/softmax_regression_test.cpp
+++ b/src/mlpack/tests/softmax_regression_test.cpp
@@ -13,6 +13,7 @@
using namespace mlpack;
using namespace mlpack::regression;
using namespace mlpack::distribution;
+using namespace mlpack::optimization;
BOOST_AUTO_TEST_SUITE(SoftmaxRegressionTest);
@@ -342,4 +343,60 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
BOOST_REQUIRE_CLOSE(testAcc, 100.0, 2.0);
}
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionTrainTest)
+{
+ // Make sure a SoftmaxRegression object trained with Train() operates the same
+ // as a SoftmaxRegression object trained in the constructor.
+ arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+ arma::vec labels(1000);
+ for (size_t i = 0; i < 500; ++i)
+ labels[i] = 0.0;
+ for (size_t i = 500; i < 1000; ++i)
+ labels[i] = 1.0;
+ SoftmaxRegression<> sr(dataset, labels, dataset.n_rows, 2);
+ SoftmaxRegression<> sr2(dataset.n_rows, 2);
+ sr2.Train(dataset, labels, 2);
+
+ // Ensure that the parameters are the same.
+ BOOST_REQUIRE_EQUAL(sr.Parameters().n_rows, sr2.Parameters().n_rows);
+ BOOST_REQUIRE_EQUAL(sr.Parameters().n_cols, sr2.Parameters().n_cols);
+ for (size_t i = 0; i < sr.Parameters().n_elem; ++i)
+ {
+ if (std::abs(sr.Parameters()[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(sr2.Parameters()[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(sr.Parameters()[i], sr2.Parameters()[i], 1e-5);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionOptimizerTrainTest)
+{
+ // The same as the previous test, just passing in an instantiated optimizer.
+ arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+ arma::vec labels(1000);
+ for (size_t i = 0; i < 500; ++i)
+ labels[i] = 0.0;
+ for (size_t i = 500; i < 1000; ++i)
+ labels[i] = 1.0;
+
+ SoftmaxRegressionFunction srf(dataset, labels, dataset.n_rows, 2, 0.01, true);
+ L_BFGS<SoftmaxRegressionFunction> lbfgs(srf);
+
+ SoftmaxRegression<> sr(lbfgs);
+ SoftmaxRegression<> sr2(dataset.n_rows, 2);
+ sr2.Train(lbfgs);
+
+ // Ensure that the parameters are the same.
+ BOOST_REQUIRE_EQUAL(sr.Parameters().n_rows, sr2.Parameters().n_rows);
+ BOOST_REQUIRE_EQUAL(sr.Parameters().n_cols, sr2.Parameters().n_cols);
+ for (size_t i = 0; i < sr.Parameters().n_elem; ++i)
+ {
+ if (std::abs(sr.Parameters()[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(sr2.Parameters()[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(sr.Parameters()[i], sr2.Parameters()[i], 1e-5);
+ }
+}
+
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list