[mlpack-svn] r13272 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 15:37:05 EDT 2012
Author: niche
Date: 2012-07-20 15:37:05 -0400 (Fri, 20 Jul 2012)
New Revision: 13272
Modified:
mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
Log:
fixed test for sparse coding dictionary step - the dictionary step now runs for more than one iteration, and the check is to see that the norm of the gradient of the lagrange dual (WRT the dual vars) is sufficiently small
Modified: mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-07-20 19:35:00 UTC (rev 13271)
+++ mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-07-20 19:37:05 UTC (rev 13272)
@@ -54,9 +54,10 @@
uword nPoints = X.n_cols;
// Normalize each point since these are images.
- for (uword i = 0; i < nPoints; ++i)
+ for (uword i = 0; i < nPoints; ++i) {
X.col(i) /= norm(X.col(i), 2);
-
+ }
+
SparseCoding<> sc(X, nAtoms, lambda1);
sc.OptimizeCode();
@@ -92,8 +93,9 @@
for(uword i = 0; i < nPoints; ++i)
{
- vec errCorr = (trans(D) * D + lambda2 *
- eye(nAtoms, nAtoms)) * Z.unsafe_col(i) - trans(D) * X.unsafe_col(i);
+ vec errCorr =
+ (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
+ - trans(D) * X.unsafe_col(i);
SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
}
@@ -101,7 +103,7 @@
BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
{
- const double tol = 1e-12;
+ const double tol = 1e-7;
double lambda1 = 0.1;
uword nAtoms = 25;
@@ -120,17 +122,10 @@
mat D = sc.Dictionary();
mat Z = sc.Codes();
- X = D * Z;
-
- // This will update sc.data (that is a reference to X).
- DataDependentRandomInitializer::Initialize(X, nAtoms, sc.Dictionary());
-
uvec adjacencies = find(Z);
- sc.OptimizeDictionary(adjacencies);
-
- mat D_hat = sc.Dictionary();
-
- BOOST_REQUIRE_SMALL(norm(D - D_hat, "fro"), tol);
+ double normGradient = sc.OptimizeDictionary(adjacencies, 1e-12);
+
+ BOOST_REQUIRE_SMALL(normGradient, tol);
}
/*
More information about the mlpack-svn
mailing list