[mlpack-svn] r16025 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 13 13:20:27 EST 2013


Author: rcurtin
Date: Wed Nov 13 13:20:27 2013
New Revision: 16025

Log:
Add test for Gradient().


Modified:
   mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp	Wed Nov 13 13:20:27 2013
@@ -120,4 +120,54 @@
   }
 }
 
+/**
+ * Test gradient of the LogisticRegressionFunction.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionFunctionGradient)
+{
+  // Very simple fake dataset.
+  arma::mat data("1 1 1;" // Fake row for intercept.
+                 "1 2 3;"
+                 "1 2 3");
+  arma::vec responses("1 1 0");
+
+  // Create a LogisticRegressionFunction.
+  LogisticRegressionFunction lrf(data, responses, 0.0 /* no regularization */);
+  arma::vec gradient;
+
+  // If the model is at the optimum, then the gradient should be zero.
+  lrf.Gradient(arma::vec("200 -40 -40"), gradient);
+
+  BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+  BOOST_REQUIRE_SMALL(gradient[0], 1e-15);
+  BOOST_REQUIRE_SMALL(gradient[1], 1e-15);
+  BOOST_REQUIRE_SMALL(gradient[2], 1e-15);
+
+  // Perturb two elements in the wrong way, so they need to become smaller.
+  lrf.Gradient(arma::vec("200 -20 -20"), gradient);
+
+  // The actual values are less important; the gradient just needs to be pointed
+  // the right way.
+  BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+  BOOST_REQUIRE_GE(gradient[1], 0.0);
+  BOOST_REQUIRE_GE(gradient[2], 0.0);
+
+  // Perturb two elements in the wrong way, so they need to become larger.
+  lrf.Gradient(arma::vec("200 -60 -60"), gradient);
+
+  // The actual values are less important; the gradient just needs to be pointed
+  // the right way.
+  BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+  BOOST_REQUIRE_LE(gradient[1], 0.0);
+  BOOST_REQUIRE_LE(gradient[2], 0.0);
+
+  // Perturb the intercept element.
+  lrf.Gradient(arma::vec("250 -40 -40"), gradient);
+
+  // The actual values are less important; the gradient just needs to be pointed
+  // the right way.
+  BOOST_REQUIRE_EQUAL(gradient.n_elem, 3);
+  BOOST_REQUIRE_GE(gradient[0], 0.0);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list