[mlpack-git] master: Add better test for mini-batch SGD, and fix bug. Also remove debugging output. (0341d4d)

gitdub at mlpack.org gitdub at mlpack.org
Mon Feb 22 12:08:49 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/efe49b78c20eea8df5f2d6b47c2024931cb88d8e...0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc

>---------------------------------------------------------------

commit 0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 22 09:07:14 2016 -0800

    Add better test for mini-batch SGD, and fix bug.
    Also remove debugging output.


>---------------------------------------------------------------

0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc
 .../minibatch_sgd/minibatch_sgd_impl.hpp           | 42 +++++++++----
 src/mlpack/tests/minibatch_sgd_test.cpp            | 71 ++++++++++++++++++++++
 2 files changed, 102 insertions(+), 11 deletions(-)

diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
index cdcd744..fe7db5c 100644
--- a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
+++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
@@ -38,7 +38,6 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
   size_t numBatches = numFunctions / batchSize;
   if (numFunctions % batchSize != 0)
     ++numBatches; // Capture last few.
-  std::cout << "numBatches " << numBatches << ".\n";
 
   // This is only used if shuffle is true.
   arma::Col<size_t> visitationOrder;
@@ -63,7 +62,7 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
     if ((currentBatch % numBatches) == 0)
     {
       // Output current objective function.
-      std::cout << "Mini-batch SGD: iteration " << i << ", objective "
+      Log::Info << "Mini-batch SGD: iteration " << i << ", objective "
           << overallObjective << "." << std::endl;
 
       if (std::isnan(overallObjective) || std::isinf(overallObjective))
@@ -94,19 +93,40 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
     const size_t offset = (shuffle) ? batchSize * visitationOrder[currentBatch]
         : batchSize * currentBatch;
     function.Gradient(iterate, offset, gradient);
-    for (size_t j = 1; j < batchSize; ++j)
+    if (visitationOrder[currentBatch] != numBatches - 1)
     {
-      arma::mat funcGradient;
-      function.Gradient(iterate, offset + j, funcGradient);
-      gradient += funcGradient;
+      for (size_t j = 1; j < batchSize; ++j)
+      {
+        arma::mat funcGradient;
+        function.Gradient(iterate, offset + j, funcGradient);
+        gradient += funcGradient;
+      }
+
+      // Now update the iterate.
+      iterate -= (stepSize / batchSize) * gradient;
+
+      // Add that to the overall objective function.
+      for (size_t j = 0; j < batchSize; ++j)
+        overallObjective += function.Evaluate(iterate, offset + j);
     }
+    else
+    {
+      // Handle last batch differently: it's not a full-size batch.
+      const size_t lastBatchSize = numFunctions - offset - 1;
+      for (size_t j = 1; j < lastBatchSize; ++j)
+      {
+        arma::mat funcGradient;
+        function.Gradient(iterate, offset + j, funcGradient);
+        gradient += funcGradient;
+      }
 
-    // Now update the iterate.
-    iterate -= (stepSize / batchSize) * gradient;
+      // Now update the iterate.
+      iterate -= (stepSize / lastBatchSize) * gradient;
 
-    // Add that to the overall objective function.
-    for (size_t j = 0; j < batchSize; ++j)
-      overallObjective += function.Evaluate(iterate, offset + j);
+      // Add that to the overall objective function.
+      for (size_t j = 0; j < lastBatchSize; ++j)
+        overallObjective += function.Evaluate(iterate, offset + j);
+    }
   }
 
   Log::Info << "Mini-batch SGD: maximum iterations (" << maxIterations << ") "
diff --git a/src/mlpack/tests/minibatch_sgd_test.cpp b/src/mlpack/tests/minibatch_sgd_test.cpp
index a05121a..972980b 100644
--- a/src/mlpack/tests/minibatch_sgd_test.cpp
+++ b/src/mlpack/tests/minibatch_sgd_test.cpp
@@ -10,6 +10,8 @@
 #include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
 #include <mlpack/core/optimizers/sgd/test_function.hpp>
 
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
+
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
@@ -19,6 +21,9 @@ using namespace mlpack;
 using namespace mlpack::optimization;
 using namespace mlpack::optimization::test;
 
+using namespace mlpack::distribution;
+using namespace mlpack::regression;
+
 BOOST_AUTO_TEST_SUITE(MiniBatchSGDTest);
 
 /**
@@ -60,4 +65,70 @@ BOOST_AUTO_TEST_CASE(SimpleSGDTestFunction)
 }
 */
 
+/**
+ * Run mini-batch SGD on logistic regression and make sure the results are
+ * acceptable.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
+{
+  // Generate a two-Gaussian dataset.
+  GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
+
+  arma::mat data(3, 1000);
+  arma::Row<size_t> responses(1000);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    data.col(i) = g1.Random();
+    responses[i] = 0;
+  }
+  for (size_t i = 501; i < 1000; ++i)
+  {
+    data.col(i) = g2.Random();
+    responses[i] = 1;
+  }
+
+  // Shuffle the dataset.
+  arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
+      data.n_cols - 1, data.n_cols));
+  arma::mat shuffledData(3, 1000);
+  arma::Row<size_t> shuffledResponses(1000);
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    shuffledData.col(i) = data.col(indices[i]);
+    shuffledResponses[i] = responses[indices[i]];
+  }
+
+  // Create a test set.
+  arma::mat testData(3, 1000);
+  arma::Row<size_t> testResponses(1000);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    data.col(i) = g1.Random();
+    responses[i] = 0;
+  }
+  for (size_t i = 501; i < 1000; ++i)
+  {
+    data.col(i) = g2.Random();
+    responses[i] = 1;
+  }
+
+  // Now run mini-batch SGD with a couple of batch sizes.
+  for (size_t batchSize = 5; batchSize < 50; batchSize += 5)
+  {
+    LogisticRegression<> lr(shuffledData.n_rows, 0.5);
+
+    LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5);
+    MiniBatchSGD<LogisticRegressionFunction<>> mbsgd(lrf, batchSize);
+    lr.Train(mbsgd);
+
+    // Ensure that the error is close to zero.
+    const double acc = lr.ComputeAccuracy(data, responses);
+    BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance.
+
+    const double testAcc = lr.ComputeAccuracy(testData, testResponses);
+    BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6); // 0.6% error tolerance.
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list