[mlpack-svn] r16025 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 13 13:20:27 EST 2013
Author: rcurtin
Date: Wed Nov 13 13:20:27 2013
New Revision: 16025
Log:
Add test for Gradient().
Modified:
mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
Modified: mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp Wed Nov 13 13:20:27 2013
@@ -120,4 +120,54 @@
}
}
+/**
+ * Test gradient of the LogisticRegressionFunction.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionFunctionGradient)
+{
+ // Very simple fake dataset.
+ arma::mat data("1 1 1;" // Fake row for intercept.
+ "1 2 3;"
+ "1 2 3");
+ arma::vec responses("1 1 0");
+
+ // Create a LogisticRegressionFunction.
+ LogisticRegressionFunction lrf(data, responses, 0.0 /* no regularization */);
+ arma::vec gradient;
+
+ // If the model is at the optimum, then the gradient should be zero.
+ lrf.Gradient(arma::vec("200 -40 -40"), gradient);
+
+ BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+ BOOST_REQUIRE_SMALL(gradient[0], 1e-15);
+ BOOST_REQUIRE_SMALL(gradient[1], 1e-15);
+ BOOST_REQUIRE_SMALL(gradient[2], 1e-15);
+
+ // Perturb two elements in the wrong way, so they need to become smaller.
+ lrf.Gradient(arma::vec("200 -20 -20"), gradient);
+
+ // The actual values are less important; the gradient just needs to be pointed
+ // the right way.
+ BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+ BOOST_REQUIRE_GE(gradient[1], 0.0);
+ BOOST_REQUIRE_GE(gradient[2], 0.0);
+
+ // Perturb two elements in the wrong way, so they need to become larger.
+ lrf.Gradient(arma::vec("200 -60 -60"), gradient);
+
+ // The actual values are less important; the gradient just needs to be pointed
+ // the right way.
+ BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+ BOOST_REQUIRE_LE(gradient[1], 0.0);
+ BOOST_REQUIRE_LE(gradient[2], 0.0);
+
+ // Perturb the intercept element.
+ lrf.Gradient(arma::vec("250 -40 -40"), gradient);
+
+ // The actual values are less important; the gradient just needs to be pointed
+ // the right way.
+ BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+ BOOST_REQUIRE_GE(gradient[0], 0.0);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list