[mlpack-svn] r10472 - mlpack/trunk/src/mlpack/core/optimizers/lbfgs

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 30 15:42:29 EST 2011


Author: rcurtin
Date: 2011-11-30 15:42:29 -0500 (Wed, 30 Nov 2011)
New Revision: 10472

Modified:
   mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp
   mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp
   mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.cpp
   mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.hpp
Log:
Revamp L-BFGS code to remove CLI from its insides.  Naming convention fixes,
const fixes, the works.


Modified: mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp	2011-11-30 19:52:15 UTC (rev 10471)
+++ mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs.hpp	2011-11-30 20:42:29 UTC (rev 10472)
@@ -37,10 +37,26 @@
    * Initialize the L-BFGS object.  Copy the function we will be optimizing
    * and set the size of the memory for the algorithm.
    *
-   * @param function_in Instance of function to be optimized
-   * @param num_basis Number of memory points to be stored
+   * @param function Instance of function to be optimized
+   * @param numBasis Number of memory points to be stored
+   * @param armijoConstant Controls the accuracy of the line search routine for
+   *     determining the Armijo condition.
+   * @param wolfe Parameter for detecting the Wolfe condition.
+   * @param minGradientNorm Minimum gradient norm required to continue the
+   *     optimization.
+   * @param maxLineSearchTrials The maximum number of trials for the line search
+   *     (before giving up).
+   * @param minStep The minimum step of the line search.
+   * @param maxStep The maximum step of the line search.
    */
-  L_BFGS(FunctionType& function_in, int num_basis);
+  L_BFGS(const FunctionType& function,
+         const size_t numBasis,
+         const double armijoConstant = 1e-4,
+         const double wolfe = 0.9,
+         const double minGradientNorm = 1e-10,
+         const size_t maxLineSearchTrials = 50,
+         const double minStep = 1e-20,
+         const double maxStep = 1e20);
 
   /**
    * Return the point where the lowest function value has been found.
@@ -48,7 +64,7 @@
    * @return arma::vec representing the point and a double with the function
    *     value at that point.
    */
-  const std::pair<arma::mat, double>& min_point_iterate() const;
+  const std::pair<arma::mat, double>& MinPointIterate() const;
 
   /**
    * Use L-BFGS to optimize the given function, starting at the given iterate
@@ -56,27 +72,39 @@
    * iterations.  The given starting point will be modified to store the
    * finishing point of the algorithm.
    *
-   * @param num_iterations Maximum number of iterations to perform
+   * @param maxIterations Maximum number of iterations to perform
    * @param iterate Starting point (will be modified)
    */
-  bool Optimize(int num_iterations, arma::mat& iterate);
+  bool Optimize(const size_t maxIterations, arma::mat& iterate);
 
  private:
   //! Internal copy of the function we are optimizing.
-  FunctionType function_;
+  FunctionType function;
 
   //! Position of the new iterate.
-  arma::mat new_iterate_tmp_;
+  arma::mat newIterateTmp;
   //! Stores all the s matrices in memory.
-  arma::cube s_lbfgs_;
+  arma::cube s;
   //! Stores all the y matrices in memory.
-  arma::cube y_lbfgs_;
+  arma::cube y;
 
   //! Size of memory for this L-BFGS optimizer.
-  int num_basis_;
+  size_t numBasis;
+  //! Parameter for determining the Armijo condition.
+  double armijoConstant;
+  //! Parameter for detecting the Wolfe condition.
+  double wolfe;
+  //! Minimum gradient norm required to continue the optimization.
+  double minGradientNorm;
+  //! Maximum number of trials for the line search.
+  size_t maxLineSearchTrials;
+  //! Minimum step of the line search.
+  double minStep;
+  //! Maximum step of the line search.
+  double maxStep;
 
   //! Best point found so far.
-  std::pair<arma::mat, double> min_point_iterate_;
+  std::pair<arma::mat, double> minPointIterate;
 
   /**
    * Evaluate the function at the given iterate point and store the result if it
@@ -84,7 +112,7 @@
    *
    * @return The value of the function
    */
-  double Evaluate_(const arma::mat& iterate);
+  double Evaluate(const arma::mat& iterate);
 
   /**
    * Calculate the scaling factor gamma which is used to scale the Hessian
@@ -93,34 +121,35 @@
    *
    * @return The calculated scaling factor
    */
-  double ChooseScalingFactor_(int iteration_num, const arma::mat& gradient);
+  double ChooseScalingFactor(const size_t iterationNum,
+                             const arma::mat& gradient);
 
   /**
    * Check to make sure that the norm of the gradient is not smaller than 1e-5.
    * Currently that value is not configurable.
    *
-   * @return (norm < 1e-5)
+   * @return (norm < minGradientNorm)
    */
-  bool GradientNormTooSmall_(const arma::mat& gradient);
+  bool GradientNormTooSmall(const arma::mat& gradient);
 
   /**
    * Perform a back-tracking line search along the search direction to
    * calculate a step size satisfying the Wolfe conditions.  The parameter
    * iterate will be modified if the method is successful.
    *
-   * @param function_value Value of the function at the initial point
+   * @param functionValue Value of the function at the initial point
    * @param iterate The initial point to begin the line search from
    * @param gradient The gradient at the initial point
-   * @param search_direction A vector specifying the search direction
-   * @param step_size Variable the calculated step size will be stored in
+   * @param searchDirection A vector specifying the search direction
+   * @param stepSize Variable the calculated step size will be stored in
    *
    * @return false if no step size is suitable, true otherwise.
    */
-  bool LineSearch_(double& function_value,
-                   arma::mat& iterate,
-                   arma::mat& gradient,
-                   const arma::mat& search_direction,
-                   double& step_size);
+  bool LineSearch(double& functionValue,
+                  arma::mat& iterate,
+                  arma::mat& gradient,
+                  const arma::mat& searchDirection,
+                  double& stepSize);
 
   /**
    * Find the L-BFGS search direction.
@@ -130,27 +159,27 @@
    * @param scaling_factor Scaling factor to use (see ChooseScalingFactor_())
    * @param search_direction Vector to store search direction in
    */
-  void SearchDirection_(const arma::mat& gradient,
-                        int iteration_num,
-                        double scaling_factor,
-                        arma::mat& search_direction);
+  void SearchDirection(const arma::mat& gradient,
+                       const size_t iterationNum,
+                       const double scalingFactor,
+                       arma::mat& searchDirection);
 
   /**
-   * Update the vectors y_bfgs_ and s_bfgs_, which store the differences
+   * Update the y and s matrices, which store the differences
    * between the iterate and old iterate and the differences between the
    * gradient and the old gradient, respectively.
    *
-   * @param iteration_num Iteration number
+   * @param iterationNum Iteration number
    * @param iterate Current point
-   * @param old_iterate Point at last iteration
+   * @param oldIterate Point at last iteration
    * @param gradient Gradient at current point (iterate)
-   * @param old_gradient Gradient at last iteration point (old_iterate)
+   * @param oldGradient Gradient at last iteration point (oldIterate)
    */
-  void UpdateBasisSet_(int iteration_num,
-                       const arma::mat& iterate,
-                       const arma::mat& old_iterate,
-                       const arma::mat& gradient,
-                       const arma::mat& old_gradient);
+  void UpdateBasisSet(const size_t iterationNum,
+                      const arma::mat& iterate,
+                      const arma::mat& oldIterate,
+                      const arma::mat& gradient,
+                      const arma::mat& oldGradient);
 };
 
 }; // namespace optimization

Modified: mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp	2011-11-30 19:52:15 UTC (rev 10471)
+++ mlpack/trunk/src/mlpack/core/optimizers/lbfgs/lbfgs_impl.hpp	2011-11-30 20:42:29 UTC (rev 10472)
@@ -18,19 +18,19 @@
  * @return The value of the function
  */
 template<typename FunctionType>
-double L_BFGS<FunctionType>::Evaluate_(const arma::mat& iterate)
+double L_BFGS<FunctionType>::Evaluate(const arma::mat& iterate)
 {
   // Evaluate the function and keep track of the minimum function
   // value encountered during the optimization.
-  double function_value = function_.Evaluate(iterate);
+  double functionValue = function.Evaluate(iterate);
 
-  if (function_value < min_point_iterate_.second)
+  if (functionValue < minPointIterate.second)
   {
-    min_point_iterate_.first = iterate;
-    min_point_iterate_.second = function_value;
+    minPointIterate.first = iterate;
+    minPointIterate.second = functionValue;
   }
 
-  return function_value;
+  return functionValue;
 }
 
 /**
@@ -40,115 +40,111 @@
  * @return The calculated scaling factor
  */
 template<typename FunctionType>
-double L_BFGS<FunctionType>::ChooseScalingFactor_(int iteration_num,
-                                                  const arma::mat& gradient)
+double L_BFGS<FunctionType>::ChooseScalingFactor(const size_t iterationNum,
+                                                 const arma::mat& gradient)
 {
-  double scaling_factor = 1.0;
-  if (iteration_num > 0)
+  double scalingFactor = 1.0;
+  if (iterationNum > 0)
   {
-    int previous_pos = (iteration_num - 1) % num_basis_;
+    int previousPos = (iterationNum - 1) % numBasis;
     // Get s and y matrices once instead of multiple times.
-    arma::mat& s_col = s_lbfgs_.slice(previous_pos);
-    arma::mat& y_col = y_lbfgs_.slice(previous_pos);
-    scaling_factor = dot(s_col, y_col) / dot(y_col, y_col);
+    arma::mat& sMat = s.slice(previousPos);
+    arma::mat& yMat = y.slice(previousPos);
+    scalingFactor = dot(sMat, yMat) / dot(yMat, yMat);
   }
   else
   {
-    scaling_factor = 1.0 / sqrt(dot(gradient, gradient));
+    scalingFactor = 1.0 / sqrt(dot(gradient, gradient));
   }
 
-  return scaling_factor;
+  return scalingFactor;
 }
 
 /**
  * Check to make sure that the norm of the gradient is not smaller than 1e-10.
  * Currently that value is not configurable.
  *
- * @return (norm < lbfgs/min_gradient_norm)
+ * @return (norm < minGradientNorm)
  */
 template<typename FunctionType>
-bool L_BFGS<FunctionType>::GradientNormTooSmall_(const arma::mat& gradient)
+bool L_BFGS<FunctionType>::GradientNormTooSmall(const arma::mat& gradient)
 {
   double norm = arma::norm(gradient, 2);
 
-  return (norm < CLI::GetParam<double>("lbfgs/min_gradient_norm"));
+  return (norm < minGradientNorm);
 }
 
 /**
  * Perform a back-tracking line search along the search direction to calculate a
  * step size satisfying the Wolfe conditions.
  *
- * @param function_value Value of the function at the initial point
+ * @param functionValue Value of the function at the initial point
  * @param iterate The initial point to begin the line search from
  * @param gradient The gradient at the initial point
- * @param search_direction A vector specifying the search direction
- * @param step_size Variable the calculated step size will be stored in
+ * @param searchDirection A vector specifying the search direction
+ * @param stepSize Variable the calculated step size will be stored in
  *
  * @return false if no step size is suitable, true otherwise.
  */
 template<typename FunctionType>
-bool L_BFGS<FunctionType>::LineSearch_(double& function_value,
-                                       arma::mat& iterate,
-                                       arma::mat& gradient,
-                                       const arma::mat& search_direction,
-                                       double& step_size)
+bool L_BFGS<FunctionType>::LineSearch(double& functionValue,
+                                      arma::mat& iterate,
+                                      arma::mat& gradient,
+                                      const arma::mat& searchDirection,
+                                      double& stepSize)
 {
   // The initial linear term approximation in the direction of the
   // search direction.
-  double initial_search_direction_dot_gradient =
-      arma::dot(gradient, search_direction);
+  double initialSearchDirectionDotGradient =
+      arma::dot(gradient, searchDirection);
 
   // If it is not a descent direction, just report failure.
-  if (initial_search_direction_dot_gradient > 0.0)
+  if (initialSearchDirectionDotGradient > 0.0)
     return false;
 
   // Save the initial function value.
-  double initial_function_value = function_value;
+  double initialFunctionValue = functionValue;
 
   // Unit linear approximation to the decrease in function value.
-  double linear_approx_function_value_decrease =
-      CLI::GetParam<double>("lbfgs/armijo_constant") *
-      initial_search_direction_dot_gradient;
+  double linearApproxFunctionValueDecrease = armijoConstant *
+      initialSearchDirectionDotGradient;
 
   // The number of iteration in the search.
-  int num_iterations = 0;
+  size_t numIterations = 0;
 
   // Armijo step size scaling factor for increase and decrease.
   const double inc = 2.1;
   const double dec = 0.5;
   double width = 0;
 
-  while(true)
+  while (true)
   {
     // Perform a step and evaluate the gradient and the function values at that
     // point.
-    new_iterate_tmp_ = iterate;
-    new_iterate_tmp_ += step_size * search_direction;
-    function_value = Evaluate_(new_iterate_tmp_);
-    function_.Gradient(new_iterate_tmp_, gradient);
-    num_iterations++;
+    newIterateTmp = iterate;
+    newIterateTmp += stepSize * searchDirection;
+    functionValue = Evaluate(newIterateTmp);
+    function.Gradient(newIterateTmp, gradient);
+    numIterations++;
 
-    if (function_value > initial_function_value + step_size *
-        linear_approx_function_value_decrease)
+    if (functionValue > initialFunctionValue + stepSize *
+        linearApproxFunctionValueDecrease)
     {
       width = dec;
     }
     else
     {
       // Check Wolfe's condition.
-      double search_direction_dot_gradient =
-          arma::dot(gradient, search_direction);
-      double wolfe = CLI::GetParam<double>("lbfgs/wolfe");
+      double searchDirectionDotGradient = arma::dot(gradient, searchDirection);
 
-      if(search_direction_dot_gradient < wolfe *
-          initial_search_direction_dot_gradient)
+      if(searchDirectionDotGradient < wolfe * initialSearchDirectionDotGradient)
       {
         width = inc;
       }
       else
       {
-        if (search_direction_dot_gradient > -wolfe *
-            initial_search_direction_dot_gradient)
+        if (searchDirectionDotGradient > -wolfe *
+            initialSearchDirectionDotGradient)
         {
           width = dec;
         }
@@ -161,19 +157,18 @@
 
     // Terminate when the step size gets too small or too big or it
     // exceeds the max number of iterations.
-    if ((step_size < CLI::GetParam<double>("lbfgs/min_step")) ||
-        (step_size > CLI::GetParam<double>("lbfgs/max_step")) ||
-        (num_iterations >= CLI::GetParam<int>("lbfgs/max_line_search_trials")))
+    if ((stepSize < minStep) || (stepSize > maxStep) ||
+        (numIterations >= maxLineSearchTrials))
     {
       return false;
     }
 
     // Scale the step size.
-    step_size *= width;
+    stepSize *= width;
   }
 
   // Move to the new iterate.
-  iterate = new_iterate_tmp_;
+  iterate = newIterateTmp;
   return true;
 }
 
@@ -181,102 +176,124 @@
  * Find the L_BFGS search direction.
  *
  * @param gradient The gradient at the current point
- * @param iteration_num The iteration number
- * @param scaling_factor Scaling factor to use (see ChooseScalingFactor_())
- * @param search_direction Vector to store search direction in
+ * @param iterationNum The iteration number
+ * @param scalingFactor Scaling factor to use (see ChooseScalingFactor_())
+ * @param searchDirection Vector to store search direction in
  */
 template<typename FunctionType>
-void L_BFGS<FunctionType>::SearchDirection_(const arma::mat& gradient,
-                                            int iteration_num,
-                                            double scaling_factor,
-                                            arma::mat& search_direction)
+void L_BFGS<FunctionType>::SearchDirection(const arma::mat& gradient,
+                                           const size_t iterationNum,
+                                           const double scalingFactor,
+                                           arma::mat& searchDirection)
 {
-  arma::mat q = gradient;
+  // Start from this point.
+  searchDirection = gradient;
 
   // See "A Recursive Formula to Compute H * g" in "Updating quasi-Newton
   // matrices with limited storage" (Nocedal, 1980).
 
   // Temporary variables.
-  arma::vec rho(num_basis_);
-  arma::vec alpha(num_basis_);
+  arma::vec rho(numBasis);
+  arma::vec alpha(numBasis);
 
-  int limit = std::max(iteration_num - num_basis_, 0);
-  for (int i = iteration_num - 1; i >= limit; i--)
+  size_t limit = (numBasis > iterationNum) ? 0 : (iterationNum - numBasis);
+  for (size_t i = iterationNum; i != limit; i--)
   {
-    int translated_position = i % num_basis_;
-    rho[iteration_num - i - 1] = 1.0 / arma::dot(
-        y_lbfgs_.slice(translated_position),
-        s_lbfgs_.slice(translated_position));
-    alpha[iteration_num - i - 1] = rho[iteration_num - i - 1] *
-        arma::dot(s_lbfgs_.slice(translated_position), q);
-    q -= alpha[iteration_num - i - 1] * y_lbfgs_.slice(translated_position);
+    int translatedPosition = (i + (numBasis - 1)) % numBasis;
+    rho[iterationNum - i] = 1.0 / arma::dot(y.slice(translatedPosition),
+                                            s.slice(translatedPosition));
+    alpha[iterationNum - i] = rho[iterationNum - i] *
+        arma::dot(s.slice(translatedPosition), searchDirection);
+    searchDirection -= alpha[iterationNum - i] * y.slice(translatedPosition);
   }
 
-  search_direction = scaling_factor * q;
+  searchDirection *= scalingFactor;
 
-  for (int i = limit; i <= iteration_num - 1; i++)
+  for (size_t i = limit; i < iterationNum; i++)
   {
-    int translated_position = i % num_basis_;
-    double beta = rho[iteration_num - i - 1] *
-        arma::dot(y_lbfgs_.slice(translated_position), search_direction);
-    search_direction += (alpha[iteration_num - i - 1] - beta) *
-        s_lbfgs_.slice(translated_position);
+    int translatedPosition = i % numBasis;
+    double beta = rho[iterationNum - i - 1] *
+        arma::dot(y.slice(translatedPosition), searchDirection);
+    searchDirection += (alpha[iterationNum - i - 1] - beta) *
+        s.slice(translatedPosition);
   }
 
   // Negate the search direction so that it is a descent direction.
-  search_direction *= -1;
+  searchDirection *= -1;
 }
 
 /**
- * Update the vectors y_bfgs_ and s_bfgs_, which store the differences between
+ * Update the y and s matrices, which store the differences between
  * the iterate and old iterate and the differences between the gradient and the
  * old gradient, respectively.
  *
- * @param iteration_num Iteration number
+ * @param iterationNum Iteration number
  * @param iterate Current point
- * @param old_iterate Point at last iteration
+ * @param oldIterate Point at last iteration
  * @param gradient Gradient at current point (iterate)
- * @param old_gradient Gradient at last iteration point (old_iterate)
+ * @param oldGradient Gradient at last iteration point (oldIterate)
  */
 template<typename FunctionType>
-void L_BFGS<FunctionType>::UpdateBasisSet_(int iteration_num,
-                                           const arma::mat& iterate,
-                                           const arma::mat& old_iterate,
-                                           const arma::mat& gradient,
-                                           const arma::mat& old_gradient)
+void L_BFGS<FunctionType>::UpdateBasisSet(const size_t iterationNum,
+                                          const arma::mat& iterate,
+                                          const arma::mat& oldIterate,
+                                          const arma::mat& gradient,
+                                          const arma::mat& oldGradient)
 {
   // Overwrite a certain position instead of pushing everything in the vector
-  // back one position
-  int overwrite_pos = iteration_num % num_basis_;
-  s_lbfgs_.slice(overwrite_pos) = iterate - old_iterate;
-  y_lbfgs_.slice(overwrite_pos) = gradient - old_gradient;
+  // back one position.
+  int overwritePos = iterationNum % numBasis;
+  s.slice(overwritePos) = iterate - oldIterate;
+  y.slice(overwritePos) = gradient - oldGradient;
 }
 
-/***
+/**
  * Initialize the L_BFGS object.  Copy the function we will be optimizing and
  * set the size of the memory for the algorithm.
  *
- * @param function_in Instance of function to be optimized
- * @param num_basis Number of memory points to be stored
+ * @param function Instance of function to be optimized
+ * @param numBasis Number of memory points to be stored
+ * @param armijoConstant Controls the accuracy of the line search routine for
+ *     determining the Armijo condition.
+ * @param wolfe Parameter for detecting the Wolfe condition.
+ * @param minGradientNorm Minimum gradient norm required to continue the
+ *     optimization.
+ * @param maxLineSearchTrials The maximum number of trials for the line search
+ *     (before giving up).
+ * @param minStep The minimum step of the line search.
+ * @param maxStep The maximum step of the line search.
  */
 template<typename FunctionType>
-L_BFGS<FunctionType>::L_BFGS(FunctionType& function_in, int num_basis) :
-  function_(function_in)
+L_BFGS<FunctionType>::L_BFGS(const FunctionType& function,
+                             const size_t numBasis,
+                             const double armijoConstant,
+                             const double wolfe,
+                             const double minGradientNorm,
+                             const size_t maxLineSearchTrials,
+                             const double minStep,
+                             const double maxStep) :
+    function(function),
+    numBasis(numBasis),
+    armijoConstant(armijoConstant),
+    wolfe(wolfe),
+    minGradientNorm(minGradientNorm),
+    maxLineSearchTrials(maxLineSearchTrials),
+    minStep(minStep),
+    maxStep(maxStep)
 {
   // Get the dimensions of the coordinates of the function; GetInitialPoint()
   // might return an arma::vec, but that's okay because then n_cols will simply
   // be 1.
-  int rows = function_.GetInitialPoint().n_rows;
-  int cols = function_.GetInitialPoint().n_cols;
+  int rows = function.GetInitialPoint().n_rows;
+  int cols = function.GetInitialPoint().n_cols;
 
-  new_iterate_tmp_.set_size(rows, cols);
-  s_lbfgs_.set_size(rows, cols, num_basis);
-  y_lbfgs_.set_size(rows, cols, num_basis);
-  num_basis_ = num_basis;
+  newIterateTmp.set_size(rows, cols);
+  s.set_size(rows, cols, numBasis);
+  y.set_size(rows, cols, numBasis);
 
   // Allocate the pair holding the min iterate information.
-  min_point_iterate_.first.zeros(rows, cols);
-  min_point_iterate_.second = std::numeric_limits<double>::max();
+  minPointIterate.first.zeros(rows, cols);
+  minPointIterate.second = std::numeric_limits<double>::max();
 }
 
 /**
@@ -286,10 +303,10 @@
  *     value at that point.
  */
 template<typename FunctionType>
-const std::pair<arma::mat, double>&
-L_BFGS<FunctionType>::min_point_iterate() const
+inline const std::pair<arma::mat, double>&
+L_BFGS<FunctionType>::MinPointIterate() const
 {
-  return min_point_iterate_;
+  return minPointIterate;
 }
 
 /**
@@ -298,47 +315,48 @@
  * The given starting point will be modified to store the finishing point of the
  * algorithm.
  *
- * @param num_iterations Maximum number of iterations to perform
+ * @param numIterations Maximum number of iterations to perform
  * @param iterate Starting point (will be modified)
  */
 template<typename FunctionType>
-bool L_BFGS<FunctionType>::Optimize(int num_iterations, arma::mat& iterate)
+bool L_BFGS<FunctionType>::Optimize(const size_t numIterations,
+                                    arma::mat& iterate)
 {
   // The old iterate to be saved.
-  arma::mat old_iterate;
-  old_iterate.zeros(iterate.n_rows, iterate.n_cols);
+  arma::mat oldIterate;
+  oldIterate.zeros(iterate.n_rows, iterate.n_cols);
 
   // Whether to optimize until convergence.
-  bool optimize_until_convergence = (num_iterations <= 0);
+  bool optimizeUntilConvergence = (numIterations == 0);
 
   // The initial function value.
-  double function_value = Evaluate_(iterate);
+  double functionValue = Evaluate(iterate);
 
   // The gradient: the current and the old.
   arma::mat gradient;
-  arma::mat old_gradient;
+  arma::mat oldGradient;
   gradient.zeros(iterate.n_rows, iterate.n_cols);
-  old_gradient.zeros(iterate.n_rows, iterate.n_cols);
+  oldGradient.zeros(iterate.n_rows, iterate.n_cols);
 
   // The search direction.
-  arma::mat search_direction;
-  search_direction.zeros(iterate.n_rows, iterate.n_cols);
+  arma::mat searchDirection;
+  searchDirection.zeros(iterate.n_rows, iterate.n_cols);
 
   // The initial gradient value.
-  function_.Gradient(iterate, gradient);
+  function.Gradient(iterate, gradient);
 
   // The flag denoting whether or not the optimization has been successful.
   bool success = false;
 
-  // The main optimization loop.
-  for (int it_num = 0; optimize_until_convergence || it_num < num_iterations;
-       it_num++)
+  // The main optimization loop.  Start from 1 to allow running forever.
+  for (size_t itNum = 0; optimizeUntilConvergence || (itNum != numIterations);
+       itNum++)
   {
-    Log::Debug << "L-BFGS iteration " << it_num << "; objective " <<
-        function_.Evaluate(iterate) << "." << std::endl;
+    Log::Debug << "L-BFGS iteration " << itNum << "; objective " <<
+        function.Evaluate(iterate) << "." << std::endl;
 
     // Break when the norm of the gradient becomes too small.
-    if(GradientNormTooSmall_(gradient))
+    if(GradientNormTooSmall(gradient))
     {
       success = true; // We have found the minimum.
       Log::Debug << "L-BFGS gradient norm too small (terminating)."
@@ -347,28 +365,28 @@
     }
 
     // Choose the scaling factor.
-    double scaling_factor = ChooseScalingFactor_(it_num, gradient);
+    double scalingFactor = ChooseScalingFactor(itNum, gradient);
 
     // Build an approximation to the Hessian and choose the search
     // direction for the current iteration.
-    SearchDirection_(gradient, it_num, scaling_factor, search_direction);
+    SearchDirection(gradient, itNum, scalingFactor, searchDirection);
 
     // Save the old iterate and the gradient before stepping.
-    old_iterate = iterate;
-    old_gradient = gradient;
+    oldIterate = iterate;
+    oldGradient = gradient;
 
     // Do a line search and take a step.
-    double step_size = 1.0;
-    success = LineSearch_(function_value, iterate, gradient, search_direction,
-        step_size);
+    double stepSize = 1.0;
+    success = LineSearch(functionValue, iterate, gradient, searchDirection,
+        stepSize);
 
     if (!success)
       break; // The line search failed; nothing else to try.
 
     // Overwrite an old basis set.
-    UpdateBasisSet_(it_num, iterate, old_iterate, gradient, old_gradient);
+    UpdateBasisSet(itNum, iterate, oldIterate, gradient, oldGradient);
 
-  } // end of the optimization loop.
+  } // End of the optimization loop.
 
   return success;
 }

Modified: mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.cpp	2011-11-30 19:52:15 UTC (rev 10471)
+++ mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.cpp	2011-11-30 20:42:29 UTC (rev 10472)
@@ -50,7 +50,7 @@
   gradient[1] = 200 * (x2 - std::pow(x1, 2));
 }
 
-const arma::mat& RosenbrockFunction::GetInitialPoint()
+const arma::mat& RosenbrockFunction::GetInitialPoint() const
 {
   return initial_point;
 }
@@ -114,7 +114,7 @@
       (1 / 5) * (x2 - x4);
 }
 
-const arma::mat& WoodFunction::GetInitialPoint()
+const arma::mat& WoodFunction::GetInitialPoint() const
 {
   return initial_point;
 }
@@ -170,7 +170,7 @@
       std::pow(coordinates[n - 2], 2));
 }
 
-const arma::mat& GeneralizedRosenbrockFunction::GetInitialPoint()
+const arma::mat& GeneralizedRosenbrockFunction::GetInitialPoint() const
 {
   return initial_point;
 }
@@ -215,7 +215,7 @@
   gradient.col(1) = gwf;
 }
 
-const arma::mat& RosenbrockWoodFunction::GetInitialPoint()
+const arma::mat& RosenbrockWoodFunction::GetInitialPoint() const
 {
   return initial_point;
 }

Modified: mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.hpp	2011-11-30 19:52:15 UTC (rev 10471)
+++ mlpack/trunk/src/mlpack/core/optimizers/lbfgs/test_functions.hpp	2011-11-30 20:42:29 UTC (rev 10472)
@@ -53,7 +53,7 @@
   double Evaluate(const arma::mat& coordinates);
   void Gradient(const arma::mat& coordinates, arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint();
+  const arma::mat& GetInitialPoint() const;
 
  private:
   arma::mat initial_point;
@@ -82,7 +82,7 @@
   double Evaluate(const arma::mat& coordinates);
   void Gradient(const arma::mat& coordinates, arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint();
+  const arma::mat& GetInitialPoint() const;
 
  private:
   arma::mat initial_point;
@@ -112,7 +112,7 @@
   double Evaluate(const arma::mat& coordinates);
   void Gradient(const arma::mat& coordinates, arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint();
+  const arma::mat& GetInitialPoint() const;
 
  private:
   arma::mat initial_point;
@@ -132,7 +132,7 @@
   double Evaluate(const arma::mat& coordinates);
   void Gradient(const arma::mat& coordinates, arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint();
+  const arma::mat& GetInitialPoint() const;
 
  private:
   arma::mat initial_point;




More information about the mlpack-svn mailing list