[mlpack-svn] r12382 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Apr 13 23:51:43 EDT 2012
Author: rcurtin
Date: 2012-04-13 23:51:43 -0400 (Fri, 13 Apr 2012)
New Revision: 12382
Modified:
mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
Log:
Update tests to reflect new method names.
Modified: mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-04-14 03:51:36 UTC (rev 12381)
+++ mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-04-14 03:51:43 UTC (rev 12382)
@@ -7,8 +7,7 @@
// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
// to use FPC_WEAK, and it's not at all intuitive how to do that.
-
-#include <armadillo>
+#include <mlpack/core.hpp>
#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
#include <boost/test/unit_test.hpp>
@@ -20,7 +19,6 @@
BOOST_AUTO_TEST_SUITE(SparseCodingTest);
-
void VerifyCorrectness(vec beta, vec errCorr, double lambda)
{
const double tol = 1e-12;
@@ -29,23 +27,22 @@
{
if (beta(j) == 0)
{
- // make sure that errCorr(j) <= lambda
+ // Make sure that errCorr(j) <= lambda.
BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
}
else if (beta(j) < 0)
{
- // make sure that errCorr(j) == lambda
+ // Make sure that errCorr(j) == lambda.
BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
}
- else
- { // beta(j) > 0
- // make sure that errCorr(j) == -lambda
+ else // beta(j) > 0.
+ {
+ // Make sure that errCorr(j) == -lambda.
BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
}
}
}
-
BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
{
double lambda1 = 0.1;
@@ -54,24 +51,23 @@
mat X;
X.load("mnist_first250_training_4s_and_9s.arm");
uword nPoints = X.n_cols;
-
- // normalize each point since these are images
- for(uword i = 0; i < nPoints; i++) {
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
X.col(i) /= norm(X.col(i), 2);
- }
SparseCoding sc(X, nAtoms, lambda1);
sc.DataDependentRandomInitDictionary();
- sc.OptimizeCode();
+ sc.OptimizeCode();
- mat D = sc.MatD();
- mat Z = sc.MatZ();
-
- for(uword i = 0; i < nPoints; i++) {
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for (uword i = 0; i < nPoints; ++i)
+ {
vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
}
-
}
BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
@@ -83,24 +79,23 @@
mat X;
X.load("mnist_first250_training_4s_and_9s.arm");
uword nPoints = X.n_cols;
-
- // normalize each point since these are images
- for(uword i = 0; i < nPoints; i++) {
+
+ // Normalize each point since these are images.
+ for (uword i = 0; i < nPoints; ++i)
X.col(i) /= norm(X.col(i), 2);
- }
SparseCoding sc(X, nAtoms, lambda1, lambda2);
sc.DataDependentRandomInitDictionary();
- sc.OptimizeCode();
+ sc.OptimizeCode();
- mat D = sc.MatD();
- mat Z = sc.MatZ();
-
- for(uword i = 0; i < nPoints; i++) {
- vec errCorr =
- (trans(D) * D + lambda2 *
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
+ for(uword i = 0; i < nPoints; ++i)
+ {
+ vec errCorr = (trans(D) * D + lambda2 *
eye(nAtoms, nAtoms)) * Z.unsafe_col(i) - trans(D) * X.unsafe_col(i);
-
+
VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
}
}
@@ -115,31 +110,29 @@
mat X;
X.load("mnist_first250_training_4s_and_9s.arm");
uword nPoints = X.n_cols;
-
- // normalize each point since these are images
- for(uword i = 0; i < nPoints; i++) {
+
+ // Normalize each point since these are images.
+ for(uword i = 0; i < nPoints; ++i)
X.col(i) /= norm(X.col(i), 2);
- }
SparseCoding sc(X, nAtoms, lambda1);
sc.DataDependentRandomInitDictionary();
- sc.OptimizeCode();
-
- mat D = sc.MatD();
- mat Z = sc.MatZ();
-
+ sc.OptimizeCode();
+
+ mat D = sc.Dictionary();
+ mat Z = sc.Codes();
+
X = D * Z;
-
- sc.SetData(X);
+
+ sc.Data() = X;
sc.DataDependentRandomInitDictionary();
uvec adjacencies = find(Z);
sc.OptimizeDictionary(adjacencies);
-
- mat D_hat = sc.MatD();
+ mat D_hat = sc.Dictionary();
+
BOOST_REQUIRE_SMALL(norm(D - D_hat, "fro"), tol);
-
}
/*
More information about the mlpack-svn
mailing list