[mlpack-git] master: make the SDP objective matrix type a template parameter (fca638d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:14:05 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit fca638d2434faaada728a79f33bdde6f802e54d7
Author: Stephen Tu <stephent at berkeley.edu>
Date:   Wed Jan 21 14:49:38 2015 -0800

    make the SDP objective matrix type a template parameter


>---------------------------------------------------------------

fca638d2434faaada728a79f33bdde6f802e54d7
 src/mlpack/core/math/CMakeLists.txt                |   1 +
 src/mlpack/core/math/lin_alg.cpp                   |  10 +-
 src/mlpack/core/math/lin_alg.hpp                   |  14 +--
 src/mlpack/core/math/lin_alg_impl.hpp              |  23 ++++
 src/mlpack/core/optimizers/sdp/CMakeLists.txt      |   6 +-
 src/mlpack/core/optimizers/sdp/primal_dual.hpp     |  32 ++++--
 .../sdp/{primal_dual.cpp => primal_dual_impl.hpp}  | 126 ++++++++++-----------
 src/mlpack/core/optimizers/sdp/sdp.hpp             |  46 +++-----
 .../core/optimizers/sdp/{sdp.cpp => sdp_impl.hpp}  |  22 ++--
 src/mlpack/tests/lin_alg_test.cpp                  |  24 ++--
 src/mlpack/tests/sdp_primal_dual_test.cpp          |  78 ++++++-------
 11 files changed, 195 insertions(+), 187 deletions(-)

diff --git a/src/mlpack/core/math/CMakeLists.txt b/src/mlpack/core/math/CMakeLists.txt
index d90d31f..50e6635 100644
--- a/src/mlpack/core/math/CMakeLists.txt
+++ b/src/mlpack/core/math/CMakeLists.txt
@@ -3,6 +3,7 @@
 set(SOURCES
   clamp.hpp
   lin_alg.hpp
+  lin_alg_impl.hpp
   lin_alg.cpp
   random.hpp
   random.cpp
diff --git a/src/mlpack/core/math/lin_alg.cpp b/src/mlpack/core/math/lin_alg.cpp
index cf68aa4..21c9e1c 100644
--- a/src/mlpack/core/math/lin_alg.cpp
+++ b/src/mlpack/core/math/lin_alg.cpp
@@ -229,7 +229,7 @@ void mlpack::math::Svec(const arma::mat& input, arma::vec& output)
   }
 }
 
-void mlpack::math::Svec(const arma::sp_mat& input, arma::sp_mat& output)
+void mlpack::math::Svec(const arma::sp_mat& input, arma::sp_vec& output)
 {
   const size_t n = input.n_rows;
   const size_t n2bar = n * (n + 1) / 2;
@@ -270,14 +270,6 @@ void mlpack::math::Smat(const arma::vec& input, arma::mat& output)
   }
 }
 
-size_t mlpack::math::SvecIndex(size_t i, size_t j, size_t n)
-{
-  if (i > j)
-    std::swap(i, j);
-  return (j-i) + (n*(n+1) - (n-i)*(n-i+1))/2;
-}
-
-
 void mlpack::math::SymKronId(const arma::mat& A, arma::mat& op)
 {
   // TODO(stephentu): there's probably an easier way to build this operator
diff --git a/src/mlpack/core/math/lin_alg.hpp b/src/mlpack/core/math/lin_alg.hpp
index 0e0998e..b1fab56 100644
--- a/src/mlpack/core/math/lin_alg.hpp
+++ b/src/mlpack/core/math/lin_alg.hpp
@@ -87,14 +87,7 @@ void RemoveRows(const arma::mat& input,
  */
 void Svec(const arma::mat& input, arma::vec& output);
 
-/**
- * Svec for sparse matrices.
- * NOTE: armadillo doesn't have an sp_vec type, so the output type is sp_mat.
- *
- * @param input sparse A symmetric matrix
- * @param output
- */
-void Svec(const arma::sp_mat& input, arma::sp_mat& output);
+void Svec(const arma::sp_mat& input, arma::sp_vec& output);
 
 /**
  * The inverse of Svec. That is, Smat(Svec(A)) == A.
@@ -112,7 +105,7 @@ void Smat(const arma::vec& input, arma::mat& output);
  * @param j
  * @param n
  */
-size_t SvecIndex(size_t i, size_t j, size_t n);
+inline size_t SvecIndex(size_t i, size_t j, size_t n);
 
 /**
  * If A is a symmetric matrix, then SymKronId returns an operator Op such that
@@ -129,4 +122,7 @@ void SymKronId(const arma::mat& A, arma::mat& op);
 }; // namespace math
 }; // namespace mlpack
 
+// Partially include implementation
+#include "lin_alg_impl.hpp"
+
 #endif // __MLPACK_CORE_MATH_LIN_ALG_HPP
diff --git a/src/mlpack/core/math/lin_alg_impl.hpp b/src/mlpack/core/math/lin_alg_impl.hpp
new file mode 100644
index 0000000..049118e
--- /dev/null
+++ b/src/mlpack/core/math/lin_alg_impl.hpp
@@ -0,0 +1,23 @@
+/**
+ * @file lin_alg_impl.hpp
+ * @author Stephen Tu
+ */
+#ifndef __MLPACK_CORE_MATH_LIN_ALG_IMPL_HPP
+#define __MLPACK_CORE_MATH_LIN_ALG_IMPL_HPP
+
+#include "lin_alg.hpp"
+
+namespace mlpack {
+namespace math {
+
+inline size_t SvecIndex(size_t i, size_t j, size_t n)
+{
+  if (i > j)
+    std::swap(i, j);
+  return (j-i) + (n*(n+1) - (n-i)*(n-i+1))/2;
+}
+
+} // namespace math
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/optimizers/sdp/CMakeLists.txt b/src/mlpack/core/optimizers/sdp/CMakeLists.txt
index cba5d07..266d514 100644
--- a/src/mlpack/core/optimizers/sdp/CMakeLists.txt
+++ b/src/mlpack/core/optimizers/sdp/CMakeLists.txt
@@ -1,8 +1,8 @@
 set(SOURCES
-  sdp.hpp
-  sdp.cpp
   primal_dual.hpp
-  primal_dual.cpp
+  primal_dual_impl.hpp
+  sdp.hpp
+  sdp_impl.hpp
 )
 
 set(DIR_SRCS)
diff --git a/src/mlpack/core/optimizers/sdp/primal_dual.hpp b/src/mlpack/core/optimizers/sdp/primal_dual.hpp
index 9480b20..8b110d1 100644
--- a/src/mlpack/core/optimizers/sdp/primal_dual.hpp
+++ b/src/mlpack/core/optimizers/sdp/primal_dual.hpp
@@ -12,16 +12,17 @@
 namespace mlpack {
 namespace optimization {
 
+template <typename SDPType>
 class PrimalDualSolver {
  public:
 
-  PrimalDualSolver(const SDP& sdp);
+  PrimalDualSolver(const SDPType& sdp);
 
-  PrimalDualSolver(const SDP& sdp,
-                   const arma::mat& X0,
-                   const arma::vec& ysparse0,
-                   const arma::vec& ydense0,
-                   const arma::mat& Z0);
+  PrimalDualSolver(const SDPType& sdp,
+                   const arma::mat& initialX,
+                   const arma::vec& initialYsparse,
+                   const arma::vec& initialYdense,
+                   const arma::mat& initialZ);
 
   std::pair<bool, double>
   Optimize(arma::mat& X,
@@ -37,19 +38,25 @@ class PrimalDualSolver {
     return Optimize(X, ysparse, ydense, Z);
   }
 
+  const SDPType& Sdp() const { return sdp; }
+
   double& Tau() { return tau; }
+
   double& NormXzTol() { return normXzTol; }
+
   double& PrimalInfeasTol() { return primalInfeasTol; }
+
   double& DualInfeasTol() { return dualInfeasTol; }
+
   size_t& MaxIterations() { return maxIterations; }
 
  private:
-  SDP sdp;
+  SDPType sdp;
 
-  arma::mat X0;
-  arma::vec ysparse0;
-  arma::vec ydense0;
-  arma::mat Z0;
+  arma::mat initialX;
+  arma::vec initialYsparse;
+  arma::vec initialYdense;
+  arma::mat initialZ;
 
   double tau;
   double normXzTol;
@@ -63,4 +70,7 @@ class PrimalDualSolver {
 } // namespace optimization
 } // namespace mlpack
 
+// Include implementation.
+#include "primal_dual_impl.hpp"
+
 #endif
diff --git a/src/mlpack/core/optimizers/sdp/primal_dual.cpp b/src/mlpack/core/optimizers/sdp/primal_dual_impl.hpp
similarity index 76%
rename from src/mlpack/core/optimizers/sdp/primal_dual.cpp
rename to src/mlpack/core/optimizers/sdp/primal_dual_impl.hpp
index 36d32ca..481ee67 100644
--- a/src/mlpack/core/optimizers/sdp/primal_dual.cpp
+++ b/src/mlpack/core/optimizers/sdp/primal_dual_impl.hpp
@@ -1,9 +1,10 @@
 /**
- * @file primal_dual.cpp
+ * @file primal_dual_impl.hpp
  * @author Stephen Tu
  *
- * Contains an implementation of the "XZ+ZX" primal-dual IP method presented
- * and analyzed in:
+ * Contains an implementation of the "XZ+ZX" primal-dual infeasible interior
+ * point method with a Mehrotra predictor-corrector update step presented and
+ * analyzed in:
  *
  *   Primal-dual interior-point methods for semidefinite programming:
  *   Convergence rates, stability and numerical results.
@@ -16,18 +17,21 @@
  * Note there are many optimizations that still need to be implemented. See the
  * code comments for more details.
  */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SDP_PRIMAL_DUAL_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SDP_PRIMAL_DUAL_IMPL_HPP
 
 #include "primal_dual.hpp"
 
 namespace mlpack {
 namespace optimization {
 
-PrimalDualSolver::PrimalDualSolver(const SDP& sdp)
+template <typename SDPType>
+PrimalDualSolver<SDPType>::PrimalDualSolver(const SDPType& sdp)
   : sdp(sdp),
-    X0(arma::eye<arma::mat>(sdp.N(), sdp.N())),
-    ysparse0(arma::ones<arma::vec>(sdp.NumSparseConstraints())),
-    ydense0(arma::ones<arma::vec>(sdp.NumDenseConstraints())),
-    Z0(arma::eye<arma::mat>(sdp.N(), sdp.N())),
+    initialX(arma::eye<arma::mat>(sdp.N(), sdp.N())),
+    initialYsparse(arma::ones<arma::vec>(sdp.NumSparseConstraints())),
+    initialYdense(arma::ones<arma::vec>(sdp.NumDenseConstraints())),
+    initialZ(arma::eye<arma::mat>(sdp.N(), sdp.N())),
     tau(0.99),
     normXzTol(1e-7),
     primalInfeasTol(1e-7),
@@ -37,16 +41,17 @@ PrimalDualSolver::PrimalDualSolver(const SDP& sdp)
 
 }
 
-PrimalDualSolver::PrimalDualSolver(const SDP& sdp,
-                                   const arma::mat& X0,
-                                   const arma::vec& ysparse0,
-                                   const arma::vec& ydense0,
-                                   const arma::mat& Z0)
+template <typename SDPType>
+PrimalDualSolver<SDPType>::PrimalDualSolver(const SDPType& sdp,
+                                            const arma::mat& initialX,
+                                            const arma::vec& initialYsparse,
+                                            const arma::vec& initialYdense,
+                                            const arma::mat& initialZ)
   : sdp(sdp),
-    X0(X0),
-    ysparse0(ysparse0),
-    ydense0(ydense0),
-    Z0(Z0),
+    initialX(initialX),
+    initialYsparse(initialYsparse),
+    initialYdense(initialYdense),
+    initialZ(initialZ),
     tau(0.99),
     normXzTol(1e-7),
     primalInfeasTol(1e-7),
@@ -55,34 +60,34 @@ PrimalDualSolver::PrimalDualSolver(const SDP& sdp,
 {
   arma::mat tmp;
 
-  if (X0.n_rows != sdp.N() || X0.n_cols != sdp.N())
+  if (initialX.n_rows != sdp.N() || initialX.n_cols != sdp.N())
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "X0 needs to be square n x n matrix"
+      << "initialX needs to be square n x n matrix"
       << std::endl;
 
-  if (!arma::chol(tmp, X0))
+  if (!arma::chol(tmp, initialX))
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "X0 needs to be symmetric positive definite"
+      << "initialX needs to be symmetric positive definite"
       << std::endl;
 
-  if (ysparse0.n_elem != sdp.NumSparseConstraints())
+  if (initialYsparse.n_elem != sdp.NumSparseConstraints())
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "ysparse0 needs to have the same length as the number of sparse constraints"
+      << "initialYsparse needs to have the same length as the number of sparse constraints"
       << std::endl;
 
-  if (ydense0.n_elem != sdp.NumDenseConstraints())
+  if (initialYdense.n_elem != sdp.NumDenseConstraints())
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "ydense0 needs to have the same length as the number of dense constraints"
+      << "initialYdense needs to have the same length as the number of dense constraints"
       << std::endl;
 
-  if (Z0.n_rows != sdp.N() || Z0.n_cols != sdp.N())
+  if (initialZ.n_rows != sdp.N() || initialZ.n_cols != sdp.N())
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "Z0 needs to be square n x n matrix"
+      << "initialZ needs to be square n x n matrix"
       << std::endl;
 
-  if (!arma::chol(tmp, Z0))
+  if (!arma::chol(tmp, initialZ))
     Log::Fatal << "PrimalDualSolver::PrimalDualSolver(): "
-      << "Z0 needs to be symmetric positive definite"
+      << "initialZ needs to be symmetric positive definite"
       << std::endl;
 }
 
@@ -91,8 +96,8 @@ AlphaHat(const arma::mat& A, const arma::mat& dA)
 {
   // note: arma::chol(A) returns an upper triangular matrix (instead of the
   // usual lower triangular)
-  const arma::mat L = arma::trimatl(arma::chol(A).t());
-  const arma::mat Linv = L.i();
+  const arma::mat L = arma::chol(A).t();
+  const arma::mat Linv = arma::inv(arma::trimatl(L));
   const arma::vec evals = arma::eig_sym(-Linv * dA * Linv.t());
   const double alphahatinv = evals(evals.n_elem - 1);
   return 1. / alphahatinv;
@@ -171,11 +176,21 @@ SolveKKTSystem(const arma::sp_mat& Asparse,
   dsz = rd - Asparse.t() * dysparse - Adense.t() * dydense;
 }
 
+namespace private_ {
+
+// TODO(stephentu): should we move this somewhere more general
+template <typename T> struct vectype { };
+template <typename eT> struct vectype<arma::Mat<eT>> { typedef arma::Col<eT> type; };
+template <typename eT> struct vectype<arma::SpMat<eT>> { typedef arma::SpCol<eT> type; };
+
+} // namespace private_
+
+template <typename SDPType>
 std::pair<bool, double>
-PrimalDualSolver::Optimize(arma::mat& X,
-                           arma::vec& ysparse,
-                           arma::vec& ydense,
-                           arma::mat& Z)
+PrimalDualSolver<SDPType>::Optimize(arma::mat& X,
+                                    arma::vec& ysparse,
+                                    arma::vec& ydense,
+                                    arma::mat& Z)
 {
   // TODO(stephentu): We need a method which deals with the case when the Ais
   // are not linearly independent.
@@ -184,12 +199,12 @@ PrimalDualSolver::Optimize(arma::mat& X,
   const size_t n2bar = sdp.N2bar();
 
   arma::sp_mat Asparse(sdp.NumSparseConstraints(), n2bar);
-  arma::sp_mat Aisparse;
+  arma::sp_vec Aisparse;
 
   for (size_t i = 0; i < sdp.NumSparseConstraints(); i++)
   {
     math::Svec(sdp.SparseA()[i], Aisparse);
-    Asparse.row(i) = Aisparse.col(0).t();
+    Asparse.row(i) = Aisparse.t();
   }
 
   arma::mat Adense(sdp.NumDenseConstraints(), n2bar);
@@ -200,18 +215,13 @@ PrimalDualSolver::Optimize(arma::mat& X,
     Adense.row(i) = Aidense.t();
   }
 
-  arma::sp_mat scsparse;
-  if (sdp.HasSparseObjective())
-    math::Svec(sdp.SparseC(), scsparse);
+  typename private_::vectype<typename SDPType::objective_matrix_type>::type sc;
+  math::Svec(sdp.C(), sc);
 
-  arma::vec scdense;
-  if (sdp.HasDenseObjective())
-    math::Svec(sdp.DenseC(), scdense);
-
-  X = X0;
-  ysparse = ysparse0;
-  ydense = ydense0;
-  Z = Z0;
+  X = initialX;
+  ysparse = initialYsparse;
+  ydense = initialYdense;
+  Z = initialZ;
 
   arma::vec sx, sz, dysparse, dydense, dsx, dsz;
   arma::mat dX, dZ;
@@ -240,11 +250,7 @@ PrimalDualSolver::Optimize(arma::mat& X,
       rp(arma::span(sdp.NumSparseConstraints(), sdp.NumConstraints() - 1)) =
           sdp.DenseB() - Adense * sx;
 
-    rd = - sz - Asparse.t() * ysparse - Adense.t() * ydense;
-    if (sdp.HasSparseObjective())
-      rd += scsparse.col(0);
-    if (sdp.HasDenseObjective())
-      rd += scdense;
+    rd = sc - sz - Asparse.t() * ysparse - Adense.t() * ydense;
 
     math::SymKronId(X, F);
 
@@ -336,11 +342,7 @@ PrimalDualSolver::Optimize(arma::mat& X,
         sparse_primal_infeas * sparse_primal_infeas +
         dense_primal_infeas * dense_primal_infeas);
 
-    primal_obj = 0.;
-    if (sdp.HasSparseObjective())
-      primal_obj += arma::dot(sdp.SparseC(), X);
-    if (sdp.HasDenseObjective())
-      primal_obj += arma::dot(sdp.DenseC(), X);
+    primal_obj = arma::dot(sdp.C(), X);
 
     const double dual_obj =
       arma::dot(sdp.SparseB(), ysparse) +
@@ -350,11 +352,7 @@ PrimalDualSolver::Optimize(arma::mat& X,
 
     // TODO(stephentu): this dual check is quite expensive,
     // maybe make it optional?
-    DualCheck = Z;
-    if (sdp.HasSparseObjective())
-      DualCheck -= sdp.SparseC();
-    if (sdp.HasDenseObjective())
-      DualCheck -= sdp.DenseC();
+    DualCheck = Z - sdp.C();
     for (size_t i = 0; i < sdp.NumSparseConstraints(); i++)
       DualCheck += ysparse(i) * sdp.SparseA()[i];
     for (size_t i = 0; i < sdp.NumDenseConstraints(); i++)
@@ -384,3 +382,5 @@ PrimalDualSolver::Optimize(arma::mat& X,
 
 } // namespace optimization
 } // namespace mlpack
+
+#endif
diff --git a/src/mlpack/core/optimizers/sdp/sdp.hpp b/src/mlpack/core/optimizers/sdp/sdp.hpp
index 4d57dab..c0de4f3 100644
--- a/src/mlpack/core/optimizers/sdp/sdp.hpp
+++ b/src/mlpack/core/optimizers/sdp/sdp.hpp
@@ -13,11 +13,19 @@ namespace optimization {
 
 /**
  * Specify an SDP in primal form
+ *
+ *     min    dot(C, X)
+ *     s.t.   dot(Ai, X) = bi, i=1,...,m, X >= 0
+ *
+ * @tparam ObjectiveMatrixType Should be either arma::mat or arma::sp_mat
  */
+template <typename ObjectiveMatrixType>
 class SDP
 {
  public:
 
+  typedef ObjectiveMatrixType objective_matrix_type;
+
   SDP(const size_t n,
       const size_t numSparseConstraints,
       const size_t numDenseConstraints);
@@ -32,23 +40,11 @@ class SDP
 
   size_t NumConstraints() const { return sparseB.n_elem + denseB.n_elem; }
 
-  //! Return the sparse objective function matrix (sparseC).
-  const arma::sp_mat& SparseC() const { return sparseC; }
-
   //! Modify the sparse objective function matrix (sparseC).
-  arma::sp_mat& SparseC() {
-    hasModifiedSparseObjective = true;
-    return sparseC;
-  }
-
-  //! Return the dense objective function matrix (denseC).
-  const arma::mat& DenseC() const { return denseC; }
+  ObjectiveMatrixType& C() { return c; }
 
-  //! Modify the dense objective function matrix (denseC).
-  arma::mat& DenseC() {
-    hasModifiedDenseObjective = true;
-    return denseC;
-  }
+  //! Return the sparse objective function matrix (sparseC).
+  const ObjectiveMatrixType& C() const { return c; }
 
   //! Return the vector of sparse A matrices (which correspond to the sparse
   // constraints).
@@ -76,10 +72,6 @@ class SDP
   //! Modify the vector of dense B values.
   arma::vec& DenseB() { return denseB; }
 
-  bool HasSparseObjective() const { return hasModifiedSparseObjective; }
-
-  bool HasDenseObjective() const { return hasModifiedDenseObjective; }
-
   /**
    * Check whether or not the constraint matrices are linearly independent.
    *
@@ -92,17 +84,8 @@ class SDP
   //! Dimension of the objective variable.
   size_t n;
 
-  //! Sparse objective function matrix c.
-  arma::sp_mat sparseC;
-
-  //! Dense objective function matrix c.
-  arma::mat denseC;
-
-  //! If false, sparseC is zero
-  bool hasModifiedSparseObjective;
-
-  //! If false, denseC is zero
-  bool hasModifiedDenseObjective;
+  //! Objective function matrix c.
+  ObjectiveMatrixType c;
 
   //! A_i for each sparse constraint.
   std::vector<arma::sp_mat> sparseA;
@@ -118,4 +101,7 @@ class SDP
 } // namespace optimization
 } // namespace mlpack
 
+// Include implementation.
+#include "sdp_impl.hpp"
+
 #endif
diff --git a/src/mlpack/core/optimizers/sdp/sdp.cpp b/src/mlpack/core/optimizers/sdp/sdp_impl.hpp
similarity index 64%
rename from src/mlpack/core/optimizers/sdp/sdp.cpp
rename to src/mlpack/core/optimizers/sdp/sdp_impl.hpp
index 984ea7e..1da53f4 100644
--- a/src/mlpack/core/optimizers/sdp/sdp.cpp
+++ b/src/mlpack/core/optimizers/sdp/sdp_impl.hpp
@@ -1,31 +1,31 @@
 /**
- * @file sdp.cpp
+ * @file sdp_impl.hpp
  * @author Stephen Tu
  *
  */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SDP_SDP_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SDP_SDP_IMPL_HPP
 
 #include "sdp.hpp"
 
 namespace mlpack {
 namespace optimization {
 
-SDP::SDP(const size_t n,
-         const size_t numSparseConstraints,
-         const size_t numDenseConstraints) :
+template <typename ObjectiveMatrixType>
+SDP<ObjectiveMatrixType>::SDP(const size_t n,
+                              const size_t numSparseConstraints,
+                              const size_t numDenseConstraints) :
     n(n),
-    sparseC(n, n),
-    denseC(n, n),
-    hasModifiedSparseObjective(false),
-    hasModifiedDenseObjective(false),
+    c(n, n),
     sparseA(numSparseConstraints),
     sparseB(numSparseConstraints),
     denseA(numDenseConstraints),
     denseB(numDenseConstraints)
 {
-  denseC.zeros();
 }
 
-bool SDP::HasLinearlyIndependentConstraints() const
+template <typename ObjectiveMatrixType>
+bool SDP<ObjectiveMatrixType>::HasLinearlyIndependentConstraints() const
 {
   // Very inefficient, should only be used for testing/debugging
 
@@ -53,3 +53,5 @@ bool SDP::HasLinearlyIndependentConstraints() const
 
 } // namespace optimization
 } // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/lin_alg_test.cpp b/src/mlpack/tests/lin_alg_test.cpp
index ebea714..afb8876 100644
--- a/src/mlpack/tests/lin_alg_test.cpp
+++ b/src/mlpack/tests/lin_alg_test.cpp
@@ -194,12 +194,11 @@ BOOST_AUTO_TEST_CASE(TestSvecSmat)
 
   arma::vec sx;
   Svec(X, sx);
-  const double sq2 = sqrt(2.);
   BOOST_REQUIRE_CLOSE(sx(0), 0, 1e-7);
-  BOOST_REQUIRE_CLOSE(sx(1), sq2 * 1., 1e-7);
-  BOOST_REQUIRE_CLOSE(sx(2), sq2 * 2., 1e-7);
+  BOOST_REQUIRE_CLOSE(sx(1), M_SQRT2 * 1., 1e-7);
+  BOOST_REQUIRE_CLOSE(sx(2), M_SQRT2 * 2., 1e-7);
   BOOST_REQUIRE_CLOSE(sx(3), 3., 1e-7);
-  BOOST_REQUIRE_CLOSE(sx(4), sq2 * 4., 1e-7);
+  BOOST_REQUIRE_CLOSE(sx(4), M_SQRT2 * 4., 1e-7);
   BOOST_REQUIRE_CLOSE(sx(5), 5., 1e-7);
 
   arma::mat Xtest;
@@ -217,19 +216,18 @@ BOOST_AUTO_TEST_CASE(TestSparseSvec)
   X.zeros(3, 3);
   X(1, 0) = X(0, 1) = 1;
 
-  arma::sp_mat sx;
+  arma::sp_vec sx;
   Svec(X, sx);
 
-  const double sq2 = sqrt(2.);
-  const double v0 = sx(0, 0);
-  const double v1 = sx(1, 0);
-  const double v2 = sx(2, 0);
-  const double v3 = sx(3, 0);
-  const double v4 = sx(4, 0);
-  const double v5 = sx(5, 0);
+  const double v0 = sx(0);
+  const double v1 = sx(1);
+  const double v2 = sx(2);
+  const double v3 = sx(3);
+  const double v4 = sx(4);
+  const double v5 = sx(5);
 
   BOOST_REQUIRE_CLOSE(v0, 0, 1e-7);
-  BOOST_REQUIRE_CLOSE(v1, sq2 * 1., 1e-7);
+  BOOST_REQUIRE_CLOSE(v1, M_SQRT2 * 1., 1e-7);
   BOOST_REQUIRE_CLOSE(v2, 0, 1e-7);
   BOOST_REQUIRE_CLOSE(v3, 0, 1e-7);
   BOOST_REQUIRE_CLOSE(v4, 0, 1e-7);
diff --git a/src/mlpack/tests/sdp_primal_dual_test.cpp b/src/mlpack/tests/sdp_primal_dual_test.cpp
index 2d1b013..1af340c 100644
--- a/src/mlpack/tests/sdp_primal_dual_test.cpp
+++ b/src/mlpack/tests/sdp_primal_dual_test.cpp
@@ -116,12 +116,12 @@ class UndirectedGraph
   size_t numVertices;
 };
 
-static inline SDP
+static inline SDP<arma::sp_mat>
 ConstructMaxCutSDPFromGraph(const UndirectedGraph& g)
 {
-  SDP sdp(g.NumVertices(), g.NumVertices(), 0);
-  g.Laplacian(sdp.SparseC());
-  sdp.SparseC() *= -1;
+  SDP<arma::sp_mat> sdp(g.NumVertices(), g.NumVertices(), 0);
+  g.Laplacian(sdp.C());
+  sdp.C() *= -1;
   for (size_t i = 0; i < g.NumVertices(); i++)
   {
     sdp.SparseA()[i].zeros(g.NumVertices(), g.NumVertices());
@@ -131,12 +131,12 @@ ConstructMaxCutSDPFromGraph(const UndirectedGraph& g)
   return sdp;
 }
 
-static inline SDP
+static inline SDP<arma::mat>
 ConstructLovaszThetaSDPFromGraph(const UndirectedGraph& g)
 {
-  SDP sdp(g.NumVertices(), g.NumEdges() + 1, 0);
-  sdp.DenseC().ones();
-  sdp.DenseC() *= -1.;
+  SDP<arma::mat> sdp(g.NumVertices(), g.NumEdges() + 1, 0);
+  sdp.C().ones();
+  sdp.C() *= -1.;
   sdp.SparseA()[0].eye(g.NumVertices(), g.NumVertices());
   for (size_t i = 0; i < g.NumEdges(); i++)
   {
@@ -149,7 +149,7 @@ ConstructLovaszThetaSDPFromGraph(const UndirectedGraph& g)
   return sdp;
 }
 
-// TODO: does arma have a builtin way to do this?
+// TODO(stephentu): does arma have a builtin way to do this?
 static inline arma::mat
 Diag(const arma::vec& diag)
 {
@@ -162,15 +162,15 @@ Diag(const arma::vec& diag)
   return ret;
 }
 
-static inline SDP
+static inline SDP<arma::sp_mat>
 ConstructMaxCutSDPFromLaplacian(const std::string& laplacianFilename)
 {
   arma::mat laplacian;
   data::Load(laplacianFilename, laplacian, true, false);
   if (laplacian.n_rows != laplacian.n_cols)
     Log::Fatal << "laplacian not square" << std::endl;
-  SDP sdp(laplacian.n_rows, laplacian.n_rows, 0);
-  sdp.SparseC() = -arma::sp_mat(laplacian);
+  SDP<arma::sp_mat> sdp(laplacian.n_rows, laplacian.n_rows, 0);
+  sdp.C() = -arma::sp_mat(laplacian);
   for (size_t i = 0; i < laplacian.n_rows; i++)
   {
     sdp.SparseA()[i].zeros(laplacian.n_rows, laplacian.n_rows);
@@ -183,7 +183,7 @@ ConstructMaxCutSDPFromLaplacian(const std::string& laplacianFilename)
 
 BOOST_AUTO_TEST_SUITE(SdpPrimalDualTest);
 
-static void SolveMaxCutFeasibleSDP(const SDP& sdp)
+static void SolveMaxCutFeasibleSDP(const SDP<arma::sp_mat>& sdp)
 {
   arma::mat X0, Z0;
   arma::vec ysparse0, ydense0;
@@ -191,10 +191,10 @@ static void SolveMaxCutFeasibleSDP(const SDP& sdp)
 
   // strictly feasible starting point
   X0.eye(sdp.N(), sdp.N());
-  ysparse0 = -1.1 * arma::vec(arma::sum(arma::abs(sdp.SparseC()), 0).t());
-  Z0 = -Diag(ysparse0) + sdp.SparseC();
+  ysparse0 = -1.1 * arma::vec(arma::sum(arma::abs(sdp.C()), 0).t());
+  Z0 = -Diag(ysparse0) + sdp.C();
 
-  PrimalDualSolver solver(sdp, X0, ysparse0, ydense0, Z0);
+  PrimalDualSolver<SDP<arma::sp_mat>> solver(sdp, X0, ysparse0, ydense0, Z0);
 
   arma::mat X, Z;
   arma::vec ysparse, ydense;
@@ -202,7 +202,7 @@ static void SolveMaxCutFeasibleSDP(const SDP& sdp)
   BOOST_REQUIRE(p.first);
 }
 
-static void SolveMaxCutPositiveSDP(const SDP& sdp)
+static void SolveMaxCutPositiveSDP(const SDP<arma::sp_mat>& sdp)
 {
   arma::mat X0, Z0;
   arma::vec ysparse0, ydense0;
@@ -216,7 +216,7 @@ static void SolveMaxCutPositiveSDP(const SDP& sdp)
   ysparse0 = arma::randu<arma::vec>(sdp.NumSparseConstraints());
   Z0.eye(sdp.N(), sdp.N());
 
-  PrimalDualSolver solver(sdp, X0, ysparse0, ydense0, Z0);
+  PrimalDualSolver<SDP<arma::sp_mat>> solver(sdp, X0, ysparse0, ydense0, Z0);
 
   arma::mat X, Z;
   arma::vec ysparse, ydense;
@@ -226,7 +226,7 @@ static void SolveMaxCutPositiveSDP(const SDP& sdp)
 
 BOOST_AUTO_TEST_CASE(SmallMaxCutSdp)
 {
-  SDP sdp = ConstructMaxCutSDPFromLaplacian("r10.txt");
+  auto sdp = ConstructMaxCutSDPFromLaplacian("r10.txt");
   SolveMaxCutFeasibleSDP(sdp);
   SolveMaxCutPositiveSDP(sdp);
 
@@ -241,9 +241,9 @@ BOOST_AUTO_TEST_CASE(SmallLovaszThetaSdp)
 {
   UndirectedGraph g;
   UndirectedGraph::LoadFromEdges(g, "johnson8-4-4.csv", true);
-  SDP sdp = ConstructLovaszThetaSDPFromGraph(g);
+  auto sdp = ConstructLovaszThetaSDPFromGraph(g);
 
-  PrimalDualSolver solver(sdp);
+  PrimalDualSolver<SDP<arma::mat>> solver(sdp);
 
   arma::mat X, Z;
   arma::vec ysparse, ydense;
@@ -277,7 +277,7 @@ BlockDiag(const std::vector<arma::sp_mat>& blocks)
   return ret;
 }
 
-static inline SDP
+static inline SDP<arma::sp_mat>
 ConstructLogChebychevApproxSdp(const arma::mat& A, const arma::vec& b)
 {
   if (A.n_rows != b.n_elem)
@@ -292,8 +292,8 @@ ConstructLogChebychevApproxSdp(const arma::mat& A, const arma::vec& b)
   cblock(1, 2) = cblock(2, 1) = 1.;
   const arma::sp_mat C = RepeatBlockDiag(cblock, p);
 
-  SDP sdp(C.n_rows, k + 1, 0);
-  sdp.SparseC() = C;
+  SDP<arma::sp_mat> sdp(C.n_rows, k + 1, 0);
+  sdp.C() = C;
   sdp.SparseB().zeros();
   sdp.SparseB()[0] = -1;
 
@@ -372,8 +372,8 @@ BOOST_AUTO_TEST_CASE(LogChebychevApproxSdp)
   const size_t k0 = 10;
   const arma::mat A0 = RandomFullRowRankMatrix(p0, k0);
   const arma::vec b0 = arma::randu<arma::vec>(p0);
-  const SDP sdp0 = ConstructLogChebychevApproxSdp(A0, b0);
-  PrimalDualSolver solver0(sdp0);
+  const auto sdp0 = ConstructLogChebychevApproxSdp(A0, b0);
+  PrimalDualSolver<SDP<arma::sp_mat>> solver0(sdp0);
   arma::mat X0, Z0;
   arma::vec ysparse0, ydense0;
   const auto stat0 = solver0.Optimize(X0, ysparse0, ydense0, Z0);
@@ -383,8 +383,8 @@ BOOST_AUTO_TEST_CASE(LogChebychevApproxSdp)
   const size_t k1 = 5;
   const arma::mat A1 = RandomFullRowRankMatrix(p1, k1);
   const arma::vec b1 = arma::randu<arma::vec>(p1);
-  const SDP sdp1 = ConstructLogChebychevApproxSdp(A1, b1);
-  PrimalDualSolver solver1(sdp1);
+  const auto sdp1 = ConstructLogChebychevApproxSdp(A1, b1);
+  PrimalDualSolver<SDP<arma::sp_mat>> solver1(sdp1);
   arma::mat X1, Z1;
   arma::vec ysparse1, ydense1;
   const auto stat1 = solver1.Optimize(X1, ysparse1, ydense1, Z1);
@@ -445,7 +445,7 @@ BOOST_AUTO_TEST_CASE(CorrelationCoeffToySdp)
 
   std::vector<arma::sp_mat> ais({A0, A1, A2, A3, A4, A5, A6});
 
-  SDP sdp(7, 7 + 4 + 4 + 4 + 3 + 2 + 1, 0);
+  SDP<arma::sp_mat> sdp(7, 7 + 4 + 4 + 4 + 3 + 2 + 1, 0);
 
   for (size_t j = 0; j < 3; j++)
   {
@@ -489,10 +489,10 @@ BOOST_AUTO_TEST_CASE(CorrelationCoeffToySdp)
 
   sdp.SparseB()[5] = 1.; sdp.SparseB()[6] = 0.8;
 
-  sdp.SparseC().zeros();
-  sdp.SparseC()(0, 2) = sdp.SparseC()(2, 0) = 1.;
+  sdp.C().zeros();
+  sdp.C()(0, 2) = sdp.C()(2, 0) = 1.;
 
-  PrimalDualSolver solver(sdp);
+  PrimalDualSolver<SDP<arma::sp_mat>> solver(sdp);
   arma::mat X, Z;
   arma::vec ysparse, ydense;
   const auto p = solver.Optimize(X, ysparse, ydense, Z);
@@ -511,8 +511,8 @@ BOOST_AUTO_TEST_CASE(CorrelationCoeffToySdp)
 // * @param origData origDim x numPoints
 // * @param numNeighbors
 // */
-//static inline SDP ConstructMvuSDP(const arma::mat& origData,
-//                                  size_t numNeighbors)
+//static inline SDP<arma::sp_mat> ConstructMvuSDP(const arma::mat& origData,
+//                                                size_t numNeighbors)
 //{
 //  const size_t numPoints = origData.n_cols;
 //
@@ -523,9 +523,9 @@ BOOST_AUTO_TEST_CASE(CorrelationCoeffToySdp)
 //  AllkNN allknn(origData);
 //  allknn.Search(numNeighbors, neighbors, distances);
 //
-//  SDP sdp(numPoints, numNeighbors * numPoints, 1);
-//  sdp.SparseC().eye(numPoints, numPoints);
-//  sdp.SparseC() *= -1;
+//  SDP<arma::sp_mat> sdp(numPoints, numNeighbors * numPoints, 1);
+//  sdp.C().eye(numPoints, numPoints);
+//  sdp.C() *= -1;
 //  sdp.DenseA()[0].ones(numPoints, numPoints);
 //  sdp.DenseB()[0] = 0;
 //
@@ -579,9 +579,9 @@ BOOST_AUTO_TEST_CASE(CorrelationCoeffToySdp)
 //    origData.col(i) = arma::normalise(gauss.Random());
 //  }
 //
-//  SDP sdp = ConstructMvuSDP(origData, 5);
+//  auto sdp = ConstructMvuSDP(origData, 5);
 //
-//  PrimalDualSolver solver(sdp);
+//  PrimalDualSolver<SDP<arma::sp_mat>> solver(sdp);
 //  arma::mat X, Z;
 //  arma::vec ysparse, ydense;
 //  const auto p = solver.Optimize(X, ysparse, ydense, Z);



More information about the mlpack-git mailing list