[mlpack-git] master: Add minibatch SGD implementation. (2c476c2)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Feb 12 17:46:57 EST 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ec35c2d4bc564e1431abf7c9aa14737c1d40328b...46e0a233c29ae638ba60eb224826168516c0ec4e
>---------------------------------------------------------------
commit 2c476c24c2a81e95d3a974d1932d6a3d66874476
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Feb 12 14:46:57 2016 -0800
Add minibatch SGD implementation.
>---------------------------------------------------------------
2c476c24c2a81e95d3a974d1932d6a3d66874476
src/mlpack/core/optimizers/CMakeLists.txt | 1 +
.../{sa => minibatch_sgd}/CMakeLists.txt | 5 +-
.../optimizers/minibatch_sgd/minibatch_sgd.hpp | 162 +++++++++++++++++++++
.../minibatch_sgd/minibatch_sgd_impl.hpp | 126 ++++++++++++++++
4 files changed, 291 insertions(+), 3 deletions(-)
diff --git a/src/mlpack/core/optimizers/CMakeLists.txt b/src/mlpack/core/optimizers/CMakeLists.txt
index 25a0c8a..13731a6 100644
--- a/src/mlpack/core/optimizers/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/CMakeLists.txt
@@ -1,6 +1,7 @@
set(DIRS
aug_lagrangian
lbfgs
+ minibatch_sgd
sa
sdp
sgd
diff --git a/src/mlpack/core/optimizers/sa/CMakeLists.txt b/src/mlpack/core/optimizers/minibatch_sgd/CMakeLists.txt
similarity index 79%
copy from src/mlpack/core/optimizers/sa/CMakeLists.txt
copy to src/mlpack/core/optimizers/minibatch_sgd/CMakeLists.txt
index 7ee164a..e88c3ed 100644
--- a/src/mlpack/core/optimizers/sa/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/minibatch_sgd/CMakeLists.txt
@@ -1,7 +1,6 @@
set(SOURCES
- sa.hpp
- sa_impl.hpp
- exponential_schedule.hpp
+ minibatch_sgd.hpp
+ minibatch_sgd_impl.hpp
)
set(DIR_SRCS)
diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp
new file mode 100644
index 0000000..910a405
--- /dev/null
+++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp
@@ -0,0 +1,162 @@
+/**
+ * @file minibatch_sgd.hpp
+ * @author Ryan Curtin
+ *
+ * Mini-batch Stochastic Gradient Descent (SGD).
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_MINIBATCH_SGD_MINIBATCH_SGD_HPP
+#define __MLPACK_CORE_OPTIMIZERS_MINIBATCH_SGD_MINIBATCH_SGD_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * Mini-batch Stochastic Gradient Descent is a technique for minimizing a
+ * function which can be expressed as a sum of other functions. That is,
+ * suppose we have
+ *
+ * \f[
+ * f(A) = \sum_{i = 0}^{n} f_i(A)
+ * \f]
+ *
+ * and our task is to minimize \f$ A \f$. Mini-batch SGD iterates over batches
+ * of functions \f$ \{ f_{i0}(A), f_{i1}(A), \ldots, f_{i(m - 1)}(A) \f$ for
+ * some batch size \f$ m \f$, producing the following update scheme:
+ *
+ * \f[
+ * A_{j + 1} = A_j + \alpha \left(\sum_{k = 0}^{m - 1} \nabla f_{ik}(A) \right)
+ * \f]
+ *
+ * where \f$ \alpha \f$ is a parameter which specifies the step size. Each
+ * mini-batch is passed through either sequentially or randomly. The algorithm
+ * continues until \f$ j \f$ reaches the maximum number of iterations---or when
+ * a full sequence of updates through each of the mini-batches produces an
+ * improvement within a certain tolerance \f$ \epsilon \f$.
+ *
+ * The parameter \f$ \epsilon \f$ is specified by the tolerance parameter tot he
+ * constructor, as is the maximum number of iterations specified by the
+ * maxIterations parameter.
+ *
+ * This class is useful for data-dependent functions whose objective function
+ * can be expressed as a sum of objective functions operating on an individual
+ * point. Then, mini-batch SGD considers the gradient of the objective function
+ * operation on an individual mini-batch of points in its update of \f$ A \f$.
+ *
+ * For mini-batch SGD to work, a DecomposableFunctionType template parameter is
+ * required.
+ * This class must implement the following function:
+ *
+ * size_t NumFunctions();
+ * double Evaluate(const arma::mat& coordinates, const size_t i);
+ * void Gradient(const arma::mat& coordinates,
+ * const size_t i,
+ * arma::mat& gradient);
+ *
+ * NumFunctions() should return the number of functions, and in the other two
+ * functions, the parameter i refers to which individual function (or gradient)
+ * is being evaluated. So, for the case of a data-dependent function, such as
+ * NCA (see mlpack::nca::NCA), NumFunctions() should return the number of points
+ * in the dataset, and Evaluate(coordinates, 0) will evaluate the objective
+ * function on the first point in the dataset (presumably, the dataset is held
+ * internally in the DecomposableFunctionType).
+ *
+ * @tparam DecomposableFunctionType Decomposable objective function type to be
+ * minimized.
+ */
+template<typename DecomposableFunctionType>
+class MiniBatchSGD
+{
+ public:
+ /**
+ * Construct the MiniBatchSGD optimizer with the given function and
+ * parameters. The defaults here are not necessarily good for the given
+ * problem, so it is suggested that the values used be tailored for the task
+ * at hand. The maximum number of iterations refers to the maximum number of
+ * mini-batches that are processed.
+ *
+ * @param function Function to be optimized (minimized).
+ * @param batchSize Size of each mini-batch.
+ * @param stepSize Step size for each iteration.
+ * @param maxIterations Maximum number of iterations allowed (0 means no
+ * limit).
+ * @param tolerance Maximum absolute tolerance to terminate algorithm.
+ * @param shuffle If true, the mini-batch order is shuffled; otherwise, each
+ * mini-batch is visited in linear order.
+ */
+ MiniBatchSGD(DecomposableFunctionType& function,
+ const size_t batchSize = 1000,
+ const double stepSize = 0.01,
+ const size_t maxIterations = 100000,
+ const double tolerance = 1e-5,
+ const bool shuffle = true);
+
+ /**
+ * Optimize the given function using mini-batch SGD. The given starting point
+ * will be modified to store the finishing point of the algorithm, and the
+ * final objective value is returned.
+ *
+ * @param iterate Starting point (will be modified).
+ * @return Objective value of the final point.
+ */
+ double Optimize(arma::mat& iterate);
+
+ //! Get the instantiated function to be optimized.
+ const DecomposableFunctionType& Function() const { return function; }
+ //! Modify the instantiated function.
+ DecomposableFunctionType& Function() { return function; }
+
+ //! Get the batch size.
+ size_t BatchSize() const { return batchSize; }
+ //! Modify the batch size.
+ size_t& BatchSize() { return batchSize; }
+
+ //! Get the step size.
+ double StepSize() const { return stepSize; }
+ //! Modify the step size.
+ double& StepSize() { return stepSize; }
+
+ //! Get the maximum number of iterations (0 indicates no limit).
+ size_t MaxIterations() const { return maxIterations; }
+ //! Modify the maximum number of iterations (0 indicates no limit).
+ size_t& MaxIterations() { return maxIterations; }
+
+ //! Get the tolerance for termination.
+ double Tolerance() const { return tolerance; }
+ //! Modify the tolerance for termination.
+ double& Tolerance() { return tolerance; }
+
+ //! Get whether or not the individual functions are shuffled.
+ bool Shuffle() const { return shuffle; }
+ //! Modify whether or not the individual functions are shuffled.
+ bool& Shuffle() { return shuffle; }
+
+ private:
+ //! The instantiated function.
+ DecomposableFunctionType& function;
+
+ //! The size of each mini-batch.
+ size_t batchSize;
+
+ //! The step size for each example.
+ double stepSize;
+
+ //! The maximum number of allowed iterations.
+ size_t maxIterations;
+
+ //! The tolerance for termination.
+ double tolerance;
+
+ //! Controls whether or not the individual functions are shuffled when
+ //! iterating.
+ bool shuffle;
+};
+
+} // namespace optimization
+} // namespace mlpack
+
+// Include implementation.
+#include "minibatch_sgd_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
new file mode 100644
index 0000000..cdcd744
--- /dev/null
+++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
@@ -0,0 +1,126 @@
+/**
+ * @file minibatch_sgd_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of mini-batch SGD.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_MINIBATCH_SGD_MINIBATCH_SGD_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_MINIBATCH_SGD_MINIBATCH_SGD_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "minibatch_sgd.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+template<typename DecomposableFunctionType>
+MiniBatchSGD<DecomposableFunctionType>::MiniBatchSGD(
+ DecomposableFunctionType& function,
+ const size_t batchSize,
+ const double stepSize,
+ const size_t maxIterations,
+ const double tolerance,
+ const bool shuffle) :
+ function(function),
+ batchSize(batchSize),
+ stepSize(stepSize),
+ maxIterations(maxIterations),
+ tolerance(tolerance),
+ shuffle(shuffle)
+{ /* Nothing to do. */ }
+
+//! Optimize the function (minimize).
+template<typename DecomposableFunctionType>
+double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
+{
+ // Find the number of functions.
+ const size_t numFunctions = function.NumFunctions();
+ size_t numBatches = numFunctions / batchSize;
+ if (numFunctions % batchSize != 0)
+ ++numBatches; // Capture last few.
+ std::cout << "numBatches " << numBatches << ".\n";
+
+ // This is only used if shuffle is true.
+ arma::Col<size_t> visitationOrder;
+ if (shuffle)
+ visitationOrder = arma::shuffle(arma::linspace<arma::Col<size_t>>(0,
+ (numBatches - 1), numBatches));
+
+ // To keep track of where we are and how things are going.
+ size_t currentBatch = 0;
+ double overallObjective = 0;
+ double lastObjective = DBL_MAX;
+
+ // Calculate the first objective function.
+ for (size_t i = 0; i < numFunctions; ++i)
+ overallObjective += function.Evaluate(iterate, i);
+
+ // Now iterate!
+ arma::mat gradient(iterate.n_rows, iterate.n_cols);
+ for (size_t i = 1; i != maxIterations; ++i, ++currentBatch)
+ {
+ // Is this iteration the start of a sequence?
+ if ((currentBatch % numBatches) == 0)
+ {
+ // Output current objective function.
+ std::cout << "Mini-batch SGD: iteration " << i << ", objective "
+ << overallObjective << "." << std::endl;
+
+ if (std::isnan(overallObjective) || std::isinf(overallObjective))
+ {
+ Log::Warn << "Mini-batch SGD: converged to " << overallObjective
+ << "; terminating with failure. Try a smaller step size?"
+ << std::endl;
+ return overallObjective;
+ }
+
+ if (std::abs(lastObjective - overallObjective) < tolerance)
+ {
+ Log::Info << "Mini-batch SGD: minimized within tolerance " << tolerance
+ << "; terminating optimization." << std::endl;
+ return overallObjective;
+ }
+
+ // Reset the counter variables.
+ lastObjective = overallObjective;
+ overallObjective = 0;
+ currentBatch = 0;
+
+ if (shuffle)
+ visitationOrder = arma::shuffle(visitationOrder);
+ }
+
+ // Evaluate the gradient for this mini-batch.
+ const size_t offset = (shuffle) ? batchSize * visitationOrder[currentBatch]
+ : batchSize * currentBatch;
+ function.Gradient(iterate, offset, gradient);
+ for (size_t j = 1; j < batchSize; ++j)
+ {
+ arma::mat funcGradient;
+ function.Gradient(iterate, offset + j, funcGradient);
+ gradient += funcGradient;
+ }
+
+ // Now update the iterate.
+ iterate -= (stepSize / batchSize) * gradient;
+
+ // Add that to the overall objective function.
+ for (size_t j = 0; j < batchSize; ++j)
+ overallObjective += function.Evaluate(iterate, offset + j);
+ }
+
+ Log::Info << "Mini-batch SGD: maximum iterations (" << maxIterations << ") "
+ << "reached; terminating optimization." << std::endl;
+
+ // Calculate final objective.
+ overallObjective = 0;
+ for (size_t i = 0; i < numFunctions; ++i)
+ overallObjective += function.Evaluate(iterate, i);
+
+ return overallObjective;
+}
+
+} // namespace optimization
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list