[mlpack-svn] r11687 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 20:52:40 EST 2012
Author: niche
Date: 2012-03-01 20:52:40 -0500 (Thu, 01 Mar 2012)
New Revision: 11687
Added:
mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp
Modified:
mlpack/trunk/src/mlpack/tests/CMakeLists.txt
Log:
added tests for lcc
Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-03-01 23:53:45 UTC (rev 11686)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-03-02 01:52:40 UTC (rev 11687)
@@ -18,6 +18,7 @@
lin_alg_test.cpp
linear_regression_test.cpp
load_save_test.cpp
+ local_coordinate_coding_test.cpp
lrsdp_test.cpp
math_test.cpp
nbc_test.cpp
Added: mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp 2012-03-02 01:52:40 UTC (rev 11687)
@@ -0,0 +1,129 @@
+/**
+ * @file local_coordinate_coding_test.cpp
+ *
+ * Test for Local Coordinate Coding
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+
+#include <armadillo>
+#include <mlpack/methods/local_coordinate_coding/lcc.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::lcc;
+
+BOOST_AUTO_TEST_SUITE(LocalCoordinateCodingTest);
+
+
+void VerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+ const double tol = 1e-12;
+ size_t nDims = beta.n_elem;
+ for(size_t j = 0; j < nDims; j++)
+ {
+ if (beta(j) == 0)
+ {
+ // make sure that errCorr(j) <= lambda
+ BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+ }
+ else if (beta(j) < 0)
+ {
+ // make sure that errCorr(j) == lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else
+ { // beta(j) > 0
+ // make sure that errCorr(j) == -lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestCodingStep)
+{
+ double lambda1 = 0.1;
+ u32 nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ u32 nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for(u32 i = 0; i < nPoints; i++) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ LocalCoordinateCoding lcc(X, nAtoms, lambda1);
+ lcc.DataDependentRandomInitDictionary();
+ lcc.OptimizeCode();
+
+ mat D = lcc.MatD();
+ mat Z = lcc.MatZ();
+
+ for(u32 i = 0; i < nPoints; i++) {
+ vec sq_dists = vec(nAtoms);
+ for(u32 j = 0; j < nAtoms; j++) {
+ vec diff = D.unsafe_col(j) - X.unsafe_col(i);
+ sq_dists[j] = dot(diff, diff);
+ }
+ mat Dprime = D * diagmat(1.0 / sq_dists);
+ mat zPrime = Z.unsafe_col(i) % sq_dists;
+
+ vec errCorr = trans(Dprime) * (Dprime * zPrime - X.unsafe_col(i));
+ VerifyCorrectness(zPrime, errCorr, 0.5 * lambda1);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestDictionaryStep)
+{
+ const double tol = 1e-12;
+
+ double lambda = 0.1;
+ u32 nAtoms = 25;
+
+ mat X;
+ X.load("mnist_first250_training_4s_and_9s.arm");
+ u32 nPoints = X.n_cols;
+
+ // normalize each point since these are images
+ for(u32 i = 0; i < nPoints; i++) {
+ X.col(i) /= norm(X.col(i), 2);
+ }
+
+ LocalCoordinateCoding lcc(X, nAtoms, lambda);
+ lcc.DataDependentRandomInitDictionary();
+ lcc.OptimizeCode();
+ mat Z = lcc.MatZ();
+ uvec adjacencies = find(Z);
+ lcc.OptimizeDictionary(adjacencies);
+
+
+ mat D = lcc.MatD();
+
+ mat grad = zeros(D.n_rows, D.n_cols);
+ for(u32 i = 0; i < nPoints; i++) {
+ grad += (D - repmat(X.unsafe_col(i), 1, nAtoms)) * diagmat(abs(Z.unsafe_col(i)));
+ }
+ grad = lambda * grad + (D * Z - X) * trans(Z);
+
+ BOOST_REQUIRE_SMALL(norm(grad, "fro"), tol);
+
+}
+
+
+/*
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list