[mlpack-svn] r12382 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Apr 13 23:51:43 EDT 2012


Author: rcurtin
Date: 2012-04-13 23:51:43 -0400 (Fri, 13 Apr 2012)
New Revision: 12382

Modified:
   mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
Log:
Update tests to reflect new method names.


Modified: mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp	2012-04-14 03:51:36 UTC (rev 12381)
+++ mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp	2012-04-14 03:51:43 UTC (rev 12382)
@@ -7,8 +7,7 @@
 // Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
 // to use FPC_WEAK, and it's not at all intuitive how to do that.
 
-
-#include <armadillo>
+#include <mlpack/core.hpp>
 #include <mlpack/methods/sparse_coding/sparse_coding.hpp>
 
 #include <boost/test/unit_test.hpp>
@@ -20,7 +19,6 @@
 
 BOOST_AUTO_TEST_SUITE(SparseCodingTest);
 
-
 void VerifyCorrectness(vec beta, vec errCorr, double lambda)
 {
   const double tol = 1e-12;
@@ -29,23 +27,22 @@
   {
     if (beta(j) == 0)
     {
-      // make sure that errCorr(j) <= lambda
+      // Make sure that errCorr(j) <= lambda.
       BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
     }
     else if (beta(j) < 0)
     {
-      // make sure that errCorr(j) == lambda
+      // Make sure that errCorr(j) == lambda.
       BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
     }
-    else
-    { // beta(j) > 0
-      // make sure that errCorr(j) == -lambda
+    else // beta(j) > 0.
+    {
+      // Make sure that errCorr(j) == -lambda.
       BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
     }
   }
 }
 
-
 BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
 {
   double lambda1 = 0.1;
@@ -54,24 +51,23 @@
   mat X;
   X.load("mnist_first250_training_4s_and_9s.arm");
   uword nPoints = X.n_cols;
-  
-  // normalize each point since these are images
-  for(uword i = 0; i < nPoints; i++) {
+
+  // Normalize each point since these are images.
+  for (uword i = 0; i < nPoints; ++i)
     X.col(i) /= norm(X.col(i), 2);
-  }  
 
   SparseCoding sc(X, nAtoms, lambda1);
   sc.DataDependentRandomInitDictionary();
-  sc.OptimizeCode();  
+  sc.OptimizeCode();
 
-  mat D = sc.MatD();
-  mat Z = sc.MatZ();
-  
-  for(uword i = 0; i < nPoints; i++) {
+  mat D = sc.Dictionary();
+  mat Z = sc.Codes();
+
+  for (uword i = 0; i < nPoints; ++i)
+  {
     vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
     VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
   }
-  
 }
 
 BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
@@ -83,24 +79,23 @@
   mat X;
   X.load("mnist_first250_training_4s_and_9s.arm");
   uword nPoints = X.n_cols;
-  
-  // normalize each point since these are images
-  for(uword i = 0; i < nPoints; i++) {
+
+  // Normalize each point since these are images.
+  for (uword i = 0; i < nPoints; ++i)
     X.col(i) /= norm(X.col(i), 2);
-  }  
 
   SparseCoding sc(X, nAtoms, lambda1, lambda2);
   sc.DataDependentRandomInitDictionary();
-  sc.OptimizeCode();  
+  sc.OptimizeCode();
 
-  mat D = sc.MatD();
-  mat Z = sc.MatZ();
-  
-  for(uword i = 0; i < nPoints; i++) {
-    vec errCorr = 
-      (trans(D) * D + lambda2 *
+  mat D = sc.Dictionary();
+  mat Z = sc.Codes();
+
+  for(uword i = 0; i < nPoints; ++i)
+  {
+    vec errCorr = (trans(D) * D + lambda2 *
        eye(nAtoms, nAtoms)) * Z.unsafe_col(i) - trans(D) * X.unsafe_col(i);
-    
+
     VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
   }
 }
@@ -115,31 +110,29 @@
   mat X;
   X.load("mnist_first250_training_4s_and_9s.arm");
   uword nPoints = X.n_cols;
-  
-  // normalize each point since these are images
-  for(uword i = 0; i < nPoints; i++) {
+
+  // Normalize each point since these are images.
+  for(uword i = 0; i < nPoints; ++i)
     X.col(i) /= norm(X.col(i), 2);
-  }  
 
   SparseCoding sc(X, nAtoms, lambda1);
   sc.DataDependentRandomInitDictionary();
-  sc.OptimizeCode();  
-  
-  mat D = sc.MatD();
-  mat Z = sc.MatZ();
-  
+  sc.OptimizeCode();
+
+  mat D = sc.Dictionary();
+  mat Z = sc.Codes();
+
   X = D * Z;
-  
-  sc.SetData(X);
+
+  sc.Data() = X;
   sc.DataDependentRandomInitDictionary();
 
   uvec adjacencies = find(Z);
   sc.OptimizeDictionary(adjacencies);
-  
-  mat D_hat = sc.MatD();
 
+  mat D_hat = sc.Dictionary();
+
   BOOST_REQUIRE_SMALL(norm(D - D_hat, "fro"), tol);
-  
 }
 
 /*




More information about the mlpack-svn mailing list