[mlpack-git] master, mlpack-1.0.x: Initial commit of simulated annealing optimizer from Zhihao Lou. (b617669)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:50:20 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit b6176694307cbe52924508965cbe9fa1e3a38b5c
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jul 2 12:36:20 2014 +0000

    Initial commit of simulated annealing optimizer from Zhihao Lou.


>---------------------------------------------------------------

b6176694307cbe52924508965cbe9fa1e3a38b5c
 src/mlpack/core/optimizers/CMakeLists.txt          |   1 +
 .../core/optimizers/{lbfgs => sa}/CMakeLists.txt   |   9 +-
 .../core/optimizers/sa/exponential_schedule.hpp    |  52 +++++
 .../core/optimizers/sa/laplace_distribution.cpp    |  22 ++
 .../core/optimizers/sa/laplace_distribution.hpp    |  34 ++++
 src/mlpack/core/optimizers/sa/sa.hpp               | 175 ++++++++++++++++
 src/mlpack/core/optimizers/sa/sa_impl.hpp          | 225 +++++++++++++++++++++
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 src/mlpack/tests/sa_test.cpp                       |  46 +++++
 9 files changed, 561 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/core/optimizers/CMakeLists.txt b/src/mlpack/core/optimizers/CMakeLists.txt
index 34032b0..1243586 100644
--- a/src/mlpack/core/optimizers/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/CMakeLists.txt
@@ -2,6 +2,7 @@ set(DIRS
   aug_lagrangian
   lbfgs
   lrsdp
+  sa
   sgd
 )
 
diff --git a/src/mlpack/core/optimizers/lbfgs/CMakeLists.txt b/src/mlpack/core/optimizers/sa/CMakeLists.txt
similarity index 64%
copy from src/mlpack/core/optimizers/lbfgs/CMakeLists.txt
copy to src/mlpack/core/optimizers/sa/CMakeLists.txt
index a41d21d..088abec 100644
--- a/src/mlpack/core/optimizers/lbfgs/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/sa/CMakeLists.txt
@@ -1,8 +1,9 @@
 set(SOURCES
-  lbfgs_impl.hpp
-  lbfgs.hpp
-  test_functions.hpp
-  test_functions.cpp
+  sa.hpp
+  sa_impl.hpp
+  laplace_distribution.hpp
+  laplace_distribution.cpp
+  exponential_schedule.hpp
 )
 
 set(DIR_SRCS)
diff --git a/src/mlpack/core/optimizers/sa/exponential_schedule.hpp b/src/mlpack/core/optimizers/sa/exponential_schedule.hpp
new file mode 100644
index 0000000..c3d6019
--- /dev/null
+++ b/src/mlpack/core/optimizers/sa/exponential_schedule.hpp
@@ -0,0 +1,52 @@
+/*
+ * @file exponential_schedule.hpp
+ * @author Zhihao Lou
+ *
+ * Exponential (geometric) cooling schedule used in SA
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_EXPONENTIAL_SCHEDULE_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_EXPONENTIAL_SCHEDULE_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/* 
+ * The exponential cooling schedule cools the temperature T at every step
+ * \f[
+ * T_{n+1}=(1-\lambda)T_{n}
+ * \f]
+ * where \f$ 0<\lambda<1 \f$ is the cooling speed. The smaller \f$ \lambda \f$
+ * is, the slower the cooling speed, and better the final result will be. Some
+ * literature uses \f$ \alpha=(-1\lambda) \f$ instead. In practice, \f$ \alpha \f$
+ * is very close to 1 and will be awkward to input (e.g. alpha=0.999999 vs
+ * lambda=1e-6).
+ */
+class ExponentialSchedule
+{
+ public:
+  /* 
+   * Construct the ExponentialSchedule with the given parameter
+   *
+   * @param lambda Cooling speed
+   */
+  ExponentialSchedule(const double lambda = 0.001) : lambda(lambda){};
+
+  //! returns the next temperature given current status
+  double nextTemperature(const double currentTemperate, const double)
+  {return (1-lambda) * currentTemperate;}
+
+  //! Get the cooling speed lambda
+  double Lambda() const {return lambda;}
+  //! Modify the cooling speed lambda
+  double& Lambda() {return lambda;}
+ private:
+  double lambda;
+
+
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/optimizers/sa/laplace_distribution.cpp b/src/mlpack/core/optimizers/sa/laplace_distribution.cpp
new file mode 100644
index 0000000..69d96cb
--- /dev/null
+++ b/src/mlpack/core/optimizers/sa/laplace_distribution.cpp
@@ -0,0 +1,22 @@
+/*
+ * @file laplace_distribution.cpp
+ * @author Zhihao Lou
+ *
+ * Implementation of Laplace distribution
+ */
+
+#include <mlpack/core.hpp>
+#include "laplace_distribution.hpp"
+using namespace mlpack;
+using namespace mlpack::optimization;
+double LaplaceDistribution::operator () (const double param)
+{
+  // uniform [-1, 1]
+  double unif = 2.0 * math::Random() - 1.0;
+  // Laplace Distribution with mean 0
+  // x = - param * sign(unif) * log(1 - |unif|)
+  if (unif < 0) // why oh why we don't have a sign function in c++?
+      return (param * std::log(1 + unif));
+  else
+      return (-1.0 * param * std::log(1 - unif));
+}
diff --git a/src/mlpack/core/optimizers/sa/laplace_distribution.hpp b/src/mlpack/core/optimizers/sa/laplace_distribution.hpp
new file mode 100644
index 0000000..15a52a3
--- /dev/null
+++ b/src/mlpack/core/optimizers/sa/laplace_distribution.hpp
@@ -0,0 +1,34 @@
+/*
+ * @file laplace.hpp
+ * @author Zhihao Lou
+ *
+ * Laplace (double exponential) distribution used in SA
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZER_SA_LAPLACE_DISTRIBUTION_HPP
+#define __MLPACK_CORE_OPTIMIZER_SA_LAPLACE_DISTRIBUTION_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/* 
+ * The Laplace distribution centered at 0 has pdf
+ * \f[
+ * f(x|\theta) = \frac{1}{2\theta}\exp\left(-\frac{|x|}{\theta}\right)
+ * \f]
+ * given scale parameter \f$\theta\f$.
+ */
+class LaplaceDistribution
+{
+ public:
+  //! Nothing to do for the constructor
+  LaplaceDistribution(){}
+  //! Return random value from Laplace distribution with parameter param
+  double operator () (const double param);
+
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/optimizers/sa/sa.hpp b/src/mlpack/core/optimizers/sa/sa.hpp
new file mode 100644
index 0000000..902e380
--- /dev/null
+++ b/src/mlpack/core/optimizers/sa/sa.hpp
@@ -0,0 +1,175 @@
+/*
+ * @file sa.hpp
+ * @author Zhihao Lou
+ *
+ * Simulated Annealing (SA).
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_SA_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_SA_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * Simulated Annealing is an stochastic optimization algorithm which is able to
+ * deliver near-optimal results quickly without knowing the gradient of the
+ * function being optimized. It has unique hill climbing capability that makes
+ * it less vulnerable to local minima.  This implementation uses exponential
+ * cooling schedule and feedback move control by default, but the cooling
+ * schedule can be changed via a template parameter.
+ *
+ * The algorithm keeps the temperature at initial temperature for initMove
+ * steps to get rid of the dependency of initial condition. After that, it
+ * cools every step until the system is considered frozen or maxIterations is
+ * reached.
+ *
+ * At each step, SA only perturbs one parameter at a time. The process that SA
+ * perturbed all parameters in a problem is called a sweep. Every moveCtrlSweep
+ * the algorithm does feedback move control to change the average move size
+ * depending on the responsiveness of each parameter. Parameter gain controls
+ * the proportion of the feedback control.
+ *
+ * The system is considered "frozen" when its score failed to change more then
+ * tolerance for consecutive maxToleranceSweep sweeps.
+ *
+ * For SA to work, a function must implement the following methods:
+ *   double Evaluate(const arma::mat& coordinates);
+ *   arma::mat& GetInitialPoint();
+ *
+ * In additional, a move generation distribution with overloaded operator():
+ *   double operator () (const double param);
+ * which returns a random value from the distribution given parameter param,
+ * and a cooling schedule with method:
+ *   doulbe nextTemperature(const double currentTemperature, const double currentValue);
+ * which returns the next temperature given current temperature and the value
+ * of the function being optimized.
+ *
+ * @tparam FunctionType objective function type to be minimized.
+ * @tparam MoveDistributionType distribution type for move generation
+ * @tparam CoolingScheduleType type for cooling schedule
+ */
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+class SA
+{
+ public:
+  /*
+   * Construct the SA optimizer with the given function and paramters.
+   *
+   * @param function Function to be minimized.
+   * @param moveDistribution Distribution for move generation
+   * @param coolingSchedule Cooling schedule
+   * @param initT Initial temperature.
+   * @param initMoves Iterations without changing temperature.
+   * @param moveCtrlSweep Sweeps per move control.
+   * @param tolerance Tolerance to consider system frozen.
+   * @param maxToleranceSweep Maximum sweeps below tolerance to consider system frozen.
+   * @param maxMoveCoef Maximum move size.
+   * @param initMoveCoef Initial move size.
+   * @param gain Proportional control in feedback move control.
+   * @param maxIterations Maximum number of iterations allowed (0 indicates no limit).
+   */
+  SA(FunctionType& function,
+     MoveDistributionType& moveDistribution,
+     CoolingScheduleType& coolingSchedule,
+     const double initT = 10000.,
+     const size_t initMoves = 1000,
+     const size_t moveCtrlSweep = 100,
+     const double tolerance = 1e-5,
+     const size_t maxToleranceSweep = 3,
+     const double maxMoveCoef = 20,
+     const double initMoveCoef = 0.3,
+     const double gain = 0.3,
+     const size_t maxIterations = 1000000);
+  /*
+   * Optimize the given function using simulated annealing. The given starting
+   * point will be modified to store the finishing point of the algorithm, and
+   * the final objective value is returned.
+   *
+   * @param iterate Starting point (will be modified).
+   * @return Objective value of the final point.
+   */
+  double Optimize(arma::mat& iterate);
+
+  //! Get the instantiated function to be optimized.
+  const FunctionType& Function() const {return function;}
+  //! Modify the instantiated function.
+  FunctionType& Function() {return function;}
+
+  //! Get the temperature.
+  double Temperature() const {return T;}
+  //! Modify the temperature.
+  double& Temperature() {return T;}
+
+  //! Get the initial moves.
+  size_t InitMoves() const {return initMoves;}
+  //! Modify the initial moves.
+  size_t& InitMoves() {return initMoves;}
+
+  //! Get sweeps per move control.
+  size_t MoveCtrlSweep() const {return moveCtrlSweep;}
+  //! Modify sweeps per move control.
+  size_t& MoveCtrlSweep() {return moveCtrlSweep;}
+
+  //! Get the tolerance.
+  double Tolerance() const {return tolerance;}
+  //! Modify the tolerance.
+  double& Tolerance() {return tolerance;}
+
+  //! Get the maxToleranceSweep.
+  size_t MaxToleranceSweep() const {return maxToleranceSweep;}
+  //! Modify the maxToleranceSweep.
+  size_t& MaxToleranceSweep() {return maxToleranceSweep;}
+
+  //! Get the gain.
+  double Gain() const {return gain;}
+  //! Modify the gain.
+  double& Gain() {return gain;}
+
+  //! Get the maxIterations.
+  size_t MaxIterations() const {return maxIterations;}
+  //! Modify the maxIterations.
+  size_t& MaxIterations() {return maxIterations;}
+
+  //! Get Maximum move size of each parameter
+  arma::mat MaxMove() const {return maxMove;}
+  //! Modify maximum move size of each parameter
+  arma::mat& MaxMove() {return maxMove;}
+
+  //! Get move size of each parameter
+  arma::mat MoveSize() const {return moveSize;}
+  //! Modify  move size of each parameter
+  arma::mat& MoveSize() {return moveSize;}
+
+  std::string ToString() const;
+ private:
+  FunctionType &function;
+  MoveDistributionType &moveDistribution;
+  CoolingScheduleType &coolingSchedule;
+  double T;
+  size_t initMoves;
+  size_t moveCtrlSweep;
+  double tolerance;
+  size_t maxToleranceSweep;
+  double gain;
+  size_t maxIterations;
+  arma::mat maxMove;
+  arma::mat moveSize;
+
+
+  // following variables are initialized inside Optimize
+  arma::mat accept;
+  double energy;
+  size_t idx;
+  size_t nVars;
+  size_t sweepCounter;
+
+  void GenerateMove(arma::mat& iterate);
+  void MoveControl(size_t nMoves);
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#include "sa_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/optimizers/sa/sa_impl.hpp b/src/mlpack/core/optimizers/sa/sa_impl.hpp
new file mode 100644
index 0000000..d72fcc9
--- /dev/null
+++ b/src/mlpack/core/optimizers/sa/sa_impl.hpp
@@ -0,0 +1,225 @@
+/*
+ * @file sa_impl.hpp
+ * @auther Zhihao Lou
+ *
+ * The implementation of the SA optimizer.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_SA_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_SA_IMPL_HPP
+
+namespace mlpack {
+namespace optimization {
+
+template<
+    typename FunctionType,
+    typename MoveDistributionType,
+    typename CoolingScheduleType
+>
+SA<FunctionType, MoveDistributionType, CoolingScheduleType>::SA(
+    FunctionType& function,
+    MoveDistributionType& moveDistribution,
+    CoolingScheduleType& coolingSchedule,
+    const double initT,
+    const size_t initMoves,
+    const size_t moveCtrlSweep,
+    const double tolerance,
+    const size_t maxToleranceSweep,
+    const double maxMoveCoef,
+    const double initMoveCoef,
+    const double gain,
+    const size_t maxIterations) :
+    function(function),
+    moveDistribution(moveDistribution),
+    coolingSchedule(coolingSchedule),
+    T(initT),
+    initMoves(initMoves),
+    moveCtrlSweep(moveCtrlSweep),
+    tolerance(tolerance),
+    maxToleranceSweep(maxToleranceSweep),
+    gain(gain),
+    maxIterations(maxIterations)
+{
+  const size_t rows = function.GetInitialPoint().n_rows;
+  const size_t cols = function.GetInitialPoint().n_cols;
+
+  maxMove.set_size(rows, cols);
+  maxMove.fill(maxMoveCoef);
+  moveSize.set_size(rows, cols);
+  moveSize.fill(initMoveCoef);
+  accept.zeros(rows, cols);
+}
+
+//! Optimize the function (minimize).
+template<
+    typename FunctionType,
+    typename MoveDistributionType,
+    typename CoolingScheduleType
+>
+double SA<FunctionType, MoveDistributionType, CoolingScheduleType>::Optimize(
+    arma::mat &iterate)
+{
+  const size_t rows = function.GetInitialPoint().n_rows;
+  const size_t cols = function.GetInitialPoint().n_cols;
+
+  size_t i;
+  size_t frozenCount = 0;
+  energy = function.Evaluate(iterate);
+  size_t oldEnergy = energy;
+  math::RandomSeed(std::time(NULL));
+
+  nVars = rows * cols;
+  idx = 0;
+  sweepCounter = 0;
+  accept.zeros();
+
+  // Initial Moves to get rid of dependency of initial states.
+  for (i = 0; i < initMoves; ++i)
+    GenerateMove(iterate);
+
+  // Iterating and cooling.
+  for (i = 0; i != maxIterations; ++i)
+  {
+    oldEnergy = energy;
+    GenerateMove(iterate);
+    T = coolingSchedule.nextTemperature(T, energy);
+
+    // Determine if the optimization has entered (or continues to be in) a
+    // frozen state.
+    if (std::abs(energy - oldEnergy) < tolerance)
+      ++frozenCount;
+    else
+      frozenCount = 0;
+
+    // Terminate, if possible.
+    if (frozenCount >= maxToleranceSweep * nVars)
+    {
+      Log::Debug << "SA: minimized within tolerance " << tolerance << " for "
+          << maxToleranceSweep << " sweeps after " << i << " iterations; "
+          << "terminating optimization." << std::endl;
+      return energy;
+    }
+  }
+
+  Log::Debug << "SA: maximum iterations (" << maxIterations << ") reached; "
+      << "terminating optimization." << std::endl;
+  return energy;
+}
+
+/**
+ * GenerateMove proposes a move on element iterate(idx), and determines
+ * it that move is acceptable or not according to the Metropolis criterion.
+ * After that it increments idx so next call will make a move on next
+ * parameters. When all elements of the state has been moved (a sweep), it
+ * resets idx and increments sweepCounter. When sweepCounter reaches
+ * moveCtrlSweep, it performs moveControl and resets sweepCounter.
+ */
+template<
+    typename FunctionType,
+    typename MoveDistributionType,
+    typename CoolingScheduleType
+>
+void SA<FunctionType, MoveDistributionType, CoolingScheduleType>::GenerateMove(
+    arma::mat& iterate)
+{
+  double prevEnergy = energy;
+  double prevValue = iterate(idx);
+  double move = moveDistribution(moveSize(idx));
+  iterate(idx) += move;
+  energy = function.Evaluate(iterate);
+  // According to Metropolis criterion, accept the move with probability
+  // min{1, exp(-(E_new - E_old) / T)}.
+  double xi = math::Random();
+  double delta = energy - prevEnergy;
+  double criterion = std::exp(-delta / T);
+  if (delta <= 0. || criterion > xi)
+  {
+    accept(idx) += 1.;
+  }
+  else // Reject the move; restore previous state.
+  {
+    iterate(idx) = prevValue;
+    energy = prevEnergy;
+  }
+
+  ++idx;
+  if (idx == nVars) // Finished with a sweep.
+  {
+    idx = 0;
+    ++sweepCounter;
+  }
+
+  if (sweepCounter == moveCtrlSweep) // Do MoveControl().
+  {
+    MoveControl(moveCtrlSweep);
+    sweepCounter = 0;
+  }
+}
+
+/*
+ * MoveControl() uses a proportional feedback control to determine the size
+ * parameter to pass to the move generation distribution. The target of such
+ * move control is to make the acceptance ratio, accept/nMoves, be as close to
+ * 0.44 as possible. Generally speaking, the larger the move size is, the larger
+ * the function value change of the move will be, and less likely such move will
+ * be accepted by the Metropolis criterion. Thus, the move size is controlled by
+ *
+ * log(moveSize) = log(moveSize) + gain * (accept/nMoves - target)
+ *
+ * For more theory and the mysterious 0.44 value, see Jimmy K.-C. Lam and
+ * Jean-Marc Delosme. `An efficient simulated annealing schedule: derivation'.
+ * Technical Report 8816, Yale University, 1988
+ */
+template<
+    typename FunctionType,
+    typename MoveDistributionType,
+    typename CoolingScheduleType
+>
+void SA<FunctionType, MoveDistributionType, CoolingScheduleType>::MoveControl(
+    size_t nMoves)
+{
+  arma::mat target;
+  target.copy_size(accept);
+  target.fill(0.44);
+  moveSize = arma::log(moveSize);
+  moveSize += gain * (accept / (double) nMoves - target);
+  moveSize = arma::exp(moveSize);
+
+  // To avoid the use of element-wise arma::min(), which is only available in
+  // Armadillo after v3.930, we use a for loop here instead.
+  for (size_t i = 0; i < nVars; ++i)
+    moveSize(i) = (moveSize(i) > maxMove(i)) ? maxMove(i) : moveSize(i);
+
+  accept.zeros();
+}
+
+template<
+    typename FunctionType,
+    typename MoveDistributionType,
+    typename CoolingScheduleType
+>
+std::string SA<FunctionType, MoveDistributionType, CoolingScheduleType>::
+ToString() const
+{
+  std::ostringstream convert;
+  convert << "SA [" << this << "]" << std::endl;
+  convert << "  Function:" << std::endl;
+  convert << util::Indent(function.ToString(), 2);
+  convert << "  Move Distribution:" << std::endl;
+  convert << util::Indent(moveDistribution.ToString(), 2);
+  convert << "  Cooling Schedule:" << std::endl;
+  convert << util::Indent(coolingSchedule.ToString(), 2);
+  convert << "  Temperature: " << T << std::endl;
+  convert << "  Initial moves: " << initMoves << std::endl;
+  convert << "  Sweeps per move control: " << moveCtrlSweep << std::endl;
+  convert << "  Tolerance: " << tolerance << std::endl;
+  convert << "  Maximum sweeps below tolerance: " << maxToleranceSweep
+      << std::endl;
+  convert << "  Move control gain: " << gain << std::endl;
+  convert << "  Maximum iterations: " << maxIterations << std::endl;
+  return convert.str();
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index c49b70f..3175331 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -39,6 +39,7 @@ add_executable(mlpack_test
   radical_test.cpp
   range_search_test.cpp
   rectangle_tree_test.cpp
+  sa_test.cpp
   save_restore_utility_test.cpp
   sgd_test.cpp
   sort_policy_test.cpp
diff --git a/src/mlpack/tests/sa_test.cpp b/src/mlpack/tests/sa_test.cpp
new file mode 100644
index 0000000..97304b9
--- /dev/null
+++ b/src/mlpack/tests/sa_test.cpp
@@ -0,0 +1,46 @@
+/*
+ * @file sa_test.cpp
+ * @auther Zhihao Lou
+ *
+ * Test file for SA (simulated annealing).
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/sa/sa.hpp>
+#include <mlpack/core/optimizers/sa/exponential_schedule.hpp>
+#include <mlpack/core/optimizers/sa/laplace_distribution.hpp>
+#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
+
+#include <mlpack/core/metrics/ip_metric.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/metrics/mahalanobis_distance.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(SATest);
+
+BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockTest)
+{
+  size_t dim = 50;
+  GeneralizedRosenbrockFunction f(dim);
+
+  LaplaceDistribution moveDist;
+  ExponentialSchedule schedule(1e-5);
+  SA<GeneralizedRosenbrockFunction, LaplaceDistribution, ExponentialSchedule> 
+      sa(f, moveDist, schedule, 1000.,1000, 100, 1e-9, 3, 20, 0.3, 0.3, 10000000);
+  arma::mat coordinates = f.GetInitialPoint();
+  double result = sa.Optimize(coordinates);
+
+  BOOST_REQUIRE_SMALL(result, 1e-6);
+  for (size_t j = 0; j < dim; ++j)
+      BOOST_REQUIRE_CLOSE(coordinates[j], (double) 1.0, 1e-2);
+}
+
+BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list