[mlpack-git] master: Add better test for mini-batch SGD, and fix bug. Also remove debugging output. (0341d4d)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Feb 22 12:08:49 EST 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/efe49b78c20eea8df5f2d6b47c2024931cb88d8e...0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc
>---------------------------------------------------------------
commit 0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Feb 22 09:07:14 2016 -0800
Add better test for mini-batch SGD, and fix bug.
Also remove debugging output.
>---------------------------------------------------------------
0341d4d82c030dd7bcf91f5dfe7b9e452b7b3cdc
.../minibatch_sgd/minibatch_sgd_impl.hpp | 42 +++++++++----
src/mlpack/tests/minibatch_sgd_test.cpp | 71 ++++++++++++++++++++++
2 files changed, 102 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
index cdcd744..fe7db5c 100644
--- a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
+++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
@@ -38,7 +38,6 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
size_t numBatches = numFunctions / batchSize;
if (numFunctions % batchSize != 0)
++numBatches; // Capture last few.
- std::cout << "numBatches " << numBatches << ".\n";
// This is only used if shuffle is true.
arma::Col<size_t> visitationOrder;
@@ -63,7 +62,7 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
if ((currentBatch % numBatches) == 0)
{
// Output current objective function.
- std::cout << "Mini-batch SGD: iteration " << i << ", objective "
+ Log::Info << "Mini-batch SGD: iteration " << i << ", objective "
<< overallObjective << "." << std::endl;
if (std::isnan(overallObjective) || std::isinf(overallObjective))
@@ -94,19 +93,40 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
const size_t offset = (shuffle) ? batchSize * visitationOrder[currentBatch]
: batchSize * currentBatch;
function.Gradient(iterate, offset, gradient);
- for (size_t j = 1; j < batchSize; ++j)
+ if (visitationOrder[currentBatch] != numBatches - 1)
{
- arma::mat funcGradient;
- function.Gradient(iterate, offset + j, funcGradient);
- gradient += funcGradient;
+ for (size_t j = 1; j < batchSize; ++j)
+ {
+ arma::mat funcGradient;
+ function.Gradient(iterate, offset + j, funcGradient);
+ gradient += funcGradient;
+ }
+
+ // Now update the iterate.
+ iterate -= (stepSize / batchSize) * gradient;
+
+ // Add that to the overall objective function.
+ for (size_t j = 0; j < batchSize; ++j)
+ overallObjective += function.Evaluate(iterate, offset + j);
}
+ else
+ {
+ // Handle last batch differently: it's not a full-size batch.
+ const size_t lastBatchSize = numFunctions - offset - 1;
+ for (size_t j = 1; j < lastBatchSize; ++j)
+ {
+ arma::mat funcGradient;
+ function.Gradient(iterate, offset + j, funcGradient);
+ gradient += funcGradient;
+ }
- // Now update the iterate.
- iterate -= (stepSize / batchSize) * gradient;
+ // Now update the iterate.
+ iterate -= (stepSize / lastBatchSize) * gradient;
- // Add that to the overall objective function.
- for (size_t j = 0; j < batchSize; ++j)
- overallObjective += function.Evaluate(iterate, offset + j);
+ // Add that to the overall objective function.
+ for (size_t j = 0; j < lastBatchSize; ++j)
+ overallObjective += function.Evaluate(iterate, offset + j);
+ }
}
Log::Info << "Mini-batch SGD: maximum iterations (" << maxIterations << ") "
diff --git a/src/mlpack/tests/minibatch_sgd_test.cpp b/src/mlpack/tests/minibatch_sgd_test.cpp
index a05121a..972980b 100644
--- a/src/mlpack/tests/minibatch_sgd_test.cpp
+++ b/src/mlpack/tests/minibatch_sgd_test.cpp
@@ -10,6 +10,8 @@
#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
#include <mlpack/core/optimizers/sgd/test_function.hpp>
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
+
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -19,6 +21,9 @@ using namespace mlpack;
using namespace mlpack::optimization;
using namespace mlpack::optimization::test;
+using namespace mlpack::distribution;
+using namespace mlpack::regression;
+
BOOST_AUTO_TEST_SUITE(MiniBatchSGDTest);
/**
@@ -60,4 +65,70 @@ BOOST_AUTO_TEST_CASE(SimpleSGDTestFunction)
}
*/
+/**
+ * Run mini-batch SGD on logistic regression and make sure the results are
+ * acceptable.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
+{
+ // Generate a two-Gaussian dataset.
+ GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
+ GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
+
+ arma::mat data(3, 1000);
+ arma::Row<size_t> responses(1000);
+ for (size_t i = 0; i < 500; ++i)
+ {
+ data.col(i) = g1.Random();
+ responses[i] = 0;
+ }
+ for (size_t i = 501; i < 1000; ++i)
+ {
+ data.col(i) = g2.Random();
+ responses[i] = 1;
+ }
+
+ // Shuffle the dataset.
+ arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
+ data.n_cols - 1, data.n_cols));
+ arma::mat shuffledData(3, 1000);
+ arma::Row<size_t> shuffledResponses(1000);
+ for (size_t i = 0; i < data.n_cols; ++i)
+ {
+ shuffledData.col(i) = data.col(indices[i]);
+ shuffledResponses[i] = responses[indices[i]];
+ }
+
+ // Create a test set.
+ arma::mat testData(3, 1000);
+ arma::Row<size_t> testResponses(1000);
+ for (size_t i = 0; i < 500; ++i)
+ {
+ data.col(i) = g1.Random();
+ responses[i] = 0;
+ }
+ for (size_t i = 501; i < 1000; ++i)
+ {
+ data.col(i) = g2.Random();
+ responses[i] = 1;
+ }
+
+ // Now run mini-batch SGD with a couple of batch sizes.
+ for (size_t batchSize = 5; batchSize < 50; batchSize += 5)
+ {
+ LogisticRegression<> lr(shuffledData.n_rows, 0.5);
+
+ LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5);
+ MiniBatchSGD<LogisticRegressionFunction<>> mbsgd(lrf, batchSize);
+ lr.Train(mbsgd);
+
+ // Ensure that the error is close to zero.
+ const double acc = lr.ComputeAccuracy(data, responses);
+ BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance.
+
+ const double testAcc = lr.ComputeAccuracy(testData, testResponses);
+ BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6); // 0.6% error tolerance.
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list