[mlpack-svn] r11299 - mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jan 27 13:50:18 EST 2012


Author: rcurtin
Date: 2012-01-27 13:50:18 -0500 (Fri, 27 Jan 2012)
New Revision: 11299

Added:
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp
   mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp
Log:
Refactor AugLagrangian for better use by other classes.  AugLagrangianFunction
is its own standalone class now to allow for easier specialization.  The tests
aren't quite done yet (the LRSDP Lovasz-Theta ones still fail and give
unnecessary debugging output) but that's the next thing to do.


Modified: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/CMakeLists.txt	2012-01-27 17:59:38 UTC (rev 11298)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/CMakeLists.txt	2012-01-27 18:50:18 UTC (rev 11299)
@@ -3,6 +3,8 @@
 set(SOURCES
   aug_lagrangian.hpp
   aug_lagrangian_impl.hpp
+  aug_lagrangian_function.hpp
+  aug_lagrangian_function_impl.hpp
   aug_lagrangian_test_functions.hpp
   aug_lagrangian_test_functions.cpp
 )

Modified: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp	2012-01-27 17:59:38 UTC (rev 11298)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian.hpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -77,47 +77,6 @@
   LagrangianFunction& function;
   //! Number of memory points for L-BFGS.
   size_t numBasis;
-
-  /**
-   * This is a utility class, which we will pass to L-BFGS during the
-   * optimization.  We use a utility class so that we do not have to expose
-   * Evaluate() and Gradient() to the AugLagrangian public interface; instead,
-   * with a private class, these methods are correctly protected (since they
-   * should not be being used anywhere else).
-   */
-  class AugLagrangianFunction
-  {
-   public:
-    AugLagrangianFunction(LagrangianFunction& functionIn,
-                          arma::vec& lambdaIn,
-                          double sigma);
-
-    double Evaluate(const arma::mat& coordinates);
-    void Gradient(const arma::mat& coordinates, arma::mat& gradient);
-
-    const arma::mat& GetInitialPoint() const;
-
-    //! Get the Lagrange multipliers.
-    const arma::vec& Lambda() const { return lambda; }
-    //! Modify the Lagrange multipliers.
-    arma::vec& Lambda() { return lambda; }
-
-    //! Get sigma.
-    double Sigma() const { return sigma; }
-    //! Modify sigma.
-    double& Sigma() { return sigma; }
-
-    //! Get the Lagrangian function.
-    const LagrangianFunction& Function() const { return function; }
-    //! Modify the Lagrangian function.
-    LagrangianFunction& Function() { return function; }
-
-   private:
-    arma::vec lambda;
-    double sigma;
-
-    LagrangianFunction& function;
-  };
 };
 
 }; // namespace optimization

Added: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function.hpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -0,0 +1,104 @@
+/**
+ * @file aug_lagrangian_function.hpp
+ * @author Ryan Curtin
+ *
+ * Contains a utility class for AugLagrangian.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace optimization {
+
+/**
+ * This is a utility class used by AugLagrangian, meant to wrap a
+ * LagrangianFunction into a function usable by a simple optimizer like L-BFGS.
+ * Given a LagrangianFunction which follows the format outlined in the
+ * documentation for AugLagrangian, this class provides Evaluate(), Gradient(),
+ * and GetInitialPoint() functions which allow this class to be used with a
+ * simple optimizer like L-BFGS.
+ *
+ * This class can be specialized for your particular implementation -- commonly,
+ * a faster method for computing the overall objective and gradient of the
+ * augmented Lagrangian function can be implemented than the naive, default
+ * implementation given.  Use class template specialization and re-implement all
+ * of the methods (unfortunately, C++ specialization rules mean you have to
+ * re-implement everything).
+ *
+ * @tparam LagrangianFunction Lagrangian function to be used.
+ */
+template<typename LagrangianFunction>
+class AugLagrangianFunction
+{
+ public:
+  /**
+   * Initialize the AugLagrangianFunction with the given LagrangianFunction,
+   * Lagrange multipliers, and initial penalty parameter.
+   *
+   * @param function Lagrangian function.
+   * @param lambda Initial Lagrange multipliers.
+   * @param sigma Initial penalty parameter.
+   */
+  AugLagrangianFunction(LagrangianFunction& function,
+                        const arma::vec& lambda,
+                        const double sigma);
+  /**
+   * Evaluate the objective function of the Augmented Lagrangian function, which
+   * is the standard Lagrangian function evaluation plus a penalty term, which
+   * penalizes unsatisfied constraints.
+   *
+   * @param coordinates Coordinates to evaluate function at.
+   * @return Objective function.
+   */
+  double Evaluate(const arma::mat& coordinates) const;
+
+  /**
+   * Evaluate the gradient of the Augmented Lagrangian function.
+   *
+   * @param coordinates Coordinates to evaluate gradient at.
+   * @param gradient Matrix to store gradient into.
+   */
+  void Gradient(const arma::mat& coordinates, arma::mat& gradient) const;
+
+  /**
+   * Get the initial point of the optimization (supplied by the
+   * LagrangianFunction).
+   *
+   * @return Initial point.
+   */
+  const arma::mat& GetInitialPoint() const;
+
+  //! Get the Lagrange multipliers.
+  const arma::vec& Lambda() const { return lambda; }
+  //! Modify the Lagrange multipliers.
+  arma::vec& Lambda() { return lambda; }
+
+  //! Get sigma (the penalty parameter).
+  double Sigma() const { return sigma; }
+  //! Modify sigma (the penalty parameter).
+  double& Sigma() { return sigma; }
+
+  //! Get the Lagrangian function.
+  const LagrangianFunction& Function() const { return function; }
+  //! Modify the Lagrangian function.
+  LagrangianFunction& Function() { return function; }
+
+ private:
+  //! The Lagrange multipliers.
+  arma::vec lambda;
+  //! The penalty parameter.
+  double sigma;
+
+  //! Instantiation of the function to be optimized.
+  LagrangianFunction& function;
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+// Include basic implementation.
+#include "aug_lagrangian_function_impl.hpp"
+
+#endif // __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_HPP

Added: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_function_impl.hpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -0,0 +1,89 @@
+/**
+ * @file aug_lagrangian_function_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Simple, naive implementation of AugLagrangianFunction.  Better
+ * specializations can probably be given in many cases, but this is the most
+ * general case.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_FUNCTION_IMPL_HPP
+
+// In case it hasn't been included.
+#include "aug_lagrangian_function.hpp"
+
+namespace mlpack {
+namespace optimization {
+
+// Initialize the AugLagrangianFunction.
+template<typename LagrangianFunction>
+AugLagrangianFunction<LagrangianFunction>::AugLagrangianFunction(
+    LagrangianFunction& function,
+    const arma::vec& lambda,
+    const double sigma) :
+    lambda(lambda),
+    sigma(sigma),
+    function(function)
+{
+  // Nothing else to do.
+}
+
+// Evaluate the AugLagrangianFunction at the given coordinates.
+template<typename LagrangianFunction>
+double AugLagrangianFunction<LagrangianFunction>::Evaluate(
+    const arma::mat& coordinates) const
+{
+  // The augmented Lagrangian is evaluated as
+  //    f(x) + {-lambda_i * c_i(x) + (sigma / 2) c_i(x)^2} for all constraints
+
+  // First get the function's objective value.
+  double objective = function.Evaluate(coordinates);
+
+  // Now loop for each constraint.
+  for (size_t i = 0; i < function.NumConstraints(); ++i)
+  {
+    double constraint = function.EvaluateConstraint(i, coordinates);
+
+    objective += (-lambda[i] * constraint) +
+        sigma * std::pow(constraint, 2) / 2;
+  }
+
+  return objective;
+}
+
+// Evaluate the gradient of the AugLagrangianFunction at the given coordinates.
+template<typename LagrangianFunction>
+void AugLagrangianFunction<LagrangianFunction>::Gradient(
+    const arma::mat& coordinates,
+    arma::mat& gradient) const
+{
+  // The augmented Lagrangian's gradient is evaluted as
+  // f'(x) + {(-lambda_i + sigma * c_i(x)) * c'_i(x)} for all constraints
+  gradient.zeros();
+  function.Gradient(coordinates, gradient);
+
+  arma::mat constraintGradient; // Temporary for constraint gradients.
+  for (size_t i = 0; i < function.NumConstraints(); i++)
+  {
+    function.GradientConstraint(i, coordinates, constraintGradient);
+
+    // Now calculate scaling factor and add to existing gradient.
+    arma::mat tmpGradient;
+    tmpGradient = (-lambda[i] + sigma *
+        function.EvaluateConstraint(i, coordinates)) * constraintGradient;
+    gradient += tmpGradient;
+  }
+}
+
+// Get the initial point.
+template<typename LagrangianFunction>
+const arma::mat& AugLagrangianFunction<LagrangianFunction>::GetInitialPoint()
+    const
+{
+  return function.GetInitialPoint();
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp	2012-01-27 17:59:38 UTC (rev 11298)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_impl.hpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -10,6 +10,7 @@
 #define __MLPACK_CORE_OPTIMIZERS_AUG_LAGRANGIAN_AUG_LAGRANGIAN_IMPL_HPP
 
 #include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+#include "aug_lagrangian_function.hpp"
 
 namespace mlpack {
 namespace optimization {
@@ -31,13 +32,15 @@
   // Choose initial lambda parameters (vector of zeros, for simplicity).
   arma::vec lambda(function.NumConstraints());
   lambda.ones();
+  lambda *= -1;
+  lambda[0] = -double(coordinates.n_cols);
   double penalty_threshold = DBL_MAX; // Ensure we update lambda immediately.
 
   // Track the last objective to compare for convergence.
   double last_objective = function.Evaluate(coordinates);
 
   // First, we create an instance of the utility function class.
-  AugLagrangianFunction f(function, lambda, sigma);
+  AugLagrangianFunction<LagrangianFunction> f(function, lambda, sigma);
 
   // First, calculate the current penalty.
   double penalty = 0;
@@ -52,13 +55,19 @@
   size_t it;
   for (it = 0; it != (maxIterations - 1); it++)
   {
-    Log::Info << "AugLagrangian on iteration " << it
+    Log::Warn << "AugLagrangian on iteration " << it
         << ", starting with objective "  << last_objective << "." << std::endl;
 
+ //   Log::Warn << coordinates << std::endl;
+
+//    Log::Warn << trans(coordinates) * coordinates << std::endl;
+
     // Use L-BFGS to optimize this function for the given lambda and sigma.
-    L_BFGS<AugLagrangianFunction> lbfgs(f, numBasis);
+    L_BFGS<AugLagrangianFunction<LagrangianFunction> >
+        lbfgs(f, numBasis, 1e-4, 0.9, 1e-10, 100, 1e-20, 1e20);
+
     if (!lbfgs.Optimize(0, coordinates))
-      Log::Info << "L-BFGS reported an error during optimization."
+      Log::Warn << "L-BFGS reported an error during optimization."
           << std::endl;
 
     // Check if we are done with the entire optimization (the threshold we are
@@ -76,11 +85,23 @@
     // First, calculate the current penalty.
     double penalty = 0;
     for (size_t i = 0; i < function.NumConstraints(); i++)
+    {
       penalty += std::pow(function.EvaluateConstraint(i, coordinates), 2);
+//      Log::Debug << "Constraint " << i << " is " <<
+//          function.EvaluateConstraint(i, coordinates) << std::endl;
+    }
 
-    Log::Info << "Penalty is " << penalty << " (threshold "
+    Log::Warn << "Penalty is " << penalty << " (threshold "
         << penalty_threshold << ")." << std::endl;
 
+    for (size_t i = 0; i < function.NumConstraints(); ++i)
+    {
+      arma::mat tmpgrad;
+      function.GradientConstraint(i, coordinates, tmpgrad);
+//      Log::Debug << "Gradient of constraint " << i << " is " << std::endl;
+//      Log::Debug << tmpgrad << std::endl;
+    }
+
     if (penalty < penalty_threshold) // We update lambda.
     {
       // We use the update: lambda{k + 1} = lambdak - sigma * c(coordinates),
@@ -93,7 +114,7 @@
       // penalty.  TODO: this factor should be a parameter (from CLI).  The
       // value of 0.25 is taken from Burer and Monteiro (2002).
       penalty_threshold = 0.25 * penalty;
-      Log::Info << "Lagrange multiplier estimates updated." << std::endl;
+      Log::Warn << "Lagrange multiplier estimates updated." << std::endl;
     }
     else
     {
@@ -102,84 +123,13 @@
       // (2002).
       sigma *= 10;
       f.Sigma() = sigma;
-      Log::Info << "Updated sigma to " << sigma << "." << std::endl;
+      Log::Warn << "Updated sigma to " << sigma << "." << std::endl;
     }
   }
 
   return false;
 }
 
-
-template<typename LagrangianFunction>
-AugLagrangian<LagrangianFunction>::AugLagrangianFunction::AugLagrangianFunction(
-      LagrangianFunction& functionIn, arma::vec& lambdaIn, double sigma) :
-    lambda(lambdaIn),
-    sigma(sigma),
-    function(functionIn)
-{
-  // Nothing to do.
-}
-
-template<typename LagrangianFunction>
-double AugLagrangian<LagrangianFunction>::AugLagrangianFunction::Evaluate(
-    const arma::mat& coordinates)
-{
-  // The augmented Lagrangian is evaluated as
-  //   f(x) + {-lambdai * c_i(x) + (sigma / 2) c_i(x)^2} for all constraints
-//  Log::Debug << "Evaluating augmented Lagrangian." << std::endl;
-  double objective = function.Evaluate(coordinates);
-
-  // Now loop over constraints.
-  for (size_t i = 0; i < function.NumConstraints(); i++)
-  {
-    double constraint = function.EvaluateConstraint(i, coordinates);
-    objective += (-lambda[i] * constraint) +
-        sigma * std::pow(constraint, 2) / 2;
-  }
-
-//  Log::Warn << "Overall objective is " << objective << "." << std::endl;
-
-  return objective;
-}
-
-template<typename LagrangianFunction>
-void AugLagrangian<LagrangianFunction>::AugLagrangianFunction::Gradient(
-    const arma::mat& coordinates, arma::mat& gradient)
-{
-  // The augmented Lagrangian's gradient is evaluated as
-  // f'(x) + {(-lambdai + sigma * c_i(x)) * c'_i(x)} for all constraints
-//  gradient.zeros();
-  function.Gradient(coordinates, gradient);
-//  Log::Debug << "Objective function gradient norm is "
-//      << arma::norm(gradient, 2) << "." << std::endl;
-//  std::cout << gradient << std::endl;
-
-  arma::mat constraint_gradient; // Temporary for constraint gradients.
-  for (size_t i = 0; i < function.NumConstraints(); i++)
-  {
-    function.GradientConstraint(i, coordinates, constraint_gradient);
-
-    // Now calculate scaling factor and add to existing gradient.
-    arma::mat tmp_gradient;
-    tmp_gradient = (-lambda[i] + sigma *
-        function.EvaluateConstraint(i, coordinates)) * constraint_gradient;
-//    Log::Debug << "Gradient for constraint " << i << " (with lambda = "
-//        << lambda[i] << ") is " << std::endl;
-//    std::cout << tmp_gradient;
-    gradient += tmp_gradient;
-  }
-//  Log::Debug << "Overall gradient norm is " << arma::norm(gradient, 2) << "."
-//      << std::endl;
-//  std::cout << gradient << std::endl;
-}
-
-template<typename LagrangianFunction>
-const arma::mat& AugLagrangian<LagrangianFunction>::AugLagrangianFunction::
-    GetInitialPoint() const
-{
-  return function.GetInitialPoint();
-}
-
 }; // namespace optimization
 }; // namespace mlpack
 

Modified: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp	2012-01-27 17:59:38 UTC (rev 11298)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.cpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -45,7 +45,7 @@
   gradient[1] = 4 * coordinates[0] + 6 * coordinates[1];
 }
 
-double AugLagrangianTestFunction::EvaluateConstraint(size_t index,
+double AugLagrangianTestFunction::EvaluateConstraint(const size_t index,
     const arma::mat& coordinates)
 {
   // We return 0 if the index is wrong (not 0).
@@ -56,7 +56,8 @@
   return (coordinates[0] + coordinates[1] - 5);
 }
 
-void AugLagrangianTestFunction::GradientConstraint(size_t index,
+void AugLagrangianTestFunction::GradientConstraint(const size_t index,
+    const arma::mat& coordinates,
     arma::mat& gradient)
 {
   // If the user passed an invalid index (not 0), we will return a zero
@@ -108,7 +109,7 @@
   gradient[2] = 6 * (coordinates[2] + 3);
 }
 
-double GockenbachFunction::EvaluateConstraint(size_t index,
+double GockenbachFunction::EvaluateConstraint(const size_t index,
                                               const arma::mat& coordinates)
 {
   double constraint = 0;
@@ -131,7 +132,7 @@
   return constraint;
 }
 
-void GockenbachFunction::GradientConstraint(size_t index,
+void GockenbachFunction::GradientConstraint(const size_t index,
                                             const arma::mat& coordinates,
                                             arma::mat& gradient)
 {
@@ -181,10 +182,14 @@
 //  Log::Debug << "trans(coord) * coord:" << std::endl;
 //  std::cout << (trans(coordinates) * coordinates) << std::endl;
 
-  double obj = 0;
-  for (size_t i = 0; i < coordinates.n_cols; i++)
-    obj -= dot(coordinates.col(i), coordinates.col(i));
 
+  arma::mat x = trans(coordinates) * coordinates;
+  double obj = -accu(x);
+
+//  double obj = 0;
+//  for (size_t i = 0; i < coordinates.n_cols; i++)
+//    obj -= dot(coordinates.col(i), coordinates.col(i));
+
 //  Log::Debug << "Objective function is " << obj << "." << std::endl;
 
   return obj;
@@ -193,12 +198,73 @@
 void LovaszThetaSDP::Gradient(const arma::mat& coordinates,
                               arma::mat& gradient)
 {
+
+  // The gradient is equal to (2 S' R^T)^T, with R being coordinates.
+  // S' = C - sum_{i = 1}^{m} [ y_i - sigma (Tr(A_i * (R^T R)) - b_i)] * A_i
+  // We will calculate it in a not very smart way, but it should work.
+ // Log::Warn << "Using stupid specialization for gradient calculation!"
+ //    << std::endl;
+
+  // Initialize S' piece by piece.  It is of size n x n.
+  const size_t n = coordinates.n_cols;
+  arma::mat s(n, n);
+  s.ones();
+  s *= -1; // C = -ones().
+
+  for (size_t i = 0; i < NumConstraints(); ++i)
+  {
+    // Calculate [ y_i - sigma (Tr(A_i * (R^T R)) - b_i) ] * A_i.
+    // Result will be a matrix; inner result is a scalar.
+    if (i == 0)
+    {
+      // A_0 = I_n.  Hooray!  That's easy!  b_0 = 1.
+      double inner = -1 * double(n) - 0.5 *
+          (trace(trans(coordinates) * coordinates) - 1);
+
+      arma::mat zz = (inner * arma::eye<arma::mat>(n, n));
+
+//      Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
+//      Log::Debug << zz << std::endl;
+
+      s -= zz;
+    }
+    else
+    {
+      // Get edge so we can construct constraint A_i matrix.  b_i = 0.
+      arma::vec edge = edges.col(i - 1);
+
+      arma::mat a;
+      a.zeros(n, n);
+
+      // Only two nonzero entries.
+      a(edge[0], edge[1]) = 1;
+      a(edge[1], edge[0]) = 1;
+
+      double inner = (-1) - 0.5 *
+          (trace(a * (trans(coordinates) * coordinates)));
+
+      arma::mat zz = (inner * a);
+
+//      Log::Debug << "Constraint " << i << " matrix to add is " << std::endl;
+//      Log::Debug << zz << std::endl;
+
+      s -= zz;
+    }
+  }
+
+//  Log::Warn << "Calculated S is: " << std::endl << s << std::endl;
+
+  gradient = trans(2 * s * trans(coordinates));
+
+//  Log::Warn << "Calculated gradient is: " << std::endl << gradient << std::endl;
+
+
 //  Log::Debug << "Evaluating gradient. " << std::endl;
 
   // The gradient of -Tr(ones * X) is equal to -2 * ones * R
-  arma::mat ones;
-  ones.ones(coordinates.n_rows, coordinates.n_rows);
-  gradient = -2 * ones * coordinates;
+//  arma::mat ones;
+//  ones.ones(coordinates.n_rows, coordinates.n_rows);
+//  gradient = -2 * ones * coordinates;
 
 //  Log::Debug << "Done with gradient." << std::endl;
 //  std::cout << gradient;
@@ -210,14 +276,14 @@
   return edges.n_cols + 1;
 }
 
-double LovaszThetaSDP::EvaluateConstraint(size_t index,
+double LovaszThetaSDP::EvaluateConstraint(const size_t index,
                                           const arma::mat& coordinates)
 {
   if (index == 0) // This is the constraint Tr(X) = 1.
   {
     double sum = -1; // Tr(X) - 1 = 0, so we prefix the subtraction.
     for (size_t i = 0; i < coordinates.n_cols; i++)
-      sum += dot(coordinates.col(i), coordinates.col(i));
+      sum += std::abs(dot(coordinates.col(i), coordinates.col(i)));
 
 //    Log::Debug << "Constraint " << index << " evaluates to " << sum << std::endl;
     return sum;
@@ -230,17 +296,17 @@
 //    dot(coordinates.col(i), coordinates.col(j)) << "." << std::endl;
 
   // The constraint itself is X_ij, or (R^T R)_ij.
-  return dot(coordinates.col(i), coordinates.col(j));
+  return std::abs(dot(coordinates.col(i), coordinates.col(j)));
 }
 
-void LovaszThetaSDP::GradientConstraint(size_t index,
+void LovaszThetaSDP::GradientConstraint(const size_t index,
                                         const arma::mat& coordinates,
                                         arma::mat& gradient)
 {
 //  Log::Debug << "Gradient of constraint " << index << " is " << std::endl;
   if (index == 0) // This is the constraint Tr(X) = 1.
   {
-    gradient = 2 * coordinates; // d/dX (Tr(R^T R)) = 2 R.
+    gradient = 2 * coordinates; // d/dR (Tr(R R^T)) = 2 R.
 //    std::cout << gradient;
     return;
   }
@@ -256,7 +322,7 @@
   //   2 R_xj, y  = i, y != j
   //   2 R_xi, y != i, y  = j
   //   4 R_xy, y  = i, y  = j
-  // This results in the gradient matrix having two nonzero columns; for column
+  // This results in the gradient matrix having two nonzero rows; for row
   // i, the elements are R_nj, where n is the row; for column j, the elements
   // are R_ni.
   gradient.zeros(coordinates.n_rows, coordinates.n_cols);
@@ -310,5 +376,12 @@
     }
   }
 
+  Log::Debug << "Initial matrix " << std::endl << initialPoint << std::endl;
+
+  Log::Warn << "X " << std::endl << trans(initialPoint) * initialPoint
+      << std::endl;
+
+  Log::Warn << "accu " << accu(trans(initialPoint) * initialPoint) << std::endl;
+
   return initialPoint;
 }

Modified: mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp	2012-01-27 17:59:38 UTC (rev 11298)
+++ mlpack/trunk/src/mlpack/core/optimizers/aug_lagrangian/aug_lagrangian_test_functions.hpp	2012-01-27 18:50:18 UTC (rev 11299)
@@ -31,10 +31,12 @@
 
   size_t NumConstraints() const { return 1; }
 
-  double EvaluateConstraint(size_t index, const arma::mat& coordinates);
-  void GradientConstraint(size_t index, arma::mat& gradient);
+  double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+  void GradientConstraint(const size_t index,
+                          const arma::mat& coordinates,
+                          arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint() { return initialPoint; }
+  const arma::mat& GetInitialPoint() const { return initialPoint; }
 
  private:
   arma::mat initialPoint;
@@ -62,12 +64,12 @@
 
   size_t NumConstraints() const { return 2; };
 
-  double EvaluateConstraint(size_t index, const arma::mat& coordinates);
-  void GradientConstraint(size_t index,
+  double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+  void GradientConstraint(const size_t index,
                           const arma::mat& coordinates,
                           arma::mat& gradient);
 
-  const arma::mat& GetInitialPoint() { return initialPoint; }
+  const arma::mat& GetInitialPoint() const { return initialPoint; }
 
  private:
   arma::mat initialPoint;
@@ -115,16 +117,19 @@
 
   size_t NumConstraints() const;
 
-  double EvaluateConstraint(size_t index, const arma::mat& coordinates);
-  void GradientConstraint(size_t index,
+  double EvaluateConstraint(const size_t index, const arma::mat& coordinates);
+  void GradientConstraint(const size_t index,
                           const arma::mat& coordinates,
                           arma::mat& gradient);
 
   const arma::mat& GetInitialPoint();
 
+  const arma::mat& Edges() const { return edges; }
+  arma::mat&       Edges()       { return edges; }
+
  private:
   arma::mat edges;
-  int vertices;
+  size_t vertices;
 
   arma::mat initialPoint;
 };




More information about the mlpack-svn mailing list