[mlpack-git] master: WIP: first cut at sparse LR-SDP solver (31e30b0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:08:24 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 31e30b03d30d2bcb36f3007a0976110ea1639a45
Author: Stephen Tu <stephent at berkeley.edu>
Date:   Mon Dec 22 10:17:26 2014 +0800

    WIP: first cut at sparse LR-SDP solver


>---------------------------------------------------------------

31e30b03d30d2bcb36f3007a0976110ea1639a45
 src/mlpack/core/optimizers/lrsdp/lrsdp.cpp         |   5 +-
 src/mlpack/core/optimizers/lrsdp/lrsdp.hpp         |  77 ++++-----
 .../core/optimizers/lrsdp/lrsdp_function.cpp       | 178 +++++++++++----------
 .../core/optimizers/lrsdp/lrsdp_function.hpp       | 114 +++++++++----
 src/mlpack/tests/lrsdp_test.cpp                    |  33 ++--
 src/mlpack/tests/to_string_test.cpp                |   2 +-
 6 files changed, 234 insertions(+), 175 deletions(-)

diff --git a/src/mlpack/core/optimizers/lrsdp/lrsdp.cpp b/src/mlpack/core/optimizers/lrsdp/lrsdp.cpp
index 9019ded..636224e 100644
--- a/src/mlpack/core/optimizers/lrsdp/lrsdp.cpp
+++ b/src/mlpack/core/optimizers/lrsdp/lrsdp.cpp
@@ -11,9 +11,10 @@ using namespace mlpack;
 using namespace mlpack::optimization;
 using namespace std;
 
-LRSDP::LRSDP(const size_t numConstraints,
+LRSDP::LRSDP(const size_t numSparseConstraints,
+             const size_t numDenseConstraints,
              const arma::mat& initialPoint) :
-    function(numConstraints, initialPoint),
+    function(numSparseConstraints, numDenseConstraints, initialPoint),
     augLag(function)
 { }
 
diff --git a/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp b/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
index 9c4e8f5..81bb779 100644
--- a/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
+++ b/src/mlpack/core/optimizers/lrsdp/lrsdp.hpp
@@ -26,28 +26,15 @@ class LRSDP
  public:
   /**
    * Create an LRSDP to be optimized.  The solution will end up being a matrix
-   * of size (rank) x (rows).  To construct each constraint and the objective
+   * of size (rows) x (rank).  To construct each constraint and the objective
    * function, use the functions A(), B(), and C() to set them correctly.
    *
    * @param numConstraints Number of constraints in the problem.
-   * @param rank Rank of the solution (<= rows).
-   * @param rows Number of rows in the solution.
-   */
-  LRSDP(const size_t numConstraints,
-        const arma::mat& initialPoint);
-
-  /**
-   * Create an LRSDP to be optimized, passing in an already-created
-   * AugLagrangian object.  The given initial point should be set to the size
-   * (rows) x (rank), where (rank) is the reduced rank of the problem.
-   *
-   * @param numConstraints Number of constraints in the problem.
    * @param initialPoint Initial point of the optimization.
-   * @param auglag Pre-initialized AugLagrangian<LRSDP> object.
    */
-  LRSDP(const size_t numConstraints,
-        const arma::mat& initialPoint,
-        AugLagrangian<LRSDPFunction>& augLagrangian);
+  LRSDP(const size_t numSparseConstraints,
+        const size_t numDenseConstraints,
+        const arma::mat& initialPoint);
 
   /**
    * Optimize the LRSDP and return the final objective value.  The given
@@ -57,25 +44,43 @@ class LRSDP
    */
   double Optimize(arma::mat& coordinates);
 
-  //! Return the objective function matrix (C).
-  const arma::mat& C() const { return function.C(); }
-  //! Modify the objective function matrix (C).
-  arma::mat& C() { return function.C(); }
-
-  //! Return the vector of A matrices (which correspond to the constraints).
-  const std::vector<arma::mat>& A() const { return function.A(); }
-  //! Modify the veector of A matrices (which correspond to the constraints).
-  std::vector<arma::mat>& A() { return function.A(); }
-
-  //! Return the vector of modes for the A matrices.
-  const arma::uvec& AModes() const { return function.AModes(); }
-  //! Modify the vector of modes for the A matrices.
-  arma::uvec& AModes() { return function.AModes(); }
-
-  //! Return the vector of B values.
-  const arma::vec& B() const { return function.B(); }
-  //! Modify the vector of B values.
-  arma::vec& B() { return function.B(); }
+  //! Return the sparse objective function matrix (C_sparse).
+  inline const arma::sp_mat& C_sparse() const { return function.C_sparse(); }
+
+  //! Modify the sparse objective function matrix (C_sparse).
+  inline arma::sp_mat& C_sparse() { return function.C_sparse(); }
+
+  //! Return the dense objective function matrix (C_dense).
+  inline const arma::mat& C_dense() const { return function.C_dense(); }
+
+  //! Modify the dense objective function matrix (C_dense).
+  inline arma::mat& C_dense() { return function.C_dense(); }
+
+  //! Return the vector of sparse A matrices (which correspond to the sparse
+  // constraints).
+  inline const std::vector<arma::sp_mat>& A_sparse() const { return function.A_sparse(); }
+
+  //! Modify the veector of sparse A matrices (which correspond to the sparse
+  // constraints).
+  inline std::vector<arma::sp_mat>& A_sparse() { return function.A_sparse(); }
+
+  //! Return the vector of dense A matrices (which correspond to the dense
+  // constraints).
+  inline const std::vector<arma::mat>& A_dense() const { return function.A_dense(); }
+
+  //! Modify the veector of dense A matrices (which correspond to the dense
+  // constraints).
+  inline std::vector<arma::mat>& A_dense() { return function.A_dense(); }
+
+  //! Return the vector of sparse B values.
+  inline const arma::vec& B_sparse() const { return function.B_sparse(); }
+  //! Modify the vector of sparse B values.
+  inline arma::vec& B_sparse() { return function.B_sparse(); }
+
+  //! Return the vector of dense B values.
+  inline const arma::vec& B_dense() const { return function.B_dense(); }
+  //! Modify the vector of dense B values.
+  inline arma::vec& B_dense() { return function.B_dense(); }
 
   //! Return the function to be optimized.
   const LRSDPFunction& Function() const { return function; }
diff --git a/src/mlpack/core/optimizers/lrsdp/lrsdp_function.cpp b/src/mlpack/core/optimizers/lrsdp/lrsdp_function.cpp
index cca0c8e..bb4eded 100644
--- a/src/mlpack/core/optimizers/lrsdp/lrsdp_function.cpp
+++ b/src/mlpack/core/optimizers/lrsdp/lrsdp_function.cpp
@@ -10,14 +10,24 @@
 
 using namespace mlpack;
 using namespace mlpack::optimization;
+using namespace std;
 
-LRSDPFunction::LRSDPFunction(const size_t numConstraints,
+LRSDPFunction::LRSDPFunction(const size_t numSparseConstraints,
+                             const size_t numDenseConstraints,
                              const arma::mat& initialPoint):
-    a(numConstraints),
-    b(numConstraints),
-    initialPoint(initialPoint),
-    aModes(numConstraints)
-{ }
+    c_sparse(initialPoint.n_rows, initialPoint.n_rows),
+    c_dense(initialPoint.n_rows, initialPoint.n_rows, arma::fill::zeros),
+    hasModifiedSparseObjective(false),
+    hasModifiedDenseObjective(false),
+    a_sparse(numSparseConstraints),
+    b_sparse(numSparseConstraints),
+    a_dense(numDenseConstraints),
+    b_dense(numDenseConstraints),
+    initialPoint(initialPoint)
+{
+  if (initialPoint.n_rows < initialPoint.n_cols)
+    throw invalid_argument("initialPoint n_cols > n_rows");
+}
 
 double LRSDPFunction::Evaluate(const arma::mat& coordinates) const
 {
@@ -34,17 +44,11 @@ void LRSDPFunction::Gradient(const arma::mat& /* coordinates */,
 double LRSDPFunction::EvaluateConstraint(const size_t index,
                                  const arma::mat& coordinates) const
 {
-  arma::mat rrt = coordinates * trans(coordinates);
-  if (aModes[index] == 0)
-    return trace(a[index] * rrt) - b[index];
-  else
-  {
-    double value = -b[index];
-    for (size_t i = 0; i < a[index].n_cols; ++i)
-      value += a[index](2, i) * rrt(a[index](0, i), a[index](1, i));
-
-    return value;
-  }
+  const arma::mat rrt = coordinates * trans(coordinates);
+  if (index < NumSparseConstraints())
+    return trace(a_sparse[index] * rrt) - b_sparse[index];
+  const size_t index1 = index - NumSparseConstraints();
+  return trace(a_dense[index1] * rrt) - b_dense[index1];
 }
 
 void LRSDPFunction::GradientConstraint(const size_t /* index */,
@@ -58,18 +62,53 @@ void LRSDPFunction::GradientConstraint(const size_t /* index */,
 // Return a string representation of the object.
 std::string LRSDPFunction::ToString() const
 {
-  std::stringstream convert;
+  std::ostringstream convert;
   convert << "LRSDPFunction [" << this << "]" << std::endl;
-  convert << "  Number of constraints: " << a.size() << std::endl;
-  convert << "  Constraint matrix (A_i) size: " << initialPoint.n_rows << "x"
+  convert << "  Number of constraints: " << NumConstraints() << std::endl;
+  convert << "  Problem size: n=" << initialPoint.n_rows << ", r="
       << initialPoint.n_cols << std::endl;
-  convert << "  A_i modes: " << aModes.t();
-  convert << "  Constraint b_i values: " << b.t();
-  convert << "  Objective matrix (C) size: " << c.n_rows << "x" << c.n_cols
-      << std::endl;
+  convert << "  Sparse Constraint b_i values: " << b_sparse.t();
+  convert << "  Dense Constraint b_i values: " << b_dense.t();
   return convert.str();
 }
 
+template <typename MatrixType>
+static inline void
+updateObjective(double &objective,
+                const arma::mat &rrt,
+                const std::vector<MatrixType> &ais,
+                const arma::vec &bis,
+                const arma::vec &lambda,
+                size_t lambda_offset,
+                double sigma)
+{
+  for (size_t i = 0; i < ais.size(); ++i)
+  {
+    // Take the trace subtracted by the b_i.
+    double constraint = trace(ais[i] * rrt) - bis[i];
+    objective -= (lambda[lambda_offset + i] * constraint);
+    objective += (sigma / 2.) * constraint * constraint;
+  }
+}
+
+template <typename MatrixType>
+static inline void
+updateGradient(arma::mat &s,
+               const arma::mat &rrt,
+               const std::vector<MatrixType> &ais,
+               const arma::vec &bis,
+               const arma::vec &lambda,
+               size_t lambda_offset,
+               double sigma)
+{
+  for (size_t i = 0; i < ais.size(); ++i)
+  {
+    const double constraint = trace(ais[i] * rrt) - bis[i];
+    const double y = lambda[lambda_offset + i] - sigma * constraint;
+    s -= y * ais[i];
+  }
+}
+
 namespace mlpack {
 namespace optimization {
 
@@ -84,32 +123,27 @@ double AugLagrangianFunction<LRSDPFunction>::Evaluate(
   //     (sigma / 2) * sum_{i = 1}^{m} (Tr(A_i * (R R^T)) - b_i)^2
 
   // Let's start with the objective: Tr(C * (R R^T)).
-  // Simple, possibly slow solution.
-  arma::mat rrt = coordinates * trans(coordinates);
-  double objective = trace(function.C() * rrt);
+  // Simple, possibly slow solution-- see below for optimization opportunity
+  //
+  // TODO: Note that Tr(C^T * (R R^T)) = Tr( (CR)^T * R ), so
+  // multiplying C*R first, and then taking the trace dot should be more memory
+  // efficient
+  //
+  // Similarly for the constraints, taking A*R first should be more efficient
+  const arma::mat rrt = coordinates * trans(coordinates);
+  double objective = 0.;
+  if (function.hasSparseObjective())
+    objective += trace(function.C_sparse() * rrt);
+  if (function.hasDenseObjective())
+    objective += trace(function.C_dense() * rrt);
 
   // Now each constraint.
-  for (size_t i = 0; i < function.B().n_elem; ++i)
-  {
-    // Take the trace subtracted by the b_i.
-    double constraint = -function.B()[i];
-
-    if (function.AModes()[i] == 0)
-    {
-      constraint += trace(function.A()[i] * rrt);
-    }
-    else
-    {
-      for (size_t j = 0; j < function.A()[i].n_cols; ++j)
-      {
-        constraint += function.A()[i](2, j) *
-            rrt(function.A()[i](0, j), function.A()[i](1, j));
-      }
-    }
-
-    objective -= (lambda[i] * constraint);
-    objective += (sigma / 2) * std::pow(constraint, 2.0);
-  }
+  updateObjective(
+      objective, rrt, function.A_sparse(), function.B_sparse(),
+      lambda, 0, sigma);
+  updateObjective(
+      objective, rrt, function.A_dense(), function.B_dense(),
+      lambda, function.NumSparseConstraints(), sigma);
 
   return objective;
 }
@@ -125,45 +159,23 @@ void AugLagrangianFunction<LRSDPFunction>::Gradient(
   //   with
   // S' = C - sum_{i = 1}^{m} y'_i A_i
   // y'_i = y_i - sigma * (Trace(A_i * (R R^T)) - b_i)
-  arma::mat rrt = coordinates * trans(coordinates);
-  arma::mat s = function.C();
+  const arma::mat rrt = coordinates * trans(coordinates);
+  arma::mat s(function.n(), function.n(), arma::fill::zeros);
 
-  for (size_t i = 0; i < function.B().n_elem; ++i)
-  {
-    double constraint = -function.B()[i];
-
-    if (function.AModes()[i] == 0)
-    {
-      constraint += trace(function.A()[i] * rrt);
-    }
-    else
-    {
-      for (size_t j = 0; j < function.A()[i].n_cols; ++j)
-      {
-        constraint += function.A()[i](2, j) *
-            rrt(function.A()[i](0, j), function.A()[i](1, j));
-      }
-    }
-
-    double y = lambda[i] - sigma * constraint;
-
-    if (function.AModes()[i] == 0)
-    {
-      s -= (y * function.A()[i]);
-    }
-    else
-    {
-      // We only need to subtract the entries which could be modified.
-      for (size_t j = 0; j < function.A()[i].n_cols; ++j)
-      {
-        s(function.A()[i](0, j), function.A()[i](1, j)) -= y;
-      }
-    }
-  }
+  if (function.hasSparseObjective())
+    s += function.C_sparse();
+  if (function.hasDenseObjective())
+    s += function.C_dense();
+
+  updateGradient(
+      s, rrt, function.A_sparse(), function.B_sparse(),
+      lambda, 0, sigma);
+  updateGradient(
+      s, rrt, function.A_dense(), function.B_dense(),
+      lambda, function.NumSparseConstraints(), sigma);
 
   gradient = 2 * s * coordinates;
 }
 
 }; // namespace optimization
 }; // namespace mlpack
-
diff --git a/src/mlpack/core/optimizers/lrsdp/lrsdp_function.hpp b/src/mlpack/core/optimizers/lrsdp/lrsdp_function.hpp
index 9554899..2cf09a3 100644
--- a/src/mlpack/core/optimizers/lrsdp/lrsdp_function.hpp
+++ b/src/mlpack/core/optimizers/lrsdp/lrsdp_function.hpp
@@ -22,10 +22,13 @@ class LRSDPFunction
  public:
   /**
    * Construct the LRSDPFunction with the given initial point and number of
-   * constraints.  Set the A, B, and C matrices for each constraint using the
-   * A(), B(), and C() functions.
+   * constraints. Note n_cols of the initialPoint specifies the rank.
+   *
+   * Set the A_x, B_x, and C_x  matrices for each constraint using the A_x(),
+   * B_x(), and C_x() functions, for x in {sparse, dense}.
    */
-  LRSDPFunction(const size_t numConstraints,
+  LRSDPFunction(const size_t numSparseConstraints,
+                const size_t numDenseConstraints,
                 const arma::mat& initialPoint);
 
   /**
@@ -53,47 +56,98 @@ class LRSDPFunction
                           const arma::mat& coordinates,
                           arma::mat& gradient) const;
 
-  //! Get the number of constraints in the LRSDP.
-  size_t NumConstraints() const { return b.n_elem; }
+  //! Get the number of sparse constraints in the LRSDP.
+  inline size_t NumSparseConstraints() const { return b_sparse.n_elem; }
+
+  //! Get the number of dense constraints in the LRSDP.
+  inline size_t NumDenseConstraints() const { return b_dense.n_elem; }
+
+  //! Get the total number of constraints in the LRSDP.
+  inline size_t NumConstraints() const {
+    return NumSparseConstraints() + NumDenseConstraints();
+  }
 
   //! Get the initial point of the LRSDP.
-  const arma::mat& GetInitialPoint() const { return initialPoint; }
+  inline const arma::mat& GetInitialPoint() const { return initialPoint; }
+
+  inline size_t n() const { return initialPoint.n_rows; }
+
+  //! Return the sparse objective function matrix (C_sparse).
+  inline const arma::sp_mat& C_sparse() const { return c_sparse; }
+
+  //! Modify the sparse objective function matrix (C_sparse).
+  inline arma::sp_mat& C_sparse() {
+    hasModifiedSparseObjective = true;
+    return c_sparse;
+  }
+
+  //! Return the dense objective function matrix (C_dense).
+  inline const arma::mat& C_dense() const { return c_dense; }
+
+  //! Modify the dense objective function matrix (C_dense).
+  inline arma::mat& C_dense() {
+    hasModifiedDenseObjective = true;
+    return c_dense;
+  }
+
+  //! Return the vector of sparse A matrices (which correspond to the sparse
+  // constraints).
+  inline const std::vector<arma::sp_mat>& A_sparse() const { return a_sparse; }
 
-  //! Return the objective function matrix (C).
-  const arma::mat& C() const { return c; }
-  //! Modify the objective function matrix (C).
-  arma::mat& C() { return c; }
+  //! Modify the veector of sparse A matrices (which correspond to the sparse
+  // constraints).
+  inline std::vector<arma::sp_mat>& A_sparse() { return a_sparse; }
 
-  //! Return the vector of A matrices (which correspond to the constraints).
-  const std::vector<arma::mat>& A() const { return a; }
-  //! Modify the veector of A matrices (which correspond to the constraints).
-  std::vector<arma::mat>& A() { return a; }
+  //! Return the vector of dense A matrices (which correspond to the dense
+  // constraints).
+  inline const std::vector<arma::mat>& A_dense() const { return a_dense; }
 
-  //! Return the vector of modes for the A matrices.
-  const arma::uvec& AModes() const { return aModes; }
-  //! Modify the vector of modes for the A matrices.
-  arma::uvec& AModes() { return aModes; }
+  //! Modify the veector of dense A matrices (which correspond to the dense
+  // constraints).
+  inline std::vector<arma::mat>& A_dense() { return a_dense; }
 
-  //! Return the vector of B values.
-  const arma::vec& B() const { return b; }
-  //! Modify the vector of B values.
-  arma::vec& B() { return b; }
+  //! Return the vector of sparse B values.
+  inline const arma::vec& B_sparse() const { return b_sparse; }
+  //! Modify the vector of sparse B values.
+  inline arma::vec& B_sparse() { return b_sparse; }
+
+  //! Return the vector of dense B values.
+  inline const arma::vec& B_dense() const { return b_dense; }
+  //! Modify the vector of dense B values.
+  inline arma::vec& B_dense() { return b_dense; }
+
+  inline bool hasSparseObjective() const { return hasModifiedSparseObjective; }
+
+  inline bool hasDenseObjective() const { return hasModifiedDenseObjective; }
 
   //! Return string representation of object.
   std::string ToString() const;
 
  private:
-  //! Objective function matrix c.
-  arma::mat c;
-  //! A_i for each constraint.
-  std::vector<arma::mat> a;
-  //! b_i for each constraint.
-  arma::vec b;
+  //! Sparse objective function matrix c.
+  arma::sp_mat c_sparse;
+
+  //! Dense objective function matrix c.
+  arma::mat c_dense;
+
+  //! If false, c_sparse is zero
+  bool hasModifiedSparseObjective;
+
+  //! If false, c_dense is zero
+  bool hasModifiedDenseObjective;
+
+  //! A_i for each sparse constraint.
+  std::vector<arma::sp_mat> a_sparse;
+  //! b_i for each sparse constraint.
+  arma::vec b_sparse;
+
+  //! A_i for each dense constraint.
+  std::vector<arma::mat> a_dense;
+  //! b_i for each dense constraint.
+  arma::vec b_dense;
 
   //! Initial point.
   arma::mat initialPoint;
-  //! 1 if entries in matrix, 0 for normal.
-  arma::uvec aModes;
 };
 
 // Declare specializations in lrsdp_function.cpp.
diff --git a/src/mlpack/tests/lrsdp_test.cpp b/src/mlpack/tests/lrsdp_test.cpp
index 0a7a874..00ba285 100644
--- a/src/mlpack/tests/lrsdp_test.cpp
+++ b/src/mlpack/tests/lrsdp_test.cpp
@@ -58,35 +58,22 @@ void setupLovaszTheta(const arma::mat& edges,
   const size_t vertices = max(max(edges)) + 1;
 
   // C = -(e e^T) = -ones().
-  lovasz.C().ones(vertices, vertices);
-  lovasz.C() *= -1;
+  lovasz.C_dense().ones(vertices, vertices);
+  lovasz.C_dense() *= -1;
 
   // b_0 = 1; else = 0.
-  lovasz.B().zeros(edges.n_cols);
-  lovasz.B()[0] = 1;
-
-  // All of the matrices will just contain coordinates because they are
-  // super-sparse (two entries each).  Except for A_0, which is I_n.
-  lovasz.AModes().ones();
-  lovasz.AModes()[0] = 0;
+  lovasz.B_sparse().zeros(edges.n_cols);
+  lovasz.B_sparse()[0] = 1;
 
   // A_0 = I_n.
-  lovasz.A()[0].eye(vertices, vertices);
+  lovasz.A_sparse()[0].eye(vertices, vertices);
 
-  // A_ij only has ones at (i, j) and (j, i) and 1 elsewhere.
+  // A_ij only has ones at (i, j) and (j, i) and 0 elsewhere.
   for (size_t i = 0; i < edges.n_cols; ++i)
   {
-    arma::mat a(3, 2);
-
-    a(0, 0) = edges(0, i);
-    a(1, 0) = edges(1, i);
-    a(2, 0) = 1;
-
-    a(0, 1) = edges(1, i);
-    a(1, 1) = edges(0, i);
-    a(2, 1) = 1;
-
-    lovasz.A()[i + 1] = a;
+    lovasz.A_sparse()[i + 1].zeros(vertices, vertices);
+    lovasz.A_sparse()[i + 1](edges(0, i), edges(1, i)) = 1.;
+    lovasz.A_sparse()[i + 1](edges(1, i), edges(0, i)) = 1.;
   }
 
   // Set the Lagrange multipliers right.
@@ -110,7 +97,7 @@ BOOST_AUTO_TEST_CASE(Johnson844LovaszThetaSDP)
 
   createLovaszThetaInitialPoint(edges, coordinates);
 
-  LRSDP lovasz(edges.n_cols + 1, coordinates);
+  LRSDP lovasz(edges.n_cols + 1, 0, coordinates);
 
   setupLovaszTheta(edges, lovasz);
 
diff --git a/src/mlpack/tests/to_string_test.cpp b/src/mlpack/tests/to_string_test.cpp
index 8365fc3..bac9a27 100644
--- a/src/mlpack/tests/to_string_test.cpp
+++ b/src/mlpack/tests/to_string_test.cpp
@@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(LRSDPString)
   arma::mat c(40, 40);
   c.randn();
   const size_t b=3;
-  mlpack::optimization::LRSDP d(b,c);
+  mlpack::optimization::LRSDP d(b,b,c);
   Log::Debug << d;
   testOstream << d;
   std::string s = d.ToString();



More information about the mlpack-git mailing list