[mlpack-git] master: Add matrix completion test case (db1d7dc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jan 12 17:31:51 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/2f5aa71b3e3b9f14ed825a20a69759471edffb31...d81511957f5fcc5e46cee49c6103e3ae7ac1d45d

>---------------------------------------------------------------

commit db1d7dc240b286f5c34f4b07a077bbb7ba9dfe4a
Author: Stephen Tu <tu.stephenl at gmail.com>
Date:   Sun Jan 4 13:41:00 2015 -0800

    Add matrix completion test case


>---------------------------------------------------------------

db1d7dc240b286f5c34f4b07a077bbb7ba9dfe4a
 .../matrix_completion/matrix_completion.hpp        | 16 +++++--
 .../matrix_completion/matrix_completion_impl.hpp   | 49 ++++++++++++++++------
 src/mlpack/tests/data/completion_X.csv             | 20 +++++++++
 src/mlpack/tests/data/completion_indices.csv       |  2 +
 src/mlpack/tests/matrix_completion_test.cpp        | 31 ++++++++++++++
 5 files changed, 101 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/matrix_completion/matrix_completion.hpp b/src/mlpack/methods/matrix_completion/matrix_completion.hpp
index 0353459..d04321b 100644
--- a/src/mlpack/methods/matrix_completion/matrix_completion.hpp
+++ b/src/mlpack/methods/matrix_completion/matrix_completion.hpp
@@ -18,31 +18,39 @@ class MatrixCompletion
 public:
   MatrixCompletion(const size_t m,
                    const size_t n,
-                   const arma::mat& entries,
+                   const arma::umat& indices,
+                   const arma::vec& values,
                    const size_t r);
 
   MatrixCompletion(const size_t m,
                    const size_t n,
-                   const arma::mat& entries,
+                   const arma::umat& indices,
+                   const arma::vec& values,
                    const arma::mat& initialPoint);
 
   MatrixCompletion(const size_t m,
                    const size_t n,
-                   const arma::mat& entries);
+                   const arma::umat& indices,
+                   const arma::vec& values);
 
   void Recover();
 
+  const optimization::LRSDP& Sdp() const { return sdp; }
+  optimization::LRSDP& Sdp() { return sdp; }
+
   const arma::mat& Recovered() const { return recovered; }
   arma::mat& Recovered() { return recovered; }
 
 private:
   size_t m;
   size_t n;
-  arma::mat entries;
+  arma::umat indices;
+  arma::mat values;
 
   optimization::LRSDP sdp;
   arma::mat recovered;
 
+  void checkValues();
   void initSdp();
 
   static size_t
diff --git a/src/mlpack/methods/matrix_completion/matrix_completion_impl.hpp b/src/mlpack/methods/matrix_completion/matrix_completion_impl.hpp
index 1a8fdad..1ceb5c9 100644
--- a/src/mlpack/methods/matrix_completion/matrix_completion_impl.hpp
+++ b/src/mlpack/methods/matrix_completion/matrix_completion_impl.hpp
@@ -12,43 +12,64 @@ namespace matrix_completion {
 
 MatrixCompletion::MatrixCompletion(const size_t m,
                                    const size_t n,
-                                   const arma::mat& entries,
+                                   const arma::umat& indices,
+                                   const arma::vec& values,
                                    const size_t r)
-  : m(m), n(n), entries(entries),
-    sdp(entries.n_cols, 0, CreateInitialPoint(m, n, r))
+  : m(m), n(n), indices(indices), values(values),
+    sdp(indices.n_cols, 0, CreateInitialPoint(m, n, r))
 {
+  checkValues();
   initSdp();
 }
 
 MatrixCompletion::MatrixCompletion(const size_t m,
                                    const size_t n,
-                                   const arma::mat& entries,
+                                   const arma::umat& indices,
+                                   const arma::vec& values,
                                    const arma::mat& initialPoint)
-  : m(m), n(n), entries(entries),
-    sdp(entries.n_cols, 0, initialPoint)
+  : m(m), n(n), indices(indices), values(values),
+    sdp(indices.n_cols, 0, initialPoint)
 {
+  checkValues();
   initSdp();
 }
 
 MatrixCompletion::MatrixCompletion(const size_t m,
                                    const size_t n,
-                                   const arma::mat& entries)
-  : m(m), n(n), entries(entries),
-    sdp(entries.n_cols, 0, CreateInitialPoint(m, n, DefaultRank(m, n, entries.n_cols)))
+                                   const arma::umat& indices,
+                                   const arma::vec& values)
+  : m(m), n(n), indices(indices), values(values),
+    sdp(indices.n_cols, 0, CreateInitialPoint(m, n, DefaultRank(m, n, indices.n_cols)))
 {
+  checkValues();
   initSdp();
 }
 
+void MatrixCompletion::checkValues()
+{
+  if (indices.n_rows != 2)
+    Log::Fatal << "indices.n_rows != 2" << std::endl;
+
+  if (indices.n_cols != values.n_elem)
+    Log::Fatal << "indices.n_cols != values.n_elem" << std::endl;
+
+  for (size_t i = 0; i < values.n_elem; i++)
+  {
+    if (indices(0, i) >= m || indices(1, i) >= n)
+      Log::Fatal << "index out of bounds" << std::endl;
+  }
+}
+
 void MatrixCompletion::initSdp()
 {
   sdp.SparseC().eye(m + n, m + n);
-  sdp.SparseB() = 2. * entries.row(2);
-  const size_t p = entries.n_cols;
+  sdp.SparseB() = 2. * values;
+  const size_t p = indices.n_cols;
   for (size_t i = 0; i < p; i++)
   {
     sdp.SparseA()[i].zeros(m + n, m + n);
-    sdp.SparseA()[i](entries(0, i), m + entries(1, i)) = 1.;
-    sdp.SparseA()[i](m + entries(1, i), entries(0, i)) = 1.;
+    sdp.SparseA()[i](indices(0, i), m + indices(1, i)) = 1.;
+    sdp.SparseA()[i](m + indices(1, i), indices(0, i)) = 1.;
   }
 }
 
@@ -56,6 +77,8 @@ void MatrixCompletion::Recover()
 {
   recovered = sdp.Function().GetInitialPoint();
   sdp.Optimize(recovered);
+  recovered = recovered * trans(recovered);
+  recovered = recovered(arma::span(0, m - 1), arma::span(m, m + n - 1));
 }
 
 size_t MatrixCompletion::DefaultRank(const size_t m,
diff --git a/src/mlpack/tests/data/completion_X.csv b/src/mlpack/tests/data/completion_X.csv
new file mode 100644
index 0000000..7c0414b
--- /dev/null
+++ b/src/mlpack/tests/data/completion_X.csv
@@ -0,0 +1,20 @@
+2.268917959069277346e-01,1.201769815832157140e-01,2.791472705838512480e-01,4.129696002411505917e-01,7.921637265182315257e-01,5.032788432358387132e-01,3.835664173525583087e-01,8.451234748542018060e-01,3.228975933850825042e-01,3.677241101206029650e-01,3.115613861833520515e-01,2.843637987382546251e-01,8.263612223976396498e-01,2.578882500552053259e-01,3.735987908409875713e-01,4.096958161076692528e-01,4.836063397606226166e-01,1.840499653987820172e-01,5.339870859567330541e-01,6.943075938475559150e-01
+1.815514517683395057e-01,1.481438141299214750e-01,3.246365312873157327e-01,6.180958286371067700e-01,8.395422972369482872e-01,7.283799838784086322e-01,3.862362541987202080e-01,9.634855656153386017e-01,3.472450609549778133e-01,2.884095444455062607e-01,4.367424157124643602e-01,4.721652983026459749e-01,8.968181674236590517e-01,3.975936290126179662e-01,3.850861370247186111e-01,4.765619276085552025e-01,4.977252301323637496e-01,1.081737498640471112e-01,7.862189295351967866e-01,9.150673041547885411e-01
+4.485725462770236982e-01,6.210431906279154646e-01,8.079162416271070679e-01,7.479165249287595962e-01,9.660638138607120506e-01,6.775024396804142368e-01,6.268219497800838758e-01,1.488129887707428711e+00,4.081229070480736354e-01,9.937536925587704406e-01,9.942593357890014971e-01,9.277558021971723523e-01,9.732006968930746460e-01,4.241106289124610962e-01,5.442549648603086654e-01,4.212243188620748424e-01,5.900020369048364355e-01,2.652600104280545823e-01,8.213977178710969440e-01,9.447098806592240106e-01
+3.702964231051512067e-01,3.938319501534544576e-01,6.395889777312965263e-01,8.706414932452468669e-01,1.207707943992026856e+00,9.457458390400318438e-01,6.340944313462977266e-01,1.528845711341907831e+00,5.020885255631167832e-01,7.048629386590496981e-01,8.098977940228742067e-01,8.049719509373847171e-01,1.262115877670982211e+00,5.351700897498899989e-01,5.974078234042232705e-01,6.255845921380651653e-01,7.258190154091623825e-01,2.294328483941174279e-01,1.055931543348610946e+00,1.238692223038866835e+00
+3.526016035061635967e-01,2.703621037792456328e-01,4.439314100431043242e-01,4.229666868320861584e-01,9.031752423357746196e-01,4.597582533544301575e-01,5.007159852721363436e-01,1.059317307971314293e+00,3.691519879382180180e-01,6.584508547254678268e-01,4.731792376876083339e-01,3.876907164298784014e-01,9.169708773125819778e-01,2.428012038917182203e-01,4.609839449546466184e-01,4.136321988942701644e-01,5.616141572314510277e-01,2.821873896562767658e-01,5.105189321111793799e-01,6.990347228681695890e-01
+4.776614239431458131e-01,3.967295410837491931e-01,5.957997622485082800e-01,4.518458514484933186e-01,1.063991127879416743e+00,4.539370653399564248e-01,6.300085470808846067e-01,1.298949699351007459e+00,4.349946655237709736e-01,9.295503644763772888e-01,6.180235338336663720e-01,4.782498767941839590e-01,1.063025575275893919e+00,2.430658821032348205e-01,5.652089609470616338e-01,4.509587467231170876e-01,6.692198367772439482e-01,3.844592684576819686e-01,5.207452078366189285e-01,7.502845369622384020e-01
+5.056457998237853246e-01,5.532375286855771845e-01,7.272338358680021653e-01,4.881611398595514517e-01,9.503154304545203823e-01,4.119876009143376150e-01,6.322290100374869937e-01,1.346923545831661517e+00,3.940692886592210842e-01,1.071956010788616576e+00,8.011397166848406304e-01,6.554305764513597143e-01,9.325968522063492472e-01,2.505047222933129425e-01,5.435261528354763394e-01,3.651050090405553172e-01,5.993729646038712389e-01,3.697943299987761501e-01,5.141127848547054624e-01,6.928395068906985088e-01
+8.064260552339440613e-01,9.177319992214625355e-01,1.273416261287542817e+00,1.173494511088804204e+00,1.848081436294184865e+00,1.126013492925364012e+00,1.127445641813147770e+00,2.552912276563224836e+00,7.695028070582169422e-01,1.677483037275186994e+00,1.496622542088449359e+00,1.342675624209609841e+00,1.865740751901193484e+00,6.657448925587777788e-01,1.000914478221888837e+00,8.186106152427350402e-01,1.138850827545863087e+00,5.458898643619807256e-01,1.326285146704942131e+00,1.629686184016899242e+00
+1.738422620439678545e-01,1.567384743183915241e-01,2.784381394849713143e-01,4.002518394974652338e-01,5.979703370306685972e-01,4.505220850103247576e-01,3.029192179481221547e-01,7.169537418545651741e-01,2.471528659888790891e-01,3.132373072381882118e-01,3.462479895999485691e-01,3.421883485660967561e-01,6.262180569679969278e-01,2.483049716330763546e-01,2.896125897456474640e-01,3.130639915304364029e-01,3.603263077700779027e-01,1.161228340704642814e-01,4.951422090051941227e-01,5.939392287737175202e-01
+4.039456635170143861e-01,5.176489546874343262e-01,6.792690958633013087e-01,5.895486598390766719e-01,8.467228200114418346e-01,5.310638428784866250e-01,5.488301219725905122e-01,1.263870612002495175e+00,3.555434049210253189e-01,8.788683417928322417e-01,8.123858974263451493e-01,7.357475027678473944e-01,8.479815776774547453e-01,3.287810257594642183e-01,4.765747311001282127e-01,3.595012751203302548e-01,5.216642260308997914e-01,2.574900849029022609e-01,6.448484722778824452e-01,7.686516817055363271e-01
+3.940846782559529471e-01,4.147717414961812343e-01,5.314428096115817457e-01,2.745963014161701965e-01,6.611575129806924744e-01,1.995841559432231793e-01,4.627733071586718205e-01,9.432515422337117705e-01,2.729784988062347595e-01,8.394228484686327985e-01,5.583141288527634361e-01,4.237136582360849779e-01,6.359725511002094489e-01,1.245788131178181823e-01,3.907073643697869247e-01,2.273802394140331529e-01,4.240316145344125709e-01,3.018514857677208063e-01,2.675618098487977381e-01,4.053152472613138491e-01
+6.181385940128862888e-01,7.753320651824442411e-01,9.493655495231986263e-01,5.533211270559912354e-01,9.819006196623043525e-01,3.896925857346590361e-01,7.305044367182026432e-01,1.572348101912398377e+00,4.118710806852865502e-01,1.382958817823869868e+00,1.065249233800557027e+00,8.787280785934115102e-01,9.413207627638698893e-01,2.665065130425555151e-01,6.045051118209724406e-01,3.284817436254397482e-01,6.242767966384996647e-01,4.288307785111500348e-01,5.341110254630515586e-01,7.002714360078681199e-01
+2.929236407853995083e-01,5.375045292961507792e-01,5.849436198864613745e-01,3.276623462636701345e-01,2.917804957829630008e-01,1.628580143777714484e-01,3.192001003892946942e-01,7.647813248733912328e-01,1.318063667665049510e-01,7.601259858819182647e-01,7.165563717624559015e-01,6.415369614491145400e-01,2.580284780706327830e-01,1.537451077116215314e-01,2.365621509923022669e-01,4.824290586088247901e-02,1.852122752468948197e-01,1.511663393741338457e-01,2.752925054650718351e-01,2.767917790417023793e-01
+3.036686150026711162e-01,5.088741548959666572e-01,7.019940518817437969e-01,9.187845699751328032e-01,9.295555377881823400e-01,9.172063736459384886e-01,5.341178316295921435e-01,1.415286958013770136e+00,3.963752712126933364e-01,6.737014883809999688e-01,9.567201408953059927e-01,9.937306044691960327e-01,9.754041175353654891e-01,5.599205367827400082e-01,4.858583602521173161e-01,4.860159183826575635e-01,5.462939906370443133e-01,1.197171183316589060e-01,1.065523125996545639e+00,1.143011359415453176e+00
+3.422631357449678613e-01,2.213860487484220485e-01,3.979823446324060043e-01,3.932591826509272415e-01,9.171120351530952508e-01,4.496671551666838740e-01,4.920501282725772141e-01,1.021922881227457225e+00,3.730116119067053981e-01,6.140735702178622413e-01,4.075089008512043365e-01,3.210010083282257565e-01,9.338141519905384103e-01,2.280255878839305139e-01,4.589387293709081828e-01,4.264042465506259205e-01,5.710287168605510111e-01,2.866053736369176264e-01,4.882095437920391245e-01,6.897026618730657255e-01
+3.780667996848537937e-01,6.806705533316103063e-01,8.522610306419799997e-01,8.982008697364650684e-01,8.542117439745799956e-01,8.082942431138352424e-01,5.711940870538670945e-01,1.490272384107881765e+00,3.692378767891994484e-01,9.050920875655098508e-01,1.128811782008320996e+00,1.126253668216393278e+00,8.734109392420458251e-01,5.244718079861424576e-01,4.911232686296309047e-01,3.963609896800371346e-01,5.069895362553684670e-01,1.574863079920674758e-01,9.856003900484041624e-01,1.036848281384411319e+00
+6.883324258806501339e-01,7.107449492784763123e-01,1.003308451906025134e+00,8.577879360163622735e-01,1.539453231544298806e+00,8.264850387549846911e-01,9.359281739938727362e-01,2.048856709250512775e+00,6.369817749657655792e-01,1.403509931769793440e+00,1.133924660048421185e+00,9.738173329176448290e-01,1.545846618401569561e+00,4.773398860971982405e-01,8.317210303940278937e-01,6.659385898396117387e-01,9.566742015114944042e-01,4.982074781111152095e-01,9.700659879672112185e-01,1.251533822910914751e+00
+2.193289792584471876e-01,1.444168029996975577e-01,2.818065587621303436e-01,3.633923749930468849e-01,6.918207240471885289e-01,4.249170950590445139e-01,3.510302821331614531e-01,7.711702625950600520e-01,2.827324571482761617e-01,3.794669403393211571e-01,3.149359926256778408e-01,2.816037458866646470e-01,7.164755829946196641e-01,2.219594975519232627e-01,3.351744749340607177e-01,3.465339819716781733e-01,4.239023072800830461e-01,1.733166634953264840e-01,4.584510381735480977e-01,5.963850938307371230e-01
+3.809063081518008209e-01,5.465817697684511867e-01,7.721669308951958355e-01,9.586052775992502939e-01,1.106283291516237721e+00,9.653566779438295109e-01,6.326368609364459994e-01,1.598030251157036119e+00,4.672374614612428978e-01,8.109833578557517075e-01,1.013156428293742195e+00,1.020910772735745153e+00,1.151273032905090199e+00,5.791976528722865281e-01,5.762718957159478350e-01,5.599879712525445141e-01,6.592452753751245620e-01,1.919744367288593878e-01,1.115927899774691978e+00,1.240457852259422511e+00
+4.097349117171018573e-01,5.057586327159691031e-01,7.362059580237217959e-01,8.695357873373452673e-01,1.148209772988772137e+00,8.875464877592552648e-01,6.510232800554107691e-01,1.572878003935865632e+00,4.806516459123344465e-01,8.401188797848546663e-01,9.279061656787667145e-01,9.039974376823688429e-01,1.186938454467031834e+00,5.209607410316934306e-01,5.947537097079912405e-01,5.660933502038212328e-01,6.923799436450011058e-01,2.429177052172657447e-01,1.018626457793137208e+00,1.178166114246190999e+00
diff --git a/src/mlpack/tests/data/completion_indices.csv b/src/mlpack/tests/data/completion_indices.csv
new file mode 100644
index 0000000..65077d0
--- /dev/null
+++ b/src/mlpack/tests/data/completion_indices.csv
@@ -0,0 +1,2 @@
+17,1,0,0,11,9,7,15,6,4,17,5,9,8,4,17,13,19,4,16,7,11,9,14,4,6,12,0,13,16,18,18,19,6,15,0,4,8,10,19,12,5,3,0,4,9,0,16,9,16,18,18,9,6,18,0,5,3,7,14,4,12,13,2,14,17,17,9,4,9,15,13,17,12,16,14,3,9,19,11,4,6,16,11,19,8,7,0,19,8,0,2,12,5,11,16,19,15,3,2,14,7,0,15,0,12,6,13,15,15,10,13,5,15,5,13,4,17,9,6,18,10,16,15,7,7,14,1,4,10,2,4,6,6,4,4,15,18,13,10,3,5,10,4,1,5,5,18,10,10,3,19,10,2,12,10,2,17,8,13,3,15,11,5,2,1,18,15,13,2,17,6,0,0,7,5,17,14,1,6,12,11,16,8,16,5,18,1,12,13,1,12,11,2,0,12,19,10,6,2,3,14,11,13,8,10,2,0,5,15,3,0,2,10,17,5,10,0,18,5,7,12,0,17,16,3,16,9,10,12,7,14,16,7,11,1,18,15,9,8
+1,11,1,6,6,4,19,9,6,5,5,12,11,2,3,15,1,1,19,10,3,19,0,15,11,12,4,7,8,6,15,12,14,5,6,18,15,13,3,5,9,0,0,19,4,10,4,15,6,16,5,0,5,11,17,16,10,13,9,4,18,17,0,17,7,17,6,3,12,9,19,12,7,12,3,9,8,17,11,10,6,2,9,13,19,9,1,10,6,15,8,9,2,14,7,13,2,0,2,1,18,18,3,17,5,11,18,6,16,3,17,9,3,2,5,18,8,0,15,16,8,11,2,12,5,6,3,16,2,12,4,14,17,9,9,1,4,11,10,16,7,1,19,0,9,13,7,14,7,9,10,12,8,8,16,6,12,11,6,3,1,15,4,8,7,6,7,11,13,3,16,15,15,13,0,19,18,14,17,14,13,14,12,3,11,4,13,19,15,15,5,0,18,15,11,6,10,15,13,13,11,5,1,19,11,1,19,12,11,18,6,14,10,13,10,17,5,17,9,18,14,5,0,9,7,3,18,2,4,14,12,10,17,11,9,3,19,14,14,8
diff --git a/src/mlpack/tests/matrix_completion_test.cpp b/src/mlpack/tests/matrix_completion_test.cpp
index b325696..6392c9d 100644
--- a/src/mlpack/tests/matrix_completion_test.cpp
+++ b/src/mlpack/tests/matrix_completion_test.cpp
@@ -15,4 +15,35 @@ using namespace mlpack::matrix_completion;
 
 BOOST_AUTO_TEST_SUITE(MatrixCompletionTest);
 
+BOOST_AUTO_TEST_CASE(GaussianMatrixCompletionSDP)
+{
+  arma::mat Xorig, values;
+  arma::umat indices;
+
+  data::Load("completion_X.csv", Xorig, true, false);
+  data::Load("completion_indices.csv", indices, true, false);
+
+  values.set_size(indices.n_cols);
+  for (size_t i = 0; i < indices.n_cols; ++i)
+  {
+    values(i) = Xorig(indices(0, i), indices(1, i));
+  }
+
+  MatrixCompletion mc(Xorig.n_rows, Xorig.n_cols, indices, values);
+  mc.Recover();
+
+  const double err =
+    arma::norm(Xorig - mc.Recovered(), "fro") /
+    arma::norm(Xorig, "fro");
+  BOOST_REQUIRE_SMALL(err, 1e-5);
+
+  for (size_t i = 0; i < indices.n_cols; ++i)
+  {
+    BOOST_REQUIRE_CLOSE(
+      mc.Recovered()(indices(0, i), indices(1, i)),
+      Xorig(indices(0, i), indices(1, i)),
+      1e-5);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list