[mlpack-svn] r16024 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 13 13:10:52 EST 2013
Author: rcurtin
Date: Wed Nov 13 13:10:51 2013
New Revision: 16024
Log:
Add test for regularization of objective function.
Modified:
mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
Modified: mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp Wed Nov 13 13:10:51 2013
@@ -80,4 +80,44 @@
}
}
+/**
+ * Test regularization for the LogisticRegressionFunction Evaluate() function.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionFunctionRegularizationEvaluate)
+{
+ const size_t points = 5000;
+ const size_t dimension = 25;
+ const size_t trials = 10;
+
+ // Create a random dataset.
+ arma::mat data;
+ data.randu(dimension, points);
+ // Create random responses.
+ arma::vec responses(points);
+ for (size_t i = 0; i < points; ++i)
+ responses[i] = math::RandInt(0, 2);
+
+ LogisticRegressionFunction lrfNoReg(data, responses, 0.0);
+ LogisticRegressionFunction lrfSmallReg(data, responses, 0.5);
+ LogisticRegressionFunction lrfBigReg(data, responses, 20.0);
+
+ for (size_t i = 0; i < trials; ++i)
+ {
+ arma::vec parameters(dimension);
+ parameters.randu();
+
+ // Regularization term: 0.5 * lambda * || parameters ||_2^2 (but note that
+ // the first parameters term is ignored).
+ const double smallRegTerm = 0.25 * std::pow(arma::norm(parameters, 2), 2.0)
+ - 0.25 * std::pow(parameters[0], 2.0);
+ const double bigRegTerm = 10.0 * std::pow(arma::norm(parameters, 2), 2.0)
+ - 10.0 * std::pow(parameters[0], 2.0);
+
+ BOOST_REQUIRE_CLOSE(lrfNoReg.Evaluate(parameters) - smallRegTerm,
+ lrfSmallReg.Evaluate(parameters), 1e-5);
+ BOOST_REQUIRE_CLOSE(lrfNoReg.Evaluate(parameters) - bigRegTerm,
+ lrfBigReg.Evaluate(parameters), 1e-5);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list