[mlpack-svn] r11687 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 20:52:40 EST 2012


Author: niche
Date: 2012-03-01 20:52:40 -0500 (Thu, 01 Mar 2012)
New Revision: 11687

Added:
   mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp
Modified:
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt
Log:
added tests for lcc

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	2012-03-01 23:53:45 UTC (rev 11686)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	2012-03-02 01:52:40 UTC (rev 11687)
@@ -18,6 +18,7 @@
   lin_alg_test.cpp
   linear_regression_test.cpp
   load_save_test.cpp
+  local_coordinate_coding_test.cpp
   lrsdp_test.cpp
   math_test.cpp
   nbc_test.cpp

Added: mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/tests/local_coordinate_coding_test.cpp	2012-03-02 01:52:40 UTC (rev 11687)
@@ -0,0 +1,129 @@
+/**
+ * @file local_coordinate_coding_test.cpp
+ *
+ * Test for Local Coordinate Coding
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+
+#include <armadillo>
+#include <mlpack/methods/local_coordinate_coding/lcc.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace mlpack::lcc;
+
+BOOST_AUTO_TEST_SUITE(LocalCoordinateCodingTest);
+
+
+void VerifyCorrectness(vec beta, vec errCorr, double lambda)
+{
+  const double tol = 1e-12;
+  size_t nDims = beta.n_elem;
+  for(size_t j = 0; j < nDims; j++)
+  {
+    if (beta(j) == 0)
+    {
+      // make sure that errCorr(j) <= lambda
+      BOOST_REQUIRE_SMALL(std::max(fabs(errCorr(j)) - lambda, 0.0), tol);
+    }
+    else if (beta(j) < 0)
+    {
+      // make sure that errCorr(j) == lambda
+      BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+    }
+    else
+    { // beta(j) > 0
+      // make sure that errCorr(j) == -lambda
+      BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+    }
+  }
+}
+
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestCodingStep)
+{
+  double lambda1 = 0.1;
+  u32 nAtoms = 25;
+
+  mat X;
+  X.load("mnist_first250_training_4s_and_9s.arm");
+  u32 nPoints = X.n_cols;
+  
+  // normalize each point since these are images
+  for(u32 i = 0; i < nPoints; i++) {
+    X.col(i) /= norm(X.col(i), 2);
+  }  
+
+  LocalCoordinateCoding lcc(X, nAtoms, lambda1);
+  lcc.DataDependentRandomInitDictionary();
+  lcc.OptimizeCode();  
+  
+  mat D = lcc.MatD();
+  mat Z = lcc.MatZ();
+  
+  for(u32 i = 0; i < nPoints; i++) {
+    vec sq_dists = vec(nAtoms);
+    for(u32 j = 0; j < nAtoms; j++) {
+      vec diff = D.unsafe_col(j) - X.unsafe_col(i);
+      sq_dists[j] = dot(diff, diff);
+    }
+    mat Dprime = D * diagmat(1.0 / sq_dists);
+    mat zPrime = Z.unsafe_col(i) % sq_dists;
+    
+    vec errCorr = trans(Dprime) * (Dprime * zPrime - X.unsafe_col(i));
+    VerifyCorrectness(zPrime, errCorr, 0.5 * lambda1);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestDictionaryStep)
+{
+  const double tol = 1e-12;
+
+  double lambda = 0.1;
+  u32 nAtoms = 25;
+
+  mat X;
+  X.load("mnist_first250_training_4s_and_9s.arm");
+  u32 nPoints = X.n_cols;
+  
+  // normalize each point since these are images
+  for(u32 i = 0; i < nPoints; i++) {
+    X.col(i) /= norm(X.col(i), 2);
+  }  
+
+  LocalCoordinateCoding lcc(X, nAtoms, lambda);
+  lcc.DataDependentRandomInitDictionary();
+  lcc.OptimizeCode();  
+  mat Z = lcc.MatZ();
+  uvec adjacencies = find(Z);
+  lcc.OptimizeDictionary(adjacencies);
+  
+  
+  mat D = lcc.MatD();
+  
+  mat grad = zeros(D.n_rows, D.n_cols);
+  for(u32 i = 0; i < nPoints; i++) {
+    grad += (D - repmat(X.unsafe_col(i), 1, nAtoms)) * diagmat(abs(Z.unsafe_col(i)));
+  }
+  grad = lambda * grad + (D * Z - X) * trans(Z);
+
+  BOOST_REQUIRE_SMALL(norm(grad, "fro"), tol);
+  
+}
+
+
+/*
+BOOST_AUTO_TEST_CASE(LocalCoordinateCodingTestWhole)
+{
+
+}
+*/
+
+
+BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list