[mlpack-svn] r13024 - mlpack/trunk/src/mlpack/methods/sparse_coding

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jun 11 16:10:21 EDT 2012


Author: rcurtin
Date: 2012-06-11 16:10:20 -0400 (Mon, 11 Jun 2012)
New Revision: 13024

Added:
   mlpack/trunk/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/nothing_initializer.hpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/random_initializer.hpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
Log:
Refactor SparseCoding to initialize the dictionary as a template parameter.


Modified: mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt	2012-06-11 19:23:47 UTC (rev 13023)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt	2012-06-11 20:10:20 UTC (rev 13024)
@@ -3,8 +3,11 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into the output library
 set(SOURCES
+  data_dependent_random_initializer.hpp
+  nothing_initializer.hpp
+  random_initializer.hpp
   sparse_coding.hpp
-  sparse_coding.cpp
+  sparse_coding_impl.hpp
 )
 
 # add directory name to sources

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/data_dependent_random_initializer.hpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -0,0 +1,56 @@
+/**
+ * @file data_dependent_random_initializer.hpp
+ * @author Nishant Mehta
+ *
+ * A sensible heuristic for initializing dictionaries for sparse coding.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_DATA_DEPENDENT_RANDOM_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A data-dependent random dictionary initializer for SparseCoding.  This
+ * creates random dictionary atoms by adding three random observations from the
+ * data together, and then normalizing the atom.
+ */
+class DataDependentRandomInitializer
+{
+ public:
+  /**
+   * Initialize the dictionary by adding together three random observations from
+   * the data, and then normalizing the atom.  This implementation is simple
+   * enough to be included with the definition.
+   *
+   * @param data Dataset to initialize the dictionary with.
+   * @param atoms Number of atoms in dictionary.
+   * @param dictionary Dictionary to initialize.
+   */
+  static void Initialize(const arma::mat& data,
+                         const size_t atoms,
+                         arma::mat& dictionary)
+  {
+    // Set the size of the dictionary.
+    dictionary.set_size(data.n_rows, atoms);
+
+    // Create each atom.
+    for (size_t i = 0; i < atoms; ++i)
+    {
+      // Add three atoms together.
+      dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
+          data.col(math::RandInt(data.n_cols)) +
+          data.col(math::RandInt(data.n_cols)));
+
+      // Now normalize the atom.
+      dictionary.col(i) /= norm(dictionary.col(i), 2);
+    }
+  }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/nothing_initializer.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/nothing_initializer.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/nothing_initializer.hpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -0,0 +1,41 @@
+/**
+ * @file nothing_initializer.hpp
+ * @author Ryan Curtin
+ *
+ * An initializer for SparseCoding which does precisely nothing.  It is useful
+ * for when you have an already defined dictionary and you plan on setting it
+ * with SparseCoding::Dictionary().
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_NOTHING_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A DictionaryInitializer for SparseCoding which does not initialize anything;
+ * it is useful for when the dictionary is already known and will be set with
+ * SparseCoding::Dictionary().
+ */
+class NothingInitializer
+{
+ public:
+  /**
+   * This function does not initialize the dictionary.  This will cause problems
+   * for SparseCoding if the dictionary is not set manually before running the
+   * method.
+   */
+  static void Initialize(const arma::mat& /* data */,
+                         const size_t /* atoms */,
+                         arma::mat& /* dictionary */)
+  {
+    // Do nothing!
+  }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/random_initializer.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/random_initializer.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/random_initializer.hpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -0,0 +1,48 @@
+/**
+ * @file random_initializer.hpp
+ * @author Nishant Mehta
+ *
+ * A very simple random dictionary initializer for SparseCoding; it is probably
+ * not a very good choice.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_RANDOM_INITIALIZER_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * A DictionaryInitializer for use with the SparseCoding class.  This provides a
+ * random, normally distributed dictionary, such that each atom has a norm of 1.
+ */
+class RandomInitializer
+{
+ public:
+  /**
+   * Initialize the dictionary randomly from a normal distribution, such that
+   * each atom has a norm of 1.  This is simple enough to be included with the
+   * definition.
+   *
+   * @param data Dataset to use for initialization.
+   * @param atoms Number of atoms (columns) in the dictionary.
+   * @param dictionary Dictionary to initialize.
+   */
+  static void Initialize(const arma::mat& data,
+                         const size_t atoms,
+                         arma::mat& dictionary)
+  {
+    // Create random dictionary.
+    dictionary.randn(data.n_rows, atoms);
+
+    // Normalize each atom.
+    for (size_t j = 0; j < atoms; ++j)
+      dictionary.col(j) /= norm(dictionary.col(j), 2);
+  }
+};
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif

Deleted: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp	2012-06-11 19:23:47 UTC (rev 13023)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -1,401 +0,0 @@
-/**
- * @file sparse_coding.cpp
- * @author Nishant Mehta
- *
- * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
- * l1+l2 (Elastic Net) regularization.
- */
-#include "sparse_coding.hpp"
-
-using namespace std;
-using namespace arma;
-using namespace mlpack;
-using namespace mlpack::regression;
-using namespace mlpack::sparse_coding;
-
-// TODO: parameterizable; options to methods?
-#define OBJ_TOL 1e-2 // 1E-9
-#define NEWTON_TOL 1e-6 // 1E-9
-
-SparseCoding::SparseCoding(const mat& data,
-                           const size_t atoms,
-                           const double lambda1,
-                           const double lambda2) :
-    atoms(atoms),
-    data(data),
-    codes(mat(atoms, data.n_cols)),
-    lambda1(lambda1),
-    lambda2(lambda2)
-{ /* Nothing left to do. */ }
-
-// Always a not good decision!
-void SparseCoding::RandomInitDictionary()
-{
-  dictionary = randn(data.n_rows, atoms);
-
-  for (size_t j = 0; j < atoms; ++j)
-    dictionary.col(j) /= norm(dictionary.col(j), 2);
-}
-
-// The sensible heuristic.
-void SparseCoding::DataDependentRandomInitDictionary()
-{
-  dictionary = mat(data.n_rows, atoms);
-  for (size_t j = 0; j < atoms; ++j)
-  {
-    vec vecD_j = dictionary.unsafe_col(j);
-    RandomAtom(vecD_j);
-  }
-}
-
-void SparseCoding::RandomAtom(vec& atom)
-{
-  atom.zeros();
-  const size_t nSeedAtoms = 3;
-  for (size_t i = 0; i < nSeedAtoms; i++)
-    atom += data.col(rand() % data.n_cols);
-
-  atom /= norm(atom, 2);
-}
-
-void SparseCoding::DoSparseCoding(const size_t maxIterations)
-{
-  double lastObjVal = DBL_MAX;
-
-  Log::Info << "Initial Coding Step." << endl;
-
-  OptimizeCode();
-  uvec adjacencies = find(codes);
-
-  Log::Info << "  Sparsity level: "
-      << 100.0 * ((double) (adjacencies.n_elem)) / ((double)
-      (atoms * data.n_cols)) << "%" << endl;
-  Log::Info << "  Objective value: " << Objective() << "." << endl;
-
-  for (size_t t = 1; t != maxIterations; ++t)
-  {
-    Log::Info << "Iteration " << t << " of " << maxIterations << "." << endl;
-
-    Log::Info << "Performing dictionary step... ";
-    OptimizeDictionary(adjacencies);
-    Log::Info << "objective value: " << Objective() << "." << endl;
-
-    Log::Info << "Performing coding step..." << endl;
-    OptimizeCode();
-    adjacencies = find(codes);
-    Log::Info << "  Sparsity level: "
-        << 100.0 *
-        ((double) (adjacencies.n_elem)) / ((double) (atoms * data.n_cols))
-        << "%" << endl;
-
-    double curObjVal = Objective();
-    Log::Info << "  Objective value: " << curObjVal << "." << endl;
-
-    double objValImprov = lastObjVal - curObjVal;
-    Log::Info << "  Improvement: " << scientific << objValImprov << "." << endl;
-
-    if (objValImprov < OBJ_TOL)
-    {
-      Log::Info << "Converged within tolerance " << OBJ_TOL << ".\n";
-      break;
-    }
-
-    lastObjVal = curObjVal;
-  }
-}
-
-void SparseCoding::OptimizeCode()
-{
-  // When using Cholesky version of LARS, this is correct even if lambda2 > 0.
-  mat matGram = trans(dictionary) * dictionary;
-  // mat matGram;
-  // if(lambda2 > 0) {
-  //   matGram = trans(dictionary) * dictionary + lambda2 * eye(atoms, atoms);
-  // }
-  // else {
-  //   matGram = trans(dictionary) * dictionary;
-  // }
-
-  for (size_t i = 0; i < data.n_cols; ++i)
-  {
-    // Report progress.
-    if ((i % 100) == 0)
-      Log::Debug << "Optimization at point " << i << "." << endl;
-
-    bool useCholesky = true;
-    LARS lars(useCholesky, matGram, lambda1, lambda2);
-
-    vec beta;
-    lars.Regress(dictionary, data.unsafe_col(i), beta, true);
-
-    codes.col(i) = beta;
-  }
-}
-
-void SparseCoding::OptimizeDictionary(const uvec& adjacencies)
-{
-  // Count the number of atomic neighbors for each point x^i.
-  uvec neighborCounts = zeros<uvec>(data.n_cols, 1);
-
-  if (adjacencies.n_elem > 0)
-  {
-    // This gets the column index.
-    // TODO: is this integer division intentional?
-    size_t curPointInd = (size_t) (adjacencies(0) / atoms);
-    size_t curCount = 1;
-
-    for (size_t l = 1; l < adjacencies.n_elem; ++l)
-    {
-      if ((size_t) (adjacencies(l) / atoms) == curPointInd)
-      {
-        ++curCount;
-      }
-      else
-      {
-        neighborCounts(curPointInd) = curCount;
-        curPointInd = (size_t) (adjacencies(l) / atoms);
-        curCount = 1;
-      }
-    }
-
-    neighborCounts(curPointInd) = curCount;
-  }
-
-  // Handle the case of inactive atoms (atoms not used in the given coding).
-  std::vector<size_t> inactiveAtoms;
-  std::vector<size_t> activeAtoms;
-  activeAtoms.reserve(atoms);
-
-  for (size_t j = 0; j < atoms; ++j)
-  {
-    if (accu(codes.row(j) != 0) == 0)
-      inactiveAtoms.push_back(j);
-    else
-      activeAtoms.push_back(j);
-  }
-
-  const size_t nActiveAtoms = activeAtoms.size();
-  const size_t nInactiveAtoms = inactiveAtoms.size();
-
-  // Efficient construction of Z restricted to active atoms.
-  mat matActiveZ;
-  if (inactiveAtoms.empty())
-  {
-    matActiveZ = codes;
-  }
-  else
-  {
-    uvec inactiveAtomsVec = conv_to<uvec>::from(inactiveAtoms);
-    RemoveRows(codes, inactiveAtomsVec, matActiveZ);
-  }
-
-  uvec atomReverseLookup = uvec(atoms);
-  for (size_t i = 0; i < nActiveAtoms; ++i)
-    atomReverseLookup(activeAtoms[i]) = i;
-
-  if (nInactiveAtoms > 0)
-  {
-    Log::Info << "There are " << nInactiveAtoms
-        << " inactive atoms. They will be re-initialized randomly.\n";
-  }
-
-  Log::Debug << "Solving Dual via Newton's Method.\n";
-
-  mat dictionaryEstimate;
-  // Solve using Newton's method in the dual - note that the final dot
-  // multiplication with inv(A) seems to be unavoidable. Although more
-  // expensive, the code written this way (we use solve()) should be more
-  // numerically stable than just using inv(A) for everything.
-  vec dualVars = zeros<vec>(nActiveAtoms);
-
-  //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
-
-  // Method used by feature sign code - fails miserably here.  Perhaps the
-  // MATLAB optimizer fmincon does something clever?
-  //vec dualVars = 10.0 * randu(nActiveAtoms, 1);
-
-  //vec dualVars = diagvec(solve(dictionary, data * trans(codes))
-  //    - codes * trans(codes));
-  //for (size_t i = 0; i < dualVars.n_elem; i++)
-  //  if (dualVars(i) < 0)
-  //    dualVars(i) = 0;
-
-  bool converged = false;
-  mat codesXT = matActiveZ * trans(data);
-  mat codesZT = matActiveZ * trans(matActiveZ);
-
-  for (size_t t = 1; !converged; ++t)
-  {
-    mat A = codesZT + diagmat(dualVars);
-
-    mat matAInvZXT = solve(A, codesXT);
-
-    vec gradient = -(sum(square(matAInvZXT), 1) - ones<vec>(nActiveAtoms));
-
-    mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
-
-    vec searchDirection = -solve(hessian, gradient);
-    //vec searchDirection = -gradient;
-
-    // Armijo line search.
-    const double c = 1e-4;
-    double alpha = 1.0;
-    const double rho = 0.9;
-    double sufficientDecrease = c * dot(gradient, searchDirection);
-
-    /*
-    {
-      double sumDualVars = sum(dualVars);
-      double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
-      Log::Debug << "fOld = " << fOld << "." << endl;
-      double fNew =
-          -(-trace(trans(codesXT) * solve(codesZT +
-          diagmat(dualVars + alpha * searchDirection), codesXT))
-          - (sumDualVars + alpha * sum(searchDirection)) );
-      Log::Debug << "fNew = " << fNew << "." << endl;
-    }
-    */
-
-    double improvement;
-    while (true)
-    {
-      // Calculate objective.
-      double sumDualVars = sum(dualVars);
-      double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
-      double fNew = -(-trace(trans(codesXT) * solve(codesZT +
-          diagmat(dualVars + alpha * searchDirection), codesXT)) -
-          (sumDualVars + alpha * sum(searchDirection)));
-
-      if (fNew <= fOld + alpha * sufficientDecrease)
-      {
-        searchDirection = alpha * searchDirection;
-        improvement = fOld - fNew;
-        break;
-      }
-
-      alpha *= rho;
-    }
-
-    // End of Armijo line search code.
-
-    dualVars += searchDirection;
-    double normGradient = norm(gradient, 2);
-    Log::Debug << "Newton Method iteration " << t << ":" << endl;
-    Log::Debug << "  Gradient norm: " << std::scientific << normGradient
-        << "." << endl;
-    Log::Debug << "  Improvement: " << std::scientific << improvement << ".\n";
-
-    if (improvement < NEWTON_TOL)
-      converged = true;
-  }
-
-  if (inactiveAtoms.empty())
-  {
-    dictionaryEstimate = trans(solve(codesZT + diagmat(dualVars), codesXT));
-  }
-  else
-  {
-    mat dictionaryActiveEstimate = trans(solve(codesZT + diagmat(dualVars),
-        codesXT));
-    dictionaryEstimate = zeros(data.n_rows, atoms);
-
-    for (size_t i = 0; i < nActiveAtoms; ++i)
-      dictionaryEstimate.col(activeAtoms[i]) = dictionaryActiveEstimate.col(i);
-
-    for (size_t i = 0; i < nInactiveAtoms; ++i)
-    {
-      vec vecdictionaryi = dictionaryEstimate.unsafe_col(inactiveAtoms[i]);
-      RandomAtom(vecdictionaryi);
-    }
-  }
-
-  dictionary = dictionaryEstimate;
-}
-
-void SparseCoding::ProjectDictionary()
-{
-  for (size_t j = 0; j < atoms; j++)
-  {
-    double normD_j = norm(dictionary.col(j), 2);
-    if ((normD_j > 1) && (normD_j - 1.0 > 1e-9))
-    {
-      Log::Warn << "Norm exceeded 1 by " << std::scientific << normD_j - 1.0
-          << ".  Shrinking...\n";
-      dictionary.col(j) /= normD_j;
-    }
-  }
-}
-
-double SparseCoding::Objective()
-{
-  double l11NormZ = sum(sum(abs(codes)));
-  double froNormResidual = norm(data - dictionary * codes, "fro");
-
-  if (lambda2 > 0)
-  {
-    double froNormZ = norm(codes, "fro");
-    return 0.5 *
-      (froNormResidual * froNormResidual + lambda2 * froNormZ * froNormZ) +
-      lambda1 * l11NormZ;
-  }
-  else
-  {
-    return 0.5 * froNormResidual * froNormResidual + lambda1 * l11NormZ;
-  }
-}
-
-void mlpack::sparse_coding::RemoveRows(const mat& X,
-                                       uvec rows_to_remove,
-                                       mat& X_mod)
-{
-  const size_t n_cols = X.n_cols;
-  const size_t n_rows = X.n_rows;
-  const size_t n_to_remove = rows_to_remove.n_elem;
-  const size_t n_to_keep = n_rows - n_to_remove;
-
-  if (n_to_remove == 0)
-  {
-    X_mod = X;
-  }
-  else
-  {
-    X_mod.set_size(n_to_keep, n_cols);
-
-    size_t cur_row = 0;
-    size_t remove_ind = 0;
-    // First, check 0 to first row to remove.
-    if (rows_to_remove(0) > 0)
-    {
-      // Note that this implies that n_rows > 1.
-      size_t height = rows_to_remove(0);
-      X_mod(span(cur_row, cur_row + height - 1), span::all) =
-          X(span(0, rows_to_remove(0) - 1), span::all);
-      cur_row += height;
-    }
-    // Now, check i'th row to remove to (i + 1)'th row to remove, until i is the
-    // penultimate row.
-    while (remove_ind < n_to_remove - 1)
-    {
-      size_t height = rows_to_remove[remove_ind + 1] -
-          rows_to_remove[remove_ind] - 1;
-
-      if (height > 0)
-      {
-        X_mod(span(cur_row, cur_row + height - 1), span::all) =
-            X(span(rows_to_remove[remove_ind] + 1,
-            rows_to_remove[remove_ind + 1] - 1), span::all);
-        cur_row += height;
-      }
-
-      remove_ind++;
-    }
-
-    // Now that i is the last row to remove, check last row to remove to last
-    // row.
-    if (rows_to_remove[remove_ind] < n_rows - 1)
-    {
-      X_mod(span(cur_row, n_to_keep - 1), span::all) =
-          X(span(rows_to_remove[remove_ind] + 1, n_rows - 1), span::all);
-    }
-  }
-}

Modified: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp	2012-06-11 19:23:47 UTC (rev 13023)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -11,6 +11,11 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/lars/lars.hpp>
 
+// Include our three simple dictionary initializers.
+#include "nothing_initializer.hpp"
+#include "data_dependent_random_initializer.hpp"
+#include "random_initializer.hpp"
+
 namespace mlpack {
 namespace sparse_coding {
 
@@ -78,7 +83,20 @@
  *   publisher={Royal Statistical Society}
  * }
  * @endcode
+ *
+ * Before the method is run, the dictionary is initialized using the
+ * DictionaryInitializationPolicy class.  Possible choices include the
+ * RandomInitializer, which provides an entirely random dictionary, the
+ * DataDependentRandomInitializer, which provides a random dictionary based
+ * loosely on characteristics of the dataset, and the NothingInitializer, which
+ * does not initialize the dictionary -- instead, the user should set the
+ * dictionary using the Dictionary() mutator method.
+ *
+ * @tparam DictionaryInitializationPolicy The class to use to initialize the
+ *     dictionary; must have 'void Initialize(const arma::mat& data, arma::mat&
+ *     dictionary)' function.
  */
+template<typename DictionaryInitializer = DataDependentRandomInitializer>
 class SparseCoding
 {
  public:
@@ -96,28 +114,6 @@
                const double lambda2 = 0);
 
   /**
-   * Initialize dictionary by drawing k vectors uniformly at random from the
-   * unit sphere.
-   */
-  void RandomInitDictionary();
-
-  /**
-   * Initialize dictionary by initializing each atom to a normalized mixture of
-   * a small number of randomly selected points in X.
-   */
-  void DataDependentRandomInitDictionary();
-
-  /**
-   * Initialize an atom to a normalized mixture of a small number of randomly
-   * selected points in X.
-   *
-   * @param atom The atom to initialize
-   */
-  void RandomAtom(arma::vec& atom);
-
-  // core algorithm functions
-
-  /**
    * Run Sparse Coding with Dictionary Learning.
    *
    * @param maxIterations Maximum number of iterations to run algorithm.  If 0,
@@ -131,7 +127,7 @@
   void OptimizeCode();
 
   /**
-   * Learn dictionary via Newton method based on Lagrange dual
+   * Learn dictionary via Newton method based on Lagrange dual.
    *
    * @param adjacencies Indices of entries (unrolled column by column) of
    *    the coding matrix Z that are non-zero (the adjacency matrix for the
@@ -186,10 +182,13 @@
 };
 
 void RemoveRows(const arma::mat& X,
-                arma::uvec rows_to_remove,
-                arma::mat& X_mod);
+                const arma::uvec& rowsToRemove,
+                arma::mat& modX);
 
 }; // namespace sparse_coding
 }; // namespace mlpack
 
+// Include implementation.
+#include "sparse_coding_impl.hpp"
+
 #endif

Copied: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp (from rev 13003, mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -0,0 +1,399 @@
+/**
+ * @file sparse_coding_impl.hpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
+ * l1+l2 (Elastic Net) regularization.
+ */
+#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_IMPL_HPP
+
+// In case it hasn't already been included.
+#include "sparse_coding.hpp"
+
+namespace mlpack {
+namespace sparse_coding {
+
+// TODO: parameterizable; options to methods?
+#define OBJ_TOL 1e-2 // 1E-9
+#define NEWTON_TOL 1e-6 // 1E-9
+
+template<typename DictionaryInitializer>
+SparseCoding<DictionaryInitializer>::SparseCoding(const arma::mat& data,
+                                                  const size_t atoms,
+                                                  const double lambda1,
+                                                  const double lambda2) :
+    atoms(atoms),
+    data(data),
+    codes(atoms, data.n_cols),
+    lambda1(lambda1),
+    lambda2(lambda2)
+{
+  // Initialize the dictionary.
+  DictionaryInitializer::Initialize(data, atoms, dictionary);
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::DoSparseCoding(
+    const size_t maxIterations)
+{
+  double lastObjVal = DBL_MAX;
+
+  Log::Info << "Initial Coding Step." << std::endl;
+
+  OptimizeCode();
+  arma::uvec adjacencies = find(codes);
+
+  Log::Info << "  Sparsity level: "
+      << 100.0 * ((double) (adjacencies.n_elem)) / ((double)
+      (atoms * data.n_cols)) << "%" << std::endl;
+  Log::Info << "  Objective value: " << Objective() << "." << std::endl;
+
+  for (size_t t = 1; t != maxIterations; ++t)
+  {
+    Log::Info << "Iteration " << t << " of " << maxIterations << "."
+        << std::endl;
+
+    Log::Info << "Performing dictionary step... ";
+    OptimizeDictionary(adjacencies);
+    Log::Info << "objective value: " << Objective() << "." << std::endl;
+
+    Log::Info << "Performing coding step..." << std::endl;
+    OptimizeCode();
+    adjacencies = find(codes);
+    Log::Info << "  Sparsity level: "
+        << 100.0 *
+        ((double) (adjacencies.n_elem)) / ((double) (atoms * data.n_cols))
+        << "%" << std::endl;
+
+    double curObjVal = Objective();
+    Log::Info << "  Objective value: " << curObjVal << "." << std::endl;
+
+    double objValImprov = lastObjVal - curObjVal;
+    Log::Info << "  Improvement: " << std::scientific << objValImprov << "."
+        << std::endl;
+
+    if (objValImprov < OBJ_TOL)
+    {
+      Log::Info << "Converged within tolerance " << OBJ_TOL << ".\n";
+      break;
+    }
+
+    lastObjVal = curObjVal;
+  }
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::OptimizeCode()
+{
+  // When using Cholesky version of LARS, this is correct even if lambda2 > 0.
+  arma::mat matGram = trans(dictionary) * dictionary;
+  // mat matGram;
+  // if(lambda2 > 0) {
+  //   matGram = trans(dictionary) * dictionary + lambda2 * eye(atoms, atoms);
+  // }
+  // else {
+  //   matGram = trans(dictionary) * dictionary;
+  // }
+
+  for (size_t i = 0; i < data.n_cols; ++i)
+  {
+    // Report progress.
+    if ((i % 100) == 0)
+      Log::Debug << "Optimization at point " << i << "." << std::endl;
+
+    bool useCholesky = true;
+    regression::LARS lars(useCholesky, matGram, lambda1, lambda2);
+
+    arma::vec beta;
+    lars.Regress(dictionary, data.unsafe_col(i), beta, true);
+
+    codes.col(i) = beta;
+  }
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::OptimizeDictionary(
+      const arma::uvec& adjacencies)
+{
+  // Count the number of atomic neighbors for each point x^i.
+  arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
+
+  if (adjacencies.n_elem > 0)
+  {
+    // This gets the column index.
+    // TODO: is this integer division intentional?
+    size_t curPointInd = (size_t) (adjacencies(0) / atoms);
+    size_t curCount = 1;
+
+    for (size_t l = 1; l < adjacencies.n_elem; ++l)
+    {
+      if ((size_t) (adjacencies(l) / atoms) == curPointInd)
+      {
+        ++curCount;
+      }
+      else
+      {
+        neighborCounts(curPointInd) = curCount;
+        curPointInd = (size_t) (adjacencies(l) / atoms);
+        curCount = 1;
+      }
+    }
+
+    neighborCounts(curPointInd) = curCount;
+  }
+
+  // Handle the case of inactive atoms (atoms not used in the given coding).
+  std::vector<size_t> inactiveAtoms;
+  std::vector<size_t> activeAtoms;
+  activeAtoms.reserve(atoms);
+
+  for (size_t j = 0; j < atoms; ++j)
+  {
+    if (accu(codes.row(j) != 0) == 0)
+      inactiveAtoms.push_back(j);
+    else
+      activeAtoms.push_back(j);
+  }
+
+  const size_t nActiveAtoms = activeAtoms.size();
+  const size_t nInactiveAtoms = inactiveAtoms.size();
+
+  // Efficient construction of Z restricted to active atoms.
+  arma::mat matActiveZ;
+  if (inactiveAtoms.empty())
+  {
+    matActiveZ = codes;
+  }
+  else
+  {
+    arma::uvec inactiveAtomsVec =
+        arma::conv_to<arma::uvec>::from(inactiveAtoms);
+    RemoveRows(codes, inactiveAtomsVec, matActiveZ);
+  }
+
+  arma::uvec atomReverseLookup(atoms);
+  for (size_t i = 0; i < nActiveAtoms; ++i)
+    atomReverseLookup(activeAtoms[i]) = i;
+
+  if (nInactiveAtoms > 0)
+  {
+    Log::Info << "There are " << nInactiveAtoms
+        << " inactive atoms. They will be re-initialized randomly.\n";
+  }
+
+  Log::Debug << "Solving Dual via Newton's Method.\n";
+
+  arma::mat dictionaryEstimate;
+  // Solve using Newton's method in the dual - note that the final dot
+  // multiplication with inv(A) seems to be unavoidable. Although more
+  // expensive, the code written this way (we use solve()) should be more
+  // numerically stable than just using inv(A) for everything.
+  arma::vec dualVars = arma::zeros<arma::vec>(nActiveAtoms);
+
+  //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
+
+  // Method used by feature sign code - fails miserably here.  Perhaps the
+  // MATLAB optimizer fmincon does something clever?
+  //vec dualVars = 10.0 * randu(nActiveAtoms, 1);
+
+  //vec dualVars = diagvec(solve(dictionary, data * trans(codes))
+  //    - codes * trans(codes));
+  //for (size_t i = 0; i < dualVars.n_elem; i++)
+  //  if (dualVars(i) < 0)
+  //    dualVars(i) = 0;
+
+  bool converged = false;
+  arma::mat codesXT = matActiveZ * trans(data);
+  arma::mat codesZT = matActiveZ * trans(matActiveZ);
+
+  for (size_t t = 1; !converged; ++t)
+  {
+    arma::mat A = codesZT + diagmat(dualVars);
+
+    arma::mat matAInvZXT = solve(A, codesXT);
+
+    arma::vec gradient = -(arma::sum(arma::square(matAInvZXT), 1) -
+        arma::ones<arma::vec>(nActiveAtoms));
+
+    arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
+
+    arma::vec searchDirection = -solve(hessian, gradient);
+    //vec searchDirection = -gradient;
+
+    // Armijo line search.
+    const double c = 1e-4;
+    double alpha = 1.0;
+    const double rho = 0.9;
+    double sufficientDecrease = c * dot(gradient, searchDirection);
+
+    /*
+    {
+      double sumDualVars = sum(dualVars);
+      double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
+      Log::Debug << "fOld = " << fOld << "." << std::endl;
+      double fNew =
+          -(-trace(trans(codesXT) * solve(codesZT +
+          diagmat(dualVars + alpha * searchDirection), codesXT))
+          - (sumDualVars + alpha * sum(searchDirection)) );
+      Log::Debug << "fNew = " << fNew << "." << std::endl;
+    }
+    */
+
+    double improvement;
+    while (true)
+    {
+      // Calculate objective.
+      double sumDualVars = sum(dualVars);
+      double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
+      double fNew = -(-trace(trans(codesXT) * solve(codesZT +
+          diagmat(dualVars + alpha * searchDirection), codesXT)) -
+          (sumDualVars + alpha * sum(searchDirection)));
+
+      if (fNew <= fOld + alpha * sufficientDecrease)
+      {
+        searchDirection = alpha * searchDirection;
+        improvement = fOld - fNew;
+        break;
+      }
+
+      alpha *= rho;
+    }
+
+    // End of Armijo line search code.
+
+    dualVars += searchDirection;
+    double normGradient = norm(gradient, 2);
+    Log::Debug << "Newton Method iteration " << t << ":" << std::endl;
+    Log::Debug << "  Gradient norm: " << std::scientific << normGradient
+        << "." << std::endl;
+    Log::Debug << "  Improvement: " << std::scientific << improvement << ".\n";
+
+    if (improvement < NEWTON_TOL)
+      converged = true;
+  }
+
+  if (inactiveAtoms.empty())
+  {
+    dictionaryEstimate = trans(solve(codesZT + diagmat(dualVars), codesXT));
+  }
+  else
+  {
+    arma::mat dictionaryActiveEstimate = trans(solve(codesZT +
+        diagmat(dualVars), codesXT));
+    dictionaryEstimate = arma::zeros(data.n_rows, atoms);
+
+    for (size_t i = 0; i < nActiveAtoms; ++i)
+      dictionaryEstimate.col(activeAtoms[i]) = dictionaryActiveEstimate.col(i);
+
+    for (size_t i = 0; i < nInactiveAtoms; ++i)
+    {
+      // Make a new random atom estimate.
+      dictionaryEstimate.col(inactiveAtoms[i]) =
+          (data.col(math::RandInt(data.n_cols)) +
+           data.col(math::RandInt(data.n_cols)) +
+           data.col(math::RandInt(data.n_cols)));
+
+      dictionaryEstimate.col(inactiveAtoms[i]) /=
+          norm(dictionaryEstimate.col(inactiveAtoms[i]), 2);
+    }
+  }
+
+  dictionary = dictionaryEstimate;
+}
+
+template<typename DictionaryInitializer>
+void SparseCoding<DictionaryInitializer>::ProjectDictionary()
+{
+  for (size_t j = 0; j < atoms; j++)
+  {
+    double normD_j = norm(dictionary.col(j), 2);
+    if ((normD_j > 1) && (normD_j - 1.0 > 1e-9))
+    {
+      Log::Warn << "Norm exceeded 1 by " << std::scientific << normD_j - 1.0
+          << ".  Shrinking...\n";
+      dictionary.col(j) /= normD_j;
+    }
+  }
+}
+
+template<typename DictionaryInitializer>
+double SparseCoding<DictionaryInitializer>::Objective()
+{
+  double l11NormZ = sum(sum(abs(codes)));
+  double froNormResidual = norm(data - dictionary * codes, "fro");
+
+  if (lambda2 > 0)
+  {
+    double froNormZ = norm(codes, "fro");
+    return 0.5 *
+      (froNormResidual * froNormResidual + lambda2 * froNormZ * froNormZ) +
+      lambda1 * l11NormZ;
+  }
+  else
+  {
+    return 0.5 * froNormResidual * froNormResidual + lambda1 * l11NormZ;
+  }
+}
+
+void RemoveRows(const arma::mat& X,
+                const arma::uvec& rowsToRemove,
+                arma::mat& modX)
+{
+  const size_t cols = X.n_cols;
+  const size_t rows = X.n_rows;
+  const size_t nRemove = rowsToRemove.n_elem;
+  const size_t nKeep = rows - nRemove;
+
+  if (nRemove == 0)
+  {
+    modX = X;
+  }
+  else
+  {
+    modX.set_size(nKeep, cols);
+
+    size_t curRow = 0;
+    size_t removeInd = 0;
+    // First, check 0 to first row to remove.
+    if (rowsToRemove(0) > 0)
+    {
+      // Note that this implies that n_rows > 1.
+      size_t height = rowsToRemove(0);
+      modX(arma::span(curRow, curRow + height - 1), arma::span::all) =
+          X(arma::span(0, rowsToRemove(0) - 1), arma::span::all);
+      curRow += height;
+    }
+    // Now, check i'th row to remove to (i + 1)'th row to remove, until i is the
+    // penultimate row.
+    while (removeInd < nRemove - 1)
+    {
+      size_t height = rowsToRemove[removeInd + 1] -
+          rowsToRemove[removeInd] - 1;
+
+      if (height > 0)
+      {
+        modX(arma::span(curRow, curRow + height - 1), arma::span::all) =
+            X(arma::span(rowsToRemove[removeInd] + 1,
+            rowsToRemove[removeInd + 1] - 1), arma::span::all);
+        curRow += height;
+      }
+
+      removeInd++;
+    }
+
+    // Now that i is the last row to remove, check last row to remove to last
+    // row.
+    if (rowsToRemove[removeInd] < rows - 1)
+    {
+      modX(arma::span(curRow, nKeep - 1), arma::span::all) =
+          X(arma::span(rowsToRemove[removeInd] + 1, rows - 1),
+          arma::span::all);
+    }
+  }
+}
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp	2012-06-11 19:23:47 UTC (rev 13023)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp	2012-06-11 20:10:20 UTC (rev 13024)
@@ -63,14 +63,10 @@
     matX.col(i) /= norm(matX.col(i), 2);
 
   // Run the sparse coding algorithm.
-  SparseCoding sc(matX, nAtoms, lambda1, lambda2);
+  SparseCoding<> sc(matX, nAtoms, lambda1, lambda2);
 
-  if (strlen(initialDictionaryFullpath) == 0)
+  if (strlen(initialDictionaryFullpath) != 0)
   {
-    sc.DataDependentRandomInitDictionary();
-  }
-  else
-  {
     mat matInitialD;
     data::Load(initialDictionaryFullpath, matInitialD);
 




More information about the mlpack-svn mailing list