[mlpack-svn] r17297 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 5 14:27:16 EST 2014


Author: rcurtin
Date: Wed Nov  5 14:27:15 2014
New Revision: 17297

Log:
Fix logistic regression tests by enforcing a tighter tolerance for SGD
convergence.  The changes introduced to SGD in r17196 to cause SGD to shuffle
also caused situations where SGD can converge way too early, causing the two
tests to fail.  Tightening the tolerance to 1e-10 appears to be the solution to
this issue.


Modified:
   mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp	Wed Nov  5 14:27:15 2014
@@ -495,8 +495,11 @@
                  "1 2 3");
   arma::vec responses("1 1 0");
 
-  // Create a logistic regression object using SGD.
-  LogisticRegression<SGD> lr(data, responses);
+  // Create a logistic regression object using a custom SGD object with a much
+  // smaller tolerance.
+  LogisticRegressionFunction lrf(data, responses, 0.001);
+  SGD<LogisticRegressionFunction> sgd(lrf, 0.01, 100000, 1e-10);
+  LogisticRegression<SGD> lr(sgd);
 
   // Test sigmoid function.
   arma::vec sigmoids = 1 / (1 + arma::exp(-lr.Parameters()[0]
@@ -536,13 +539,17 @@
 // regularization.
 BOOST_AUTO_TEST_CASE(LogisticRegressionSGDRegularizationSimpleTest)
 {
+  math::RandomSeed(std::time(NULL));
   // Very simple fake dataset.
   arma::mat data("1 2 3;"
                  "1 2 3");
   arma::vec responses("1 1 0");
 
-  // Create a logistic regression object using SGD.
-  LogisticRegression<SGD> lr(data, responses, 0.001);
+  // Create a logistic regression object using custom SGD with a much smaller
+  // tolerance.
+  LogisticRegressionFunction lrf(data, responses, 0.001);
+  SGD<LogisticRegressionFunction> sgd(lrf, 0.01, 100000, 1e-10);
+  LogisticRegression<SGD> lr(sgd);
 
   // Test sigmoid function.
   arma::vec sigmoids = 1 / (1 + arma::exp(-lr.Parameters()[0]
@@ -582,7 +589,6 @@
 
   // Ensure that the error is close to zero.
   const double acc = lr.ComputeAccuracy(data, responses);
-
   BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance.
 
   // Create a test set.



More information about the mlpack-svn mailing list