[mlpack-git] master: Add test for new constructor, Train(), and fix compilation error. (050bc70)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:58 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3

>---------------------------------------------------------------

commit 050bc7007fde3e634229b6b53d4bffbd1699d5e7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Sep 28 17:05:49 2015 -0400

    Add test for new constructor, Train(), and fix compilation error.


>---------------------------------------------------------------

050bc7007fde3e634229b6b53d4bffbd1699d5e7
 .../softmax_regression/softmax_regression_impl.hpp |  4 +-
 src/mlpack/tests/softmax_regression_test.cpp       | 57 ++++++++++++++++++++++
 2 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
index 913b6e1..46afc6c 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_impl.hpp
@@ -23,7 +23,9 @@ SoftmaxRegression(const size_t inputSize,
     lambda(0.0001),
     fitIntercept(fitIntercept)
 {
-  SoftmaxRegressionFunction regressor(arma::mat(), 1,
+  arma::mat tmp;
+  arma::vec tmplabels;
+  SoftmaxRegressionFunction regressor(tmp, tmplabels,
                                       inputSize, numClasses,
                                       lambda, fitIntercept);
   parameters = regressor.GetInitialPoint();
diff --git a/src/mlpack/tests/softmax_regression_test.cpp b/src/mlpack/tests/softmax_regression_test.cpp
index 8e14c51..18bf3f5 100644
--- a/src/mlpack/tests/softmax_regression_test.cpp
+++ b/src/mlpack/tests/softmax_regression_test.cpp
@@ -13,6 +13,7 @@
 using namespace mlpack;
 using namespace mlpack::regression;
 using namespace mlpack::distribution;
+using namespace mlpack::optimization;
 
 BOOST_AUTO_TEST_SUITE(SoftmaxRegressionTest);
 
@@ -342,4 +343,60 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionMultipleClasses)
   BOOST_REQUIRE_CLOSE(testAcc, 100.0, 2.0);
 }
 
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionTrainTest)
+{
+  // Make sure a SoftmaxRegression object trained with Train() operates the same
+  // as a SoftmaxRegression object trained in the constructor.
+  arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+  arma::vec labels(1000);
+  for (size_t i = 0; i < 500; ++i)
+    labels[i] = 0.0;
+  for (size_t i = 500; i < 1000; ++i)
+    labels[i] = 1.0;
+  SoftmaxRegression<> sr(dataset, labels, dataset.n_rows, 2);
+  SoftmaxRegression<> sr2(dataset.n_rows, 2);
+  sr2.Train(dataset, labels, 2);
+
+  // Ensure that the parameters are the same.
+  BOOST_REQUIRE_EQUAL(sr.Parameters().n_rows, sr2.Parameters().n_rows);
+  BOOST_REQUIRE_EQUAL(sr.Parameters().n_cols, sr2.Parameters().n_cols);
+  for (size_t i = 0; i < sr.Parameters().n_elem; ++i)
+  {
+    if (std::abs(sr.Parameters()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(sr2.Parameters()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(sr.Parameters()[i], sr2.Parameters()[i], 1e-5);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(SoftmaxRegressionOptimizerTrainTest)
+{
+  // The same as the previous test, just passing in an instantiated optimizer.
+  arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+  arma::vec labels(1000);
+  for (size_t i = 0; i < 500; ++i)
+    labels[i] = 0.0;
+  for (size_t i = 500; i < 1000; ++i)
+    labels[i] = 1.0;
+
+  SoftmaxRegressionFunction srf(dataset, labels, dataset.n_rows, 2, 0.01, true);
+  L_BFGS<SoftmaxRegressionFunction> lbfgs(srf);
+
+  SoftmaxRegression<> sr(lbfgs);
+  SoftmaxRegression<> sr2(dataset.n_rows, 2);
+  sr2.Train(lbfgs);
+
+  // Ensure that the parameters are the same.
+  BOOST_REQUIRE_EQUAL(sr.Parameters().n_rows, sr2.Parameters().n_rows);
+  BOOST_REQUIRE_EQUAL(sr.Parameters().n_cols, sr2.Parameters().n_cols);
+  for (size_t i = 0; i < sr.Parameters().n_elem; ++i)
+  {
+    if (std::abs(sr.Parameters()[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(sr2.Parameters()[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(sr.Parameters()[i], sr2.Parameters()[i], 1e-5);
+  }
+}
+
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list