[mlpack-git] master: Allow mini-batch SGD optimizer. (6e8e873)

gitdub at mlpack.org gitdub at mlpack.org
Mon Feb 22 13:32:56 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/29dc196be362c8af44f36f331ac719d5a0e34acd...f3d692c0124e9667076b97318f6d64661015d368

>---------------------------------------------------------------

commit 6e8e87311b91eaf8b5fdc805726dc0639f083c90
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 22 10:32:34 2016 -0800

    Allow mini-batch SGD optimizer.


>---------------------------------------------------------------

6e8e87311b91eaf8b5fdc805726dc0639f083c90
 src/mlpack/methods/nca/nca_main.cpp | 94 +++++++++++++++++++++++++------------
 1 file changed, 63 insertions(+), 31 deletions(-)

diff --git a/src/mlpack/methods/nca/nca_main.cpp b/src/mlpack/methods/nca/nca_main.cpp
index 0da3b56..de16626 100644
--- a/src/mlpack/methods/nca/nca_main.cpp
+++ b/src/mlpack/methods/nca/nca_main.cpp
@@ -10,6 +10,7 @@
 #include "nca.hpp"
 
 #include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+#include <mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp>
 
 // Define parameters.
 PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
@@ -19,18 +20,18 @@ PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
     "by scaling the dimensions.  The method is nonparametric, and does not "
     "require a value of k.  It works by using stochastic (\"soft\") neighbor "
     "assignments and using optimization techniques over the gradient of the "
-    "accuracy of the neighbor assignments.\n"
-    "\n"
+    "accuracy of the neighbor assignments."
+    "\n\n"
     "To work, this algorithm needs labeled data.  It can be given as the last "
     "row of the input dataset (--input_file), or alternatively in a separate "
-    "file (--labels_file).\n"
-    "\n"
-    "This implementation of NCA uses either stochastic gradient descent or the "
-    "L_BFGS optimizer.  Both of these optimizers do not guarantee global "
-    "convergence for a nonconvex objective function (NCA's objective function "
-    "is nonconvex), so the final results could depend on the random seed or "
-    "other optimizer parameters.\n"
-    "\n"
+    "file (--labels_file)."
+    "\n\n"
+    "This implementation of NCA uses stochastic gradient descent, mini-batch "
+    "stochastic gradient descent, or the L_BFGS optimizer.  These optimizers do"
+    " not guarantee global convergence for a nonconvex objective function "
+    "(NCA's objective function is nonconvex), so the final results could depend"
+    " on the random seed or other optimizer parameters."
+    "\n\n"
     "Stochastic gradient descent, specified by --optimizer \"sgd\", depends "
     "primarily on two parameters: the step size (--step_size) and the maximum "
     "number of iterations (--max_iterations).  In addition, a normalized "
@@ -44,10 +45,18 @@ PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
     "the maximum iterations to a large number and allow SGD to find a minimum, "
     "or set the maximum iterations to 0 (allowing infinite iterations) and set "
     "the tolerance (--tolerance) to define the maximum allowed difference "
-    "between objectives for SGD to terminate.  Be careful -- setting the "
+    "between objectives for SGD to terminate.  Be careful---setting the "
     "tolerance instead of the maximum iterations can take a very long time and "
-    "may actually never converge due to the properties of the SGD optimizer.\n"
-    "\n"
+    "may actually never converge due to the properties of the SGD optimizer. "
+    "Note that a single iteration of SGD refers to a single point, so to take "
+    "a single pass over the dataset, set --max_iterations equal to the number "
+    "of points in the dataset."
+    "\n\n"
+    "The mini-batch SGD optimizer, specified by --optimizer \"minibatch-sgd\", "
+    "has the same parameters as SGD, but the batch size may also be specified "
+    "with the --batch_size (-b) option.  Each iteration of mini-batch SGD "
+    "refers to a single mini-batch."
+    "\n\n"
     "The L-BFGS optimizer, specified by --optimizer \"lbfgs\", uses a "
     "back-tracking line search algorithm to minimize a function.  The "
     "following parameters are used by L-BFGS: --num_basis (specifies the number"
@@ -57,15 +66,16 @@ PROGRAM_INFO("Neighborhood Components Analysis (NCA)",
     "--max_step (which both refer to the line search routine).  For more "
     "details on the L-BFGS optimizer, consult either the mlpack L-BFGS "
     "documentation (in lbfgs.hpp) or the vast set of published literature on "
-    "L-BFGS.\n"
-    "\n"
+    "L-BFGS."
+    "\n\n"
     "By default, the SGD optimizer is used.");
 
 PARAM_STRING_REQ("input_file", "Input dataset to run NCA on.", "i");
 PARAM_STRING_REQ("output_file", "Output file for learned distance matrix.",
     "o");
 PARAM_STRING("labels_file", "File of labels for input dataset.", "l", "");
-PARAM_STRING("optimizer", "Optimizer to use; \"sgd\" or \"lbfgs\".", "O", "sgd");
+PARAM_STRING("optimizer", "Optimizer to use; 'sgd', 'minibatch-sgd', or "
+    "'lbfgs'.", "O", "sgd");
 
 PARAM_FLAG("normalize", "Use a normalized starting point for optimization. This"
     " is useful for when points are far apart, or when SGD is returning NaN.",
@@ -79,7 +89,8 @@ PARAM_DOUBLE("tolerance", "Maximum tolerance for termination of SGD or L-BFGS.",
 PARAM_DOUBLE("step_size", "Step size for stochastic gradient descent (alpha).",
     "a", 0.01);
 PARAM_FLAG("linear_scan", "Don't shuffle the order in which data points are "
-    "visited for SGD.", "L");
+    "visited for SGD or mini-batch SGD.", "L");
+PARAM_INT("batch_size", "Batch size for mini-batch SGD.", "b", 50);
 
 PARAM_INT("num_basis", "Number of memory points to be stored for L-BFGS.", "B",
     5);
@@ -92,7 +103,6 @@ PARAM_DOUBLE("max_step", "Maximum step of line search for L-BFGS.", "M", 1e20);
 
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 
-
 using namespace mlpack;
 using namespace mlpack::nca;
 using namespace mlpack::metric;
@@ -115,47 +125,56 @@ int main(int argc, char* argv[])
 
   const string optimizerType = CLI::GetParam<string>("optimizer");
 
-  if ((optimizerType != "sgd") && (optimizerType != "lbfgs"))
+  if ((optimizerType != "sgd") && (optimizerType != "lbfgs") &&
+      (optimizerType != "minibatch-sgd"))
   {
     Log::Fatal << "Optimizer type '" << optimizerType << "' unknown; must be "
-        << "'sgd' or 'lbfgs'!" << std::endl;
+        << "'sgd', 'minibatch-sgd', or 'lbfgs'!" << endl;
   }
 
   // Warn on unused parameters.
-  if (optimizerType == "sgd")
+  if (optimizerType == "sgd" || optimizerType == "minibatch-sgd")
   {
     if (CLI::HasParam("num_basis"))
       Log::Warn << "Parameter --num_basis ignored (not using 'lbfgs' "
-          << "optimizer)." << std::endl;
+          << "optimizer)." << endl;
 
     if (CLI::HasParam("armijo_constant"))
       Log::Warn << "Parameter --armijo_constant ignored (not using 'lbfgs' "
-          << "optimizer)." << std::endl;
+          << "optimizer)." << endl;
 
     if (CLI::HasParam("wolfe"))
       Log::Warn << "Parameter --wolfe ignored (not using 'lbfgs' optimizer).\n";
 
     if (CLI::HasParam("max_line_search_trials"))
       Log::Warn << "Parameter --max_line_search_trials ignored (not using "
-          << "'lbfgs' optimizer." << std::endl;
+          << "'lbfgs' optimizer." << endl;
 
     if (CLI::HasParam("min_step"))
       Log::Warn << "Parameter --min_step ignored (not using 'lbfgs' optimizer)."
-          << std::endl;
+          << endl;
 
     if (CLI::HasParam("max_step"))
       Log::Warn << "Parameter --max_step ignored (not using 'lbfgs' optimizer)."
-          << std::endl;
+          << endl;
+
+    if (optimizerType == "sgd" && CLI::HasParam("batch_size"))
+      Log::Warn << "Parameter --batch_size ignored (not using 'minibatch-sgd' "
+          << "optimizer." << endl;
   }
   else if (optimizerType == "lbfgs")
   {
     if (CLI::HasParam("step_size"))
-      Log::Warn << "Parameter --step_size ignored (not using 'sgd' optimizer)."
-          << std::endl;
+      Log::Warn << "Parameter --step_size ignored (not using 'sgd' or "
+          << "'minibatch-sgd' optimizer)." << endl;
 
     if (CLI::HasParam("linear_scan"))
-      Log::Warn << "Parameter --linear_scan ignored (not using 'sgd' "
-          << "optimizer)." << std::endl;
+      Log::Warn << "Parameter --linear_scan ignored (not using 'sgd' or "
+          << "'minibatch-sgd' optimizer)." << endl;
+
+    if (CLI::HasParam("batch_size"))
+      Log::Warn << "Parameter --batch_size ignored (not using 'minibatch-sgd' "
+          << "optimizer)." << endl;
   }
 
   const double stepSize = CLI::GetParam<double>("step_size");
@@ -169,6 +188,7 @@ int main(int argc, char* argv[])
   const int maxLineSearchTrials = CLI::GetParam<int>("max_line_search_trials");
   const double minStep = CLI::GetParam<double>("min_step");
   const double maxStep = CLI::GetParam<double>("max_step");
+  const size_t batchSize = (size_t) CLI::GetParam<int>("batch_size");
 
   // Load data.
   arma::mat data;
@@ -188,6 +208,7 @@ int main(int argc, char* argv[])
   }
   else
   {
+    Log::Info << "Using last column of input dataset as labels." << endl;
     for (size_t i = 0; i < data.n_cols; i++)
       rawLabels[i] = (int) data(data.n_rows - 1, i);
 
@@ -212,7 +233,7 @@ int main(int argc, char* argv[])
 
     distance = diagmat(1.0 / ranges);
     Log::Info << "Using normalized starting point for optimization."
-        << std::endl;
+        << endl;
   }
   else
   {
@@ -244,6 +265,17 @@ int main(int argc, char* argv[])
 
     nca.LearnDistance(distance);
   }
+  else if (optimizerType == "minibatch-sgd")
+  {
+    NCA<LMetric<2>, MiniBatchSGD> nca(data, labels);
+    nca.Optimizer().StepSize() = stepSize;
+    nca.Optimizer().MaxIterations() = maxIterations;
+    nca.Optimizer().Tolerance() = tolerance;
+    nca.Optimizer().Shuffle() = shuffle;
+    nca.Optimizer().BatchSize() = batchSize;
+
+    nca.LearnDistance(distance);
+  }
 
   // Save the output.
   data::Save(CLI::GetParam<string>("output_file"), distance, true);




More information about the mlpack-git mailing list