[mlpack-svn] r11677 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 01:43:01 EST 2012
Author: niche
Date: 2012-03-01 01:43:01 -0500 (Thu, 01 Mar 2012)
New Revision: 11677
Added:
mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
Log:
added test cases for sparse coding
Added: mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/tests/sparse_coding_test.cpp 2012-03-01 06:43:01 UTC (rev 11677)
@@ -0,0 +1,153 @@
+/**
+ * @file sparse_coding_test.cpp
+ *
+ * Test for Sparse Coding
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+
+#include <armadillo>
+#include <mlpack/methods/sparse_coding/sparse_coding.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::sparse_coding;
+
+BOOST_AUTO_TEST_SUITE(SparseCodingTest);
+
+
+void VerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+ const double tol = 1e-12;
+ size_t nDims = beta.n_elem;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // make sure that errCorr(j) <= lambda
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // make sure that errCorr(j) == lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else
+ { // beta(j) > 0
+ // make sure that errCorr(j) == -lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepLasso)
+{
+ double lambda1 = 0.1;
+ u32 nAtoms = 25;
+
+ mat X;
+ X.load("/home/niche/mlpack_11_11_11/mlpack/trunk/src/mlpack/tests/data/mnist_first250_training_4s_and_9s.arm");
+ u32 nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for(u32 i = 0; i < nPoints; i++) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ SparseCoding sc(X, nAtoms, lambda1);
+ sc.DataDependentRandomInitDictionary();
+ sc.OptimizeCode();
+
+ mat D = sc.MatD();
+ mat Z = sc.MatZ();
+
+ for(u32 i = 0; i < nPoints; i++) {
+ vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
+ VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestCodingStepElasticNet)
+{
+ double lambda1 = 0.1;
+ double lambda2 = 0.2;
+ u32 nAtoms = 25;
+
+ mat X;
+ X.load("/home/niche/mlpack_11_11_11/mlpack/trunk/src/mlpack/tests/data/mnist_first250_training_4s_and_9s.arm");
+ u32 nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for(u32 i = 0; i < nPoints; i++) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ SparseCoding sc(X, nAtoms, lambda1, lambda2);
+ sc.DataDependentRandomInitDictionary();
+ sc.OptimizeCode();
+
+ mat D = sc.MatD();
+ mat Z = sc.MatZ();
+
+ for(u32 i = 0; i < nPoints; i++) {
+ vec errCorr =
+ (trans(D) * D + lambda2 *
+ eye(nAtoms, nAtoms)) * Z.unsafe_col(i) - trans(D) * X.unsafe_col(i);
+
+ VerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(SparseCodingTestDictionaryStep)
+{
+ const double tol = 1e-12;
+
+ double lambda1 = 0.1;
+ u32 nAtoms = 25;
+
+ mat X;
+ X.load("/home/niche/mlpack_11_11_11/mlpack/trunk/src/mlpack/tests/data/mnist_first250_training_4s_and_9s.arm");
+ u32 nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for(u32 i = 0; i < nPoints; i++) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ SparseCoding sc(X, nAtoms, lambda1);
+ sc.DataDependentRandomInitDictionary();
+ sc.OptimizeCode();
+
+ mat D = sc.MatD();
+ mat Z = sc.MatZ();
+
+ X = D * Z;
+
+ sc.SetData(X);
+ sc.DataDependentRandomInitDictionary();
+
+ uvec adjacencies = find(Z);
+ sc.OptimizeDictionary(adjacencies);
+
+ mat D_hat = sc.MatD();
+
+ BOOST_REQUIRE_SMALL(norm(D - D_hat, "fro"), tol);
+
+}
+
+/*
+BOOST_AUTO_TEST_CASE(SparseCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list