[mlpack-git] master: Reimplementing Adadelta with tests (88b9b85)

gitdub at mlpack.org gitdub at mlpack.org
Mon Mar 14 18:12:23 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c0886a18f63c9335a0c39dcc34c27b8925dcb91b...b864df8cf10592b3874b079302774dbe7a4c1dbc

>---------------------------------------------------------------

commit 88b9b85143a7707a705dd928c49db2be5f313351
Author: vasanth kalingeri <vasanth.kalingeri at gmail.com>
Date:   Mon Mar 14 12:26:30 2016 +0530

    Reimplementing Adadelta with tests


>---------------------------------------------------------------

88b9b85143a7707a705dd928c49db2be5f313351
 .../{rmsprop => adadelta}/CMakeLists.txt           |   4 +-
 .../rmsprop.hpp => adadelta/ada_delta.hpp}         |  70 +++-----
 .../ada_delta_impl.hpp}                            |  54 ++++---
 src/mlpack/tests/ada_delta_test.cpp                | 177 +++++++++++++--------
 4 files changed, 167 insertions(+), 138 deletions(-)

diff --git a/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt b/src/mlpack/core/optimizers/adadelta/CMakeLists.txt
similarity index 83%
copy from src/mlpack/core/optimizers/rmsprop/CMakeLists.txt
copy to src/mlpack/core/optimizers/adadelta/CMakeLists.txt
index 75c30c6..3cd516b 100644
--- a/src/mlpack/core/optimizers/rmsprop/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/adadelta/CMakeLists.txt
@@ -1,6 +1,6 @@
 set(SOURCES
-  rmsprop.hpp
-  rmsprop_impl.hpp
+  ada_delta.hpp
+  ada_delta_impl.hpp
 )
 
 set(DIR_SRCS)
diff --git a/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp b/src/mlpack/core/optimizers/adadelta/ada_delta.hpp
similarity index 69%
copy from src/mlpack/core/optimizers/rmsprop/rmsprop.hpp
copy to src/mlpack/core/optimizers/adadelta/ada_delta.hpp
index 690da6a..dbe5886 100644
--- a/src/mlpack/core/optimizers/rmsprop/rmsprop.hpp
+++ b/src/mlpack/core/optimizers/adadelta/ada_delta.hpp
@@ -1,13 +1,5 @@
-/**
- * @file rmsprop.hpp
- * @author Ryan Curtin
- * @author Marcus Edel
- *
- * RMSprop optimizer. RmsProp is an optimizer that utilizes the magnitude of
- * recent gradients to normalize the gradients.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_HPP
-#define __MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_HPP
+#ifndef __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_HPP
+#define __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_HPP
 
 #include <mlpack/core.hpp>
 
@@ -15,27 +7,25 @@ namespace mlpack {
 namespace optimization {
 
 /**
- * RMSprop is an optimizer that utilizes the magnitude of recent gradients to
- * normalize the gradients. In its basic form, given a step rate \f$ \gamma \f$
- * and a decay term \f$ \alpha \f$ we perform the following updates:
+ * Adadelta is an optimizer that uses two ideas to improve upon the two main
+ * drawbacks of the Adagrad method:
  *
- * \f{eqnarray*}{
- * r_t &=& (1 - \gamma) f'(\Delta_t)^2 + \gamma r_{t - 1} \\
- * v_{t + 1} &=& \frac{\alpha}{\sqrt{r_t}}f'(\Delta_t) \\
- * \Delta_{t + 1} &=& \Delta_t - v_{t + 1}
- * \f}
+ *  - Accumulate Over Window
+ *  - Correct Units with Hessian Approximation
  *
  * For more information, see the following.
  *
  * @code
- * @misc{tieleman2012,
- *   title={Lecture 6.5 - rmsprop, COURSERA: Neural Networks for Machine
- *   Learning},
- *   year={2012}
+ * @article{Zeiler2012,
+ *   author    = {Matthew D. Zeiler},
+ *   title     = {{ADADELTA:} An Adaptive Learning Rate Method},
+ *   journal   = {CoRR},
+ *   year      = {2012}
  * }
  * @endcode
  *
- * For RMSprop to work, a DecomposableFunctionType template parameter is
+ 
+ * For AdaDelta to work, a DecomposableFunctionType template parameter is
  * required. This class must implement the following function:
  *
  *   size_t NumFunctions();
@@ -56,11 +46,11 @@ namespace optimization {
  *     minimized.
  */
 template<typename DecomposableFunctionType>
-class RMSprop
+class AdaDelta
 {
  public:
   /**
-   * Construct the RMSprop optimizer with the given function and parameters. The
+   * Construct the AdaDelta optimizer with the given function and parameters. The
    * defaults here are not necessarily good for the given problem, so it is
    * suggested that the values used be tailored to the task at hand.  The
    * maximum number of iterations refers to the maximum number of points that
@@ -68,9 +58,7 @@ class RMSprop
    * equal one pass over the dataset).
    *
    * @param function Function to be optimized (minimized).
-   * @param stepSize Step size for each iteration.
-   * @param alpha Smoothing constant, similar to that used in AdaDelta and
-   *        momentum methods.
+   * @param rho Smoothing constant
    * @param eps Value used to initialise the mean squared gradient parameter.
    * @param maxIterations Maximum number of iterations allowed (0 means no
    *        limit).
@@ -78,16 +66,15 @@ class RMSprop
    * @param shuffle If true, the function order is shuffled; otherwise, each
    *        function is visited in linear order.
    */
-  RMSprop(DecomposableFunctionType& function,
-      const double stepSize = 0.01,
-      const double alpha = 0.99,
-      const double eps = 1e-8,
+  AdaDelta(DecomposableFunctionType& function,
+      const double rho = 0.95,
+      const double eps = 1e-6,
       const size_t maxIterations = 100000,
       const double tolerance = 1e-5,
       const bool shuffle = true);
   
   /**
-   * Optimize the given function using RMSprop. The given starting point will be
+   * Optimize the given function using AdaDelta. The given starting point will be
    * modified to store the finishing point of the algorithm, and the final
    * objective value is returned.
    *
@@ -101,15 +88,10 @@ class RMSprop
   //! Modify the instantiated function.
   DecomposableFunctionType& Function() { return function; }
 
-  //! Get the step size.
-  double StepSize() const { return stepSize; }
-  //! Modify the step size.
-  double& StepSize() { return stepSize; }
-
   //! Get the smoothing parameter.
-  double Alpha() const { return alpha; }
+  double Rho() const { return rho; }
   //! Modify the smoothing parameter.
-  double& Alpha() { return alpha; }
+  double& Rho() { return rho; }
 
   //! Get the value used to initialise the mean squared gradient parameter.
   double Epsilon() const { return eps; }
@@ -135,11 +117,8 @@ class RMSprop
   //! The instantiated function.
   DecomposableFunctionType& function;
 
-  //! The step size for each example.
-  double stepSize;
-
   //! The smoothing parameter.
-  double alpha;
+  double rho;
 
   //! The value used to initialise the mean squared gradient parameter.
   double eps;
@@ -159,6 +138,7 @@ class RMSprop
 } // namespace mlpack
 
 // Include implementation.
-#include "rmsprop_impl.hpp"
+#include "ada_delta_impl.hpp"
 
 #endif
+
diff --git a/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp b/src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp
similarity index 71%
copy from src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp
copy to src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp
index 539fa05..ac08a62 100644
--- a/src/mlpack/core/optimizers/rmsprop/rmsprop_impl.hpp
+++ b/src/mlpack/core/optimizers/adadelta/ada_delta_impl.hpp
@@ -1,30 +1,20 @@
-/**
- * @file rmsprop_impl.hpp
- * @author Ryan Curtin
- * @author Marcus Edel
- *
- * Implementation of the RMSprop optimizer.
- */
-#ifndef __MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_IMPL_HPP
-#define __MLPACK_CORE_OPTIMIZERS_RMSPROP_RMSPROP_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "rmsprop.hpp"
+#ifndef __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_ADADELTA_ADA_DELTA_IMPL_HPP
+
+#include "ada_delta.hpp"
 
 namespace mlpack {
 namespace optimization {
 
 template<typename DecomposableFunctionType>
-RMSprop<DecomposableFunctionType>::RMSprop(DecomposableFunctionType& function,
-                                           const double stepSize,
-                                           const double alpha,
+AdaDelta<DecomposableFunctionType>::AdaDelta(DecomposableFunctionType& function,
+                                           const double rho, 
                                            const double eps,
                                            const size_t maxIterations,
                                            const double tolerance,
                                            const bool shuffle) :
     function(function),
-    stepSize(stepSize),
-    alpha(alpha),
+    rho(rho),
     eps(eps),
     maxIterations(maxIterations),
     tolerance(tolerance),
@@ -33,7 +23,7 @@ RMSprop<DecomposableFunctionType>::RMSprop(DecomposableFunctionType& function,
 
 //! Optimize the function (minimize).
 template<typename DecomposableFunctionType>
-double RMSprop<DecomposableFunctionType>::Optimize(arma::mat& iterate)
+double AdaDelta<DecomposableFunctionType>::Optimize(arma::mat& iterate)
 {
   // Find the number of functions to use.
   const size_t numFunctions = function.NumFunctions();
@@ -60,18 +50,22 @@ double RMSprop<DecomposableFunctionType>::Optimize(arma::mat& iterate)
   arma::mat meanSquaredGradient = arma::zeros<arma::mat>(iterate.n_rows,
       iterate.n_cols);
 
+  // Leaky sum of squares of parameter gradient.
+  arma::mat meanSquaredGradientDx = arma::zeros<arma::mat>(iterate.n_rows,
+      iterate.n_cols);
+      
   for (size_t i = 1; i != maxIterations; ++i, ++currentFunction)
   {
     // Is this iteration the start of a sequence?
     if ((currentFunction % numFunctions) == 0)
     {
       // Output current objective function.
-      Log::Info << "RMSprop: iteration " << i << ", objective "
+      Log::Info << "AdaDelta: iteration " << i << ", objective "
           << overallObjective << "." << std::endl;
 
       if (std::isnan(overallObjective) || std::isinf(overallObjective))
       {
-        Log::Warn << "RMSprop: converged to " << overallObjective
+        Log::Warn << "AdaDelta: converged to " << overallObjective
             << "; terminating with failure. Try a smaller step size?"
             << std::endl;
         return overallObjective;
@@ -79,7 +73,7 @@ double RMSprop<DecomposableFunctionType>::Optimize(arma::mat& iterate)
 
       if (std::abs(lastObjective - overallObjective) < tolerance)
       {
-        Log::Info << "RMSprop: minimized within tolerance " << tolerance << "; "
+        Log::Info << "AdaDelta: minimized within tolerance " << tolerance << "; "
             << "terminating optimization." << std::endl;
         return overallObjective;
       }
@@ -99,10 +93,18 @@ double RMSprop<DecomposableFunctionType>::Optimize(arma::mat& iterate)
     else
       function.Gradient(iterate, currentFunction, gradient);
       
-    // And update the iterate.
-    meanSquaredGradient *= alpha;
-    meanSquaredGradient += (1 - alpha) * (gradient % gradient);
-    iterate -= stepSize * gradient / (arma::sqrt(meanSquaredGradient) + eps);
+    // Accumulate gradient.
+    meanSquaredGradient *= rho;
+    meanSquaredGradient += (1 - rho) * (gradient % gradient);
+    arma::mat dx = arma::sqrt((meanSquaredGradientDx + eps) /
+        (meanSquaredGradient + eps)) % gradient;
+
+    // Accumulate updates.
+    meanSquaredGradientDx *= rho;
+    meanSquaredGradientDx += (1 - rho) * (dx % dx);
+
+    // Apply update.
+    iterate -= dx;
     
     // Now add that to the overall objective function.
     if (shuffle)
@@ -112,7 +114,7 @@ double RMSprop<DecomposableFunctionType>::Optimize(arma::mat& iterate)
       overallObjective += function.Evaluate(iterate, currentFunction);
   }
 
-  Log::Info << "RMSprop: maximum iterations (" << maxIterations << ") reached; "
+  Log::Info << "AdaDelta: maximum iterations (" << maxIterations << ") reached; "
       << "terminating optimization." << std::endl;
   // Calculate final objective.
   overallObjective = 0;
diff --git a/src/mlpack/tests/ada_delta_test.cpp b/src/mlpack/tests/ada_delta_test.cpp
index 961fede..01fdd63 100644
--- a/src/mlpack/tests/ada_delta_test.cpp
+++ b/src/mlpack/tests/ada_delta_test.cpp
@@ -2,28 +2,34 @@
  * @file ada_delta_test.cpp
  * @author Marcus Edel
  *
- * Tests the AdaDelta optimizer on a couple test models.
+ * Tests the AdaDelta optimizer
  */
 #include <mlpack/core.hpp>
 
-#include <mlpack/methods/ann/activation_functions/logistic_function.hpp>
+#include <mlpack/core/optimizers/adadelta/ada_delta.hpp>
+#include <mlpack/core/optimizers/sgd/test_function.hpp>
 
-#include <mlpack/methods/ann/init_rules/random_init.hpp>
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
 
+#include <mlpack/methods/ann/ffn.hpp>
+#include <mlpack/methods/ann/init_rules/random_init.hpp>
+#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
+#include <mlpack/methods/ann/layer/binary_classification_layer.hpp>
 #include <mlpack/methods/ann/layer/bias_layer.hpp>
 #include <mlpack/methods/ann/layer/linear_layer.hpp>
 #include <mlpack/methods/ann/layer/base_layer.hpp>
-#include <mlpack/methods/ann/layer/one_hot_layer.hpp>
-
-#include <mlpack/methods/ann/trainer/trainer.hpp>
-#include <mlpack/methods/ann/ffn.hpp>
-#include <mlpack/methods/ann/performance_functions/mse_function.hpp>
-#include <mlpack/methods/ann/optimizer/ada_delta.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
+using namespace arma;
 using namespace mlpack;
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+
+using namespace mlpack::distribution;
+using namespace mlpack::regression;
+
 using namespace mlpack::ann;
 
 BOOST_AUTO_TEST_SUITE(AdaDeltaTest);
@@ -33,81 +39,122 @@ BOOST_AUTO_TEST_SUITE(AdaDeltaTest);
  * iris data, the data set contains 3 classes. One class is linearly separable
  * from the other 2. The other two aren't linearly separable from each other.
  */
+
 BOOST_AUTO_TEST_CASE(SimpleAdaDeltaTestFunction)
 {
-  const size_t hiddenLayerSize = 10;
-  const size_t maxEpochs = 300;
+  SGDTestFunction f;
+  AdaDelta<SGDTestFunction> optimizer(f, 0.99, 1e-8, 5000000, 1e-9, true);
+
+  arma::mat coordinates = f.GetInitialPoint();
+  const double result = optimizer.Optimize(coordinates);
 
-  // Load the dataset.
-  arma::mat dataset, labels, labelsIdx;
-  data::Load("iris_train.csv", dataset, true);
-  data::Load("iris_train_labels.csv", labelsIdx, true);
+  BOOST_REQUIRE_LE(std::abs(result) - 1.0, 0.2);
+  BOOST_REQUIRE_SMALL(coordinates[0], 1e-3);
+  BOOST_REQUIRE_SMALL(coordinates[1], 1e-3);
+  BOOST_REQUIRE_SMALL(coordinates[2], 1e-3);
+}
 
-  // Create target matrix.
-  labels = arma::zeros<arma::mat>(labelsIdx.max() + 1, labelsIdx.n_cols);
-  for (size_t i = 0; i < labelsIdx.n_cols; i++)
-    labels(labelsIdx(0, i), i) = 1;
+/**
+ * Run AdaDelta on logistic regression and make sure the results are acceptable.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionTest)
+{
+  // Generate a two-Gaussian dataset.
+  GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
+  GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
+
+  arma::mat data(3, 1000);
+  arma::Row<size_t> responses(1000);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    data.col(i) = g1.Random();
+    responses[i] = 0;
+  }
+  for (size_t i = 500; i < 1000; ++i)
+  {
+    data.col(i) = g2.Random();
+    responses[i] = 1;
+  }
 
-  // Construct a feed forward network using the specified parameters.
-  RandomInitialization randInit(0.1, 0.1);
+  // Shuffle the dataset.
+  arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
+      data.n_cols - 1, data.n_cols));
+  arma::mat shuffledData(3, 1000);
+  arma::Row<size_t> shuffledResponses(1000);
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    shuffledData.col(i) = data.col(indices[i]);
+    shuffledResponses[i] = responses[indices[i]];
+  }
 
-  LinearLayer<AdaDelta, RandomInitialization> inputLayer(dataset.n_rows,
-      hiddenLayerSize, randInit);
-  BiasLayer<AdaDelta, RandomInitialization> inputBiasLayer(hiddenLayerSize,
-      1, randInit);
-  BaseLayer<LogisticFunction> inputBaseLayer;
+  // Create a test set.
+  arma::mat testData(3, 1000);
+  arma::Row<size_t> testResponses(1000);
+  for (size_t i = 0; i < 500; ++i)
+  {
+    testData.col(i) = g1.Random();
+    testResponses[i] = 0;
+  }
+  for (size_t i = 500; i < 1000; ++i)
+  {
+    testData.col(i) = g2.Random();
+    testResponses[i] = 1;
+  }
 
-  LinearLayer<AdaDelta, RandomInitialization> hiddenLayer1(hiddenLayerSize,
-      labels.n_rows, randInit);
-  BiasLayer<AdaDelta, RandomInitialization> hiddenBiasLayer1(labels.n_rows,
-      1, randInit);
-  BaseLayer<LogisticFunction> outputLayer;
+  LogisticRegression<> lr(shuffledData.n_rows, 0.5);
 
-  OneHotLayer classOutputLayer;
+  LogisticRegressionFunction<> lrf(shuffledData, shuffledResponses, 0.5);
+  AdaDelta<LogisticRegressionFunction<> > AdaDelta(lrf);
+  lr.Train(AdaDelta);
 
-  auto modules = std::tie(inputLayer, inputBiasLayer, inputBaseLayer,
-                          hiddenLayer1, hiddenBiasLayer1, outputLayer);
+  // Ensure that the error is close to zero.
+  const double acc = lr.ComputeAccuracy(data, responses);
+  BOOST_REQUIRE_CLOSE(acc, 100.0, 0.3); // 0.3% error tolerance.
 
-  FFN<decltype(modules), OneHotLayer, MeanSquaredErrorFunction>
-      net(modules, classOutputLayer);
+  const double testAcc = lr.ComputeAccuracy(testData, testResponses);
+  BOOST_REQUIRE_CLOSE(testAcc, 100.0, 0.6); // 0.6% error tolerance.
+}
 
-  arma::mat prediction;
-  size_t error = 0;
+/**
+ * Run AdaDelta on a feedforward neural network and make sure the results are
+ * acceptable.
+ */
+BOOST_AUTO_TEST_CASE(FeedforwardTest)
+{
+  // Test on a non-linearly separable dataset (XOR).
+  arma::mat input, labels;
+  input << 0 << 1 << 1 << 0 << arma::endr
+        << 1 << 0 << 1 << 0 << arma::endr;
+  labels << 0 << 0 << 1 << 1;
 
-  // Evaluate the feed forward network.
-  for (size_t i = 0; i < dataset.n_cols; i++)
-  {
-    arma::mat input = dataset.unsafe_col(i);
-    net.Predict(input, prediction);
+  // Instantiate the first layer.
+  LinearLayer<> inputLayer(input.n_rows, 4);
+  BiasLayer<> biasLayer(4);
+  SigmoidLayer<> hiddenLayer0;
 
-    if (arma::sum(arma::sum(arma::abs(
-      prediction - labels.unsafe_col(i)))) == 0)
-      error++;
-  }
+  // Instantiate the second layer.
+  LinearLayer<> hiddenLayer1(4, labels.n_rows);
+  SigmoidLayer<> outputLayer;
 
-  // Check if the selected model isn't already optimized.
-  double classificationError = 1 - double(error) / dataset.n_cols;
-  BOOST_REQUIRE_GE(classificationError, 0.09);
+  // Instantiate the output layer.
+  BinaryClassificationLayer classOutputLayer;
 
-  // Train the feed forward network.
-  Trainer<decltype(net)> trainer(net, maxEpochs, 1, 0.01, false);
-  trainer.Train(dataset, labels, dataset, labels);
+  // Instantiate the feedforward network.
+  auto modules = std::tie(inputLayer, biasLayer, hiddenLayer0, hiddenLayer1,
+      outputLayer);
+  FFN<decltype(modules), decltype(classOutputLayer), RandomInitialization,
+      MeanSquaredErrorFunction> net(modules, classOutputLayer);
 
-  // Evaluate the feed forward network.
-  error = 0;
-  for (size_t i = 0; i < dataset.n_cols; i++)
-  {
-    arma::mat input = dataset.unsafe_col(i);
-    net.Predict(input, prediction);
+  AdaDelta<decltype(net)> opt(net, 0.88, 1e-15,
+      300 * input.n_cols, 1e-18);
 
-    if (arma::sum(arma::sum(arma::abs(
-      prediction - labels.unsafe_col(i)))) == 0)
-      error++;
-  }
+  net.Train(input, labels, opt);
 
-  classificationError = 1 - double(error) / dataset.n_cols;
+  arma::mat prediction;
+  net.Predict(input, prediction);
 
-  BOOST_REQUIRE_LE(classificationError, 0.09);
+  const bool b = arma::accu(prediction - labels) == 0;
+  BOOST_REQUIRE_EQUAL(b, true);
 }
 
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list