[mlpack-svn] r11670 - in mlpack/trunk/src/mlpack/methods: . sparse_coding

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 29 18:28:53 EST 2012


Author: niche
Date: 2012-02-29 18:28:52 -0500 (Wed, 29 Feb 2012)
New Revision: 11670

Added:
   mlpack/trunk/src/mlpack/methods/sparse_coding/
   mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp
   mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
Log:
added sparse coding

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/CMakeLists.txt	2012-02-29 23:28:52 UTC (rev 11670)
@@ -0,0 +1,25 @@
+cmake_minimum_required(VERSION 2.8)
+
+# Define the files we need to compile
+# Anything not in this list will not be compiled into the output library
+set(SOURCES
+  sparse_coding.hpp
+  sparse_coding.cpp
+)
+
+# add directory name to sources
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+
+# append sources (with directory name) to list of all MLPACK sources (used at the parent scope)
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_executable(mnist_sc
+  mnist_sc.cc
+)
+target_link_libraries(mnist_sc
+  mlpack
+)
+install(TARGETS mnist_sc RUNTIME DESTINATION bin)

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.cpp	2012-02-29 23:28:52 UTC (rev 11670)
@@ -0,0 +1,451 @@
+/**
+ * @file sparse_coding.cpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or 
+ * l1+l2 (Elastic Net) regularization
+ */
+
+#include "sparse_coding.hpp"
+
+using namespace std;
+using namespace arma;
+//using namespace mlpack;
+using namespace mlpack::regression;
+
+#define OBJ_TOL 1e-2 // 1E-9
+#define NEWTON_TOL 1e-6 // 1E-9
+
+namespace mlpack {
+namespace sparse_coding {
+
+/*
+void SparseCoding::Init(double* memX, u32 nDims, u32 nPoints,
+			u32 nAtoms, double lambda1) {
+  matX = mat(memX, nDims, nPoints, false, true);
+
+  this->nDims = nDims;
+  this->nPoints = nPoints;
+
+  this->nAtoms = nAtoms;
+  //matD = mat(nDims, nAtoms);
+  matZ = mat(nAtoms, nPoints);
+  
+  this->lambda1 = lambda1;
+  lambda2 = 0;
+}
+*/
+
+/*
+void SparseCoding::SetDictionary(double* memD) {
+  matD = mat(memD, nDims, nAtoms, false, true);
+}
+*/
+
+
+
+
+
+SparseCoding::SparseCoding(const mat& matX, u32 nAtoms, double lambda1, double lambda2) :
+  nDims(matX.n_rows),  
+  nAtoms(nAtoms),
+  nPoints(matX.n_cols),
+  matX(matX),
+  matZ(mat(nAtoms, nPoints)),
+  lambda1(lambda1),
+  lambda2(lambda2)
+{ /* nothing left to do */ }
+  
+  
+void SparseCoding::SetDictionary(const mat& D) {
+  matD = D;
+}
+
+
+void SparseCoding::InitDictionary() {  
+  DataDependentRandomInitDictionary();
+}
+
+
+void SparseCoding::LoadDictionary(const char* dictionaryFilename) {  
+  matD.load(dictionaryFilename);
+}
+
+// always a not good decision!
+void SparseCoding::RandomInitDictionary() {
+  matD = randn(nDims, nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    matD.col(j) /= norm(matD.col(j), 2);
+  }
+}
+
+// the sensible heuristic
+void SparseCoding::DataDependentRandomInitDictionary() {
+  matD = mat(nDims, nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    vec vecD_j = matD.unsafe_col(j);
+    RandomAtom(vecD_j);
+  }
+}
+
+
+void SparseCoding::RandomAtom(vec& atom) {
+  atom.zeros();
+  const u32 nSeedAtoms = 3;
+  for(u32 i = 0; i < nSeedAtoms; i++) {
+    atom +=  matX.col(rand() % nPoints);
+  }
+  atom /= norm(atom, 2);
+}
+
+
+void SparseCoding::DoSparseCoding(u32 nIterations) {
+
+  bool converged = false;
+  double lastObjVal = 1e99;  
+
+  Log::Info << "Initial Coding Step" << endl;
+  OptimizeCode();
+  uvec adjacencies = find(matZ);
+  Log::Info << "\tSparsity level: " 
+	    << 100.0 * ((double)(adjacencies.n_elem)) 
+                     / ((double)(nAtoms * nPoints))
+	    << "%\n";
+  Log::Info << "\tObjective value: " << Objective() << endl;
+  
+  for(u32 t = 1; t <= nIterations && !converged; t++) {
+    Log::Info << "Iteration " << t << " of " << nIterations << endl;
+
+    Log::Info << "Dictionary Step\n";
+    OptimizeDictionary(adjacencies);
+    //ProjectDictionary(); // is this necessary? solutions to OptimizeDictionary should be feasible
+    Log::Info << "\tObjective value: " << Objective() << endl;
+
+    Log::Info << "Coding Step" << endl;
+    OptimizeCode();
+    adjacencies = find(matZ);
+    Log::Info << "\tSparsity level: " 
+	      << 100.0 * ((double)(adjacencies.n_elem)) 
+                       / ((double)(nAtoms * nPoints))
+	      << "%\n";
+    double curObjVal = Objective();
+    Log::Info << "\tObjective value: " << curObjVal << endl;
+    
+    double objValImprov = lastObjVal - curObjVal;
+    Log::Info << "\t\t\t\t\tImprovement: " << std::scientific
+	      <<  objValImprov << endl;
+    if(objValImprov < OBJ_TOL) {
+      converged = true;
+      Log::Info << "Converged within tolerance\n";
+    }
+    
+    lastObjVal = curObjVal;
+  }
+}
+
+
+void SparseCoding::OptimizeCode() {
+  mat matGram;
+  if(lambda2 > 0) {
+    matGram = trans(matD) * matD + lambda2 * eye(nAtoms, nAtoms);
+  }
+  else {
+    matGram = trans(matD) * matD;
+  }
+  
+  
+  for(u32 i = 0; i < nPoints; i++) {
+    // report progress
+    if((i % 100) == 0) {
+      Log::Debug << "\t" << i << endl;
+    }
+    
+    //Lars lars;
+    // do we still need 0.5 * lambda? no, because we're using the standard objective now, which includes 0.5 scaling for quadratic terms
+    //lars.Init(D.memptr(), matX.colptr(i), nDims, nAtoms, true, lambda1); // apparently not as fast as using the below duo
+                                                                                       // this may change, depending on the dimensionality and sparsity
+
+    // the duo
+    //lars.Init(matD.memptr(), matX.colptr(i), nDims, nAtoms, false, lambda1);
+    //lars.SetGram(matGram.memptr(), nAtoms);
+    //lars.DoLARS();
+ 
+
+    bool useCholesky = false;
+    LARS* lars;
+    if(lambda2 > 0) {
+      lars = new LARS(useCholesky, lambda1, lambda2);
+    }
+    else {
+      lars = new LARS(useCholesky, lambda1);
+    }
+    lars -> SetGram(matGram);
+    lars -> DoLARS(matD, matX.unsafe_col(i));
+    
+    vec beta;
+    lars -> Solution(beta);
+    matZ.col(i) = beta;
+    delete lars;
+  }
+}
+
+
+void SparseCoding::OptimizeDictionary(uvec adjacencies) {
+  // count number of atomic neighbors for each point x^i
+  uvec neighborCounts = zeros<uvec>(nPoints, 1);
+  if(adjacencies.n_elem > 0) {
+    // this gets the column index
+    u32 curPointInd = (u32)(adjacencies(0) / nAtoms);
+    u32 curCount = 1;
+    for(u32 l = 1; l < adjacencies.n_elem; l++) {
+      if((u32)(adjacencies(l) / nAtoms) == curPointInd) {
+	curCount++;
+      }
+      else {
+	neighborCounts(curPointInd) = curCount;
+	curPointInd = (u32)(adjacencies(l) / nAtoms);
+	curCount = 1;
+      }
+    }
+    neighborCounts(curPointInd) = curCount;
+  }
+  
+  // handle the case of inactive atoms (atoms not used in the given coding)
+  std::vector<u32> inactiveAtoms;
+  std::vector<u32> activeAtoms;
+  activeAtoms.reserve(nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    if(accu(matZ.row(j) != 0) == 0) {
+      inactiveAtoms.push_back(j);
+    }
+    else {
+      activeAtoms.push_back(j);
+    }
+  }
+  u32 nActiveAtoms = activeAtoms.size();
+  u32 nInactiveAtoms = inactiveAtoms.size();
+
+  // efficient construction of Z restricted to active atoms
+  mat matActiveZ;
+  if(inactiveAtoms.empty()) {
+    matActiveZ = matZ;
+  }
+  else {
+    uvec inactiveAtomsVec = conv_to< uvec >::from(inactiveAtoms);
+    RemoveRows(matZ, inactiveAtomsVec, matActiveZ);
+  }
+  
+  uvec atomReverseLookup = uvec(nAtoms);
+  for(u32 i = 0; i < nActiveAtoms; i++) {
+    atomReverseLookup(activeAtoms[i]) = i;
+  }
+
+  if(nInactiveAtoms > 0) {
+    Log::Info << "There are " << nInactiveAtoms << " inactive atoms. They will be re-initialized randomly.\n";
+  }
+  
+  
+  Log::Debug << "Solving Dual via Newton's Method\n";
+  
+  mat matDEstimate;
+  // solve using Newton's method in the dual - note that the final dot multiplication with inv(A) seems to be unavoidable. Although more expensive, the code written this way (we use solve()) should be more numerically stable than just using inv(A) for everything.
+  vec dualVars = zeros<vec>(nActiveAtoms);
+  //vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
+  //vec dualVars = 10.0 * randu(nActiveAtoms, 1); // method used by feature sign code - fails miserably here. perhaps the MATLAB optimizer fmincon does something clever?
+  /*vec dualVars = diagvec(solve(matD, matX * trans(matZ)) - matZ * trans(matZ));
+  for(u32 i = 0; i < dualVars.n_elem; i++) {
+    if(dualVars(i) < 0) {
+      dualVars(i) = 0;
+    }
+  }
+  */
+  //dualVars.print("dual vars");
+
+  bool converged = false;
+  mat matZXT = matActiveZ * trans(matX);
+  mat matZZT = matActiveZ * trans(matActiveZ);
+  for(u32 t = 1; !converged; t++) {
+    mat A = matZZT + diagmat(dualVars);
+    
+    mat matAInvZXT = solve(A, matZXT);
+    
+    vec gradient = -( sum(square(matAInvZXT), 1) - ones<vec>(nActiveAtoms) );
+    
+    mat hessian = 
+      -( -2 * (matAInvZXT * trans(matAInvZXT)) % inv(A) );
+    
+    //printf("solving for dual variable update...");
+    vec searchDirection = -solve(hessian, gradient);
+    //vec searchDirection = -gradient;
+
+ 
+    
+    // BEGIN ARMIJO LINE SEARCH
+    const double c = 1e-4;
+    double alpha = 1.0;
+    const double rho = 0.9;
+    double sufficientDecrease = c * dot(gradient, searchDirection);
+
+    /*
+    {
+      double sumDualVars = sum(dualVars);    
+      double fOld = 
+	-( -trace(trans(matZXT) * matAInvZXT) - sumDualVars );
+      printf("fOld = %f\t", fOld);
+      double fNew = 
+	-( -trace(trans(matZXT) * solve(matZZT + diagmat(dualVars + alpha * searchDirection), matZXT))
+	  - (sumDualVars + alpha * sum(searchDirection)) );
+      printf("fNew = %f\n", fNew);
+    }
+    */
+    
+    double improvement;
+    while(true) {
+      // objective
+      double sumDualVars = sum(dualVars);
+      double fOld = 
+	-( -trace(trans(matZXT) * matAInvZXT) - sumDualVars );
+      double fNew = 
+	-( -trace(trans(matZXT) * solve(matZZT + diagmat(dualVars + alpha * searchDirection), matZXT))
+	   - (sumDualVars + alpha * sum(searchDirection)) );
+
+      // printf("alpha = %e\n", alpha);
+      // printf("norm of gradient = %e\n", norm(gradient, 2));
+      // printf("sufficientDecrease = %e\n", sufficientDecrease);
+      // printf("fNew - fOld - sufficientDecrease = %e\n", 
+      // 	     fNew - fOld - alpha * sufficientDecrease);
+      if(fNew <= fOld + alpha * sufficientDecrease) {
+	searchDirection = alpha * searchDirection;
+	improvement = fOld - fNew;
+	break;
+      }
+      alpha *= rho;
+    }
+    // END ARMIJO LINE SEARCH
+    
+    dualVars += searchDirection;
+    //printf("\n");
+    double normGradient = norm(gradient, 2);
+    Log::Debug << "Newton Iteration " << t << ":" << endl;
+    Log::Debug << "\tnorm of gradient = " << std::scientific << normGradient << endl;
+    Log::Debug << "\timprovement = " << std::scientific << improvement << endl;
+
+    // if(normGradient < NEWTON_TOL) {
+    //   converged = true;
+    // }
+    if(improvement < NEWTON_TOL) {
+      converged = true;
+    }
+  }
+  //dualVars.print("dual solution");
+  if(inactiveAtoms.empty()) {
+    matDEstimate = trans(solve(matZZT + diagmat(dualVars), matZXT));
+  }
+  else {
+    mat matDActiveEstimate = trans(solve(matZZT + diagmat(dualVars), matZXT));
+    matDEstimate = zeros(nDims, nAtoms);
+    for(u32 i = 0; i < nActiveAtoms; i++) {
+      matDEstimate.col(activeAtoms[i]) = matDActiveEstimate.col(i);
+    }
+    for(u32 i = 0; i < nInactiveAtoms; i++) {
+      vec vecmatDi = matDEstimate.unsafe_col(inactiveAtoms[i]);
+      RandomAtom(vecmatDi);
+    }
+  }
+  matD = matDEstimate;
+}
+
+
+void SparseCoding::ProjectDictionary() {
+  for(u32 j = 0; j < nAtoms; j++) {
+    double normD_j = norm(matD.col(j), 2);
+    if(normD_j > 1) {
+      if(normD_j - 1.0 > 1e-9) {
+	Log::Warn << "Norm Exceeded 1 by " << std::scientific << normD_j - 1.0
+		  << "\n\tShrinking...\n";
+	matD.col(j) /= normD_j;
+      }
+      // no need to normalize if the dictionary wasn't that infeasible
+      //matD.col(j) /= normD_j;
+    }
+  }
+}
+
+
+double SparseCoding::Objective() {
+  double l11NormZ = sum(sum(abs(matZ)));
+  double froNormResidual = norm(matX - matD * matZ, "fro");
+  if(lambda2 > 0) {
+    double froNormZ = norm(matZ, "fro");
+    return 
+      0.5 * (froNormResidual * froNormResidual + lambda2 * froNormZ * froNormZ)
+      + lambda1 * l11NormZ;
+  }
+  else {
+    return 0.5 * froNormResidual * froNormResidual + lambda1 * l11NormZ;
+  }
+}
+
+
+void SparseCoding::PrintDictionary() {
+  matD.print("Dictionary");
+}
+
+
+void SparseCoding::PrintCoding() {
+  matZ.print("Coding matrix");
+}
+
+
+void RemoveRows(const mat& X, uvec rows_to_remove, mat& X_mod) {
+
+  u32 n_cols = X.n_cols;
+  u32 n_rows = X.n_rows;
+  u32 n_to_remove = rows_to_remove.n_elem;
+  u32 n_to_keep = n_rows - n_to_remove;
+  
+  if(n_to_remove == 0) {
+    X_mod = X;
+  }
+  else {
+    X_mod.set_size(n_to_keep, n_cols);
+
+    u32 cur_row = 0;
+    u32 remove_ind = 0;
+    // first, check 0 to first row to remove
+    if(rows_to_remove(0) > 0) {
+      // note that this implies that n_rows > 1
+      u32 height = rows_to_remove(0);
+      X_mod(span(cur_row, cur_row + height - 1), span::all) =
+	X(span(0, rows_to_remove(0) - 1), span::all);
+      cur_row += height;
+    }
+    // now, check i'th row to remove to (i + 1)'th row to remove, until i = penultimate row
+    while(remove_ind < n_to_remove - 1) {
+      u32 height = 
+	rows_to_remove[remove_ind + 1]
+	- rows_to_remove[remove_ind]
+	- 1;
+      if(height > 0) {
+	X_mod(span(cur_row, cur_row + height - 1), 
+	      span::all) =
+	  X(span(rows_to_remove[remove_ind] + 1,
+		 rows_to_remove[remove_ind + 1] - 1), 
+	    span::all);
+	cur_row += height;
+      }
+      remove_ind++;
+    }
+    // now that i is last row to remove, check last row to remove to last row
+    if(rows_to_remove[remove_ind] < n_rows - 1) {
+      X_mod(span(cur_row, n_to_keep - 1), 
+	    span::all) = 
+	X(span(rows_to_remove[remove_ind] + 1, n_rows - 1), 
+	  span::all);
+    }
+  }
+}
+
+
+}; // namespace sparse_coding
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding.hpp	2012-02-29 23:28:52 UTC (rev 11670)
@@ -0,0 +1,220 @@
+/** 
+ * @file sparse_coding.hpp
+ * @author Nishant Mehta
+ *
+ * Definition of the SparseCoding class, which performs l1 (LASSO) or 
+ * l1+l2 (Elastic Net)-regularized sparse coding with dictionary learning
+ *
+ *  @bug Could be lots, let's see!
+ */
+
+#ifndef __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
+#define __MLPACK_METHODS_SPARSE_CODING_SPARSE_CODING_HPP
+
+//#include <armadillo>
+#include <mlpack/core.hpp>
+#include <mlpack/methods/lars/lars.hpp>
+
+namespace mlpack {
+namespace sparse_coding {
+
+/**
+ * An implementation of Sparse Coding with Dictionary Learning that achieves 
+ * sparsity via an l1-norm regularizer on the codes (LASSO) or an (l1+l2)-norm 
+ * regularizer on the codes (the Elastic Net).
+ * Let d be the number of dimensions in the original space, m the number of 
+ * training points, and k the number of atoms in the dictionary (the dimension 
+ * of the learned feature space). The training data X is a d-by-m matrix where 
+ * each column is a point and each row is a dimension. The dictionary D is a 
+ * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
+ * This program seeks to minimize the objective:
+ * min_{D,Z} 0.5 ||X - D Z||_{Fro}^2\ + lambda_1 sum_{i=1}^m ||Z_i||_1
+ *                                    + 0.5 lambda_2 sum_{i=1}^m ||Z_i||_2^2
+ * subject to ||D_j||_2 <= 1 for 1 <= j <= k
+ * where typically lambda_1 > 0 and lambda_2 = 0.
+ *
+ * This problem is solved by an algorithm that alternates between a dictionary
+ * learning step and a sparse coding step. The dictionary learning step updates 
+ * the dictionary D using a Newton method based on the Lagrange dual (see the 
+ * paper below for details). The sparse coding step involves solving a large 
+ * number of sparse linear regression problems; this can be done efficiently 
+ * using LARS, an algorithm that can solve the LASSO or the Elastic Net (papers below).
+ *
+ * Here are those papers:
+ *
+ * @code
+ * @incollection{lee2007efficient,
+ *   title = {Efficient sparse coding algorithms},
+ *   author = {Honglak Lee and Alexis Battle and Rajat Raina and Andrew Y. Ng},
+ *   booktitle = {Advances in Neural Information Processing Systems 19},
+ *   editor = {B. Sch\"{o}lkopf and J. Platt and T. Hoffman},
+ *   publisher = {MIT Press},
+ *   address = {Cambridge, MA},
+ *   pages = {801--808},
+ *   year = {2007}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{efron2004least,
+ *   title={Least angle regression},
+ *   author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
+ *   journal={The Annals of statistics},
+ *   volume={32},
+ *   number={2},
+ *   pages={407--499},
+ *   year={2004},
+ *   publisher={Institute of Mathematical Statistics}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{zou2005regularization,
+ *   title={Regularization and variable selection via the elastic net},
+ *   author={Zou, H. and Hastie, T.},
+ *   journal={Journal of the Royal Statistical Society Series B},
+ *   volume={67},
+ *   number={2},
+ *   pages={301--320},
+ *   year={2005},
+ *   publisher={Royal Statistical Society}
+ * }
+ * @endcode
+ */
+class SparseCoding {
+
+ public:
+  // void Init(double* memX, u32 nDims, u32 nPoints,
+  // 	    u32 nAtoms, double lambda1);
+
+  //void SetDictionary(double* memD);
+
+  
+  /**
+   * Set the parameters to SparseCoding. lambda2 defaults to 0.
+   *
+   * @param matX Data matrix
+   * @param nAtoms Number of atoms in dictionary
+   * @param lambda1 Regularization parameter for l1-norm penalty
+   * @param lambda2 Regularization parameter for l2-norm penalty
+   */
+  SparseCoding(const arma::mat& matX, arma::u32 nAtoms, double lambda1, double lambda2 = 0);
+  
+
+  /**
+   * Initialize dictionary somehow
+   */
+  void InitDictionary();
+  
+  /** 
+   * Load dictionary from a file
+   * 
+   * @param dictionaryFilename Filename containing dictionary
+   */
+  void LoadDictionary(const char* dictionaryFilename);
+  
+  /**
+   * Initialize dictionary by drawing k vectors uniformly at random from the 
+   * unit sphere
+   */
+  void RandomInitDictionary();
+
+  /**
+   * Initialize dictionary by initializing each atom to a normalized mixture of
+   * a small number of randomly selected points in X
+   */
+  void DataDependentRandomInitDictionary();
+
+  /**
+   * Initialize an atom to a normalized mixture of a small number of randomly
+   * selected points in X
+   *
+   * @param atom The atom to initialize
+   */
+  void RandomAtom(arma::vec& atom);
+
+  
+  // core algorithm functions
+
+  /**
+   * Run Sparse Coding with Dictionary Learning
+   *
+   * @param nIterations Maximum number of iterations to run algorithm
+   */
+  void DoSparseCoding(arma::u32 nIterations);
+
+  /**
+   * Sparse code each point via LARS
+   */
+  void OptimizeCode();
+  
+  /** 
+   * Learn dictionary via Newton method based on Lagrange dual
+   *
+   * @param adjacencies Indices of entries (unrolled column by column) of 
+   *    the coding matrix Z that are non-zero (the adjacency matrix for the 
+   *    bipartite graph of points and atoms)
+   */
+  void OptimizeDictionary(arma::uvec adjacencies);
+
+  /**
+   * Project each atom of the dictionary onto the unit ball
+   */
+  void ProjectDictionary();
+
+  /**
+   * Compute objective function
+   */
+  double Objective();
+
+
+  // accessors, modifiers, printers
+
+  //! Modifier for matD
+  void SetDictionary(const arma::mat& matD);
+  
+  //! Accessor for matD
+  const arma::mat& MatD() {
+    return matD;
+  }
+  
+  //! Accessor for matZ
+  const arma::mat& MatZ() {
+    return matZ;
+  }
+
+  // Print the dictionary matD
+  void PrintDictionary();
+
+  // Print the sparse codes matZ
+  void PrintCoding();
+
+ private:
+  arma::u32 nDims;
+  arma::u32 nAtoms;
+  arma::u32 nPoints;
+
+  // data (columns are points)
+  arma::mat matX;
+
+  // dictionary (columns are atoms)
+  arma::mat matD; 
+  
+  // sparse codes (columns are points)
+  arma::mat matZ; 
+  
+  // l1 regularization term
+  double lambda1; 
+  
+  // l2 regularization term
+  double lambda2; 
+  
+};
+
+void RemoveRows(const arma::mat& X, arma::uvec rows_to_remove, arma::mat& X_mod);
+
+
+}; // namespace sparse_coding
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_main.cpp	2012-02-29 23:28:52 UTC (rev 11670)
@@ -0,0 +1,108 @@
+/** @file mnist_sc
+ *  @author Nishant Mehta
+ *
+ *  Executable for Sparse Coding on MNIST
+ */
+
+#include <mlpack/core.hpp>
+#include "sparse_coding.hpp"
+
+
+
+PARAM_DOUBLE_REQ("lambda1", "sparse coding l1-norm regularization parameter.", "l");
+PARAM_DOUBLE("lambda2", "sparse coding l2-norm regularization parameter.", "", 0);
+
+PARAM_INT_REQ("n_atoms", "number of atoms in dictionary.", "k");
+
+PARAM_INT_REQ("n_iterations", "number of iterations for sparse coding.", "");
+
+PARAM_STRING_REQ("data", "path to the input data.", "");
+PARAM_STRING("initial_dictionary", "Filename for initial dictionary.", "", "");
+PARAM_STRING("results_dir", "Directory for results.", "", "");
+
+
+
+using namespace arma;
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::sparse_coding;
+
+int main(int argc, char* argv[]) {
+  CLI::ParseCommandLine(argc, argv);
+  
+  std::srand(time(NULL));
+  
+  double lambda1 = CLI::GetParam<double>("lambda1");
+  double lambda2 = CLI::GetParam<double>("lambda2");
+  
+  // if using fx-run, one could just leave results_dir blank
+  const char* resultsDir = CLI::GetParam<string>("results_dir").c_str();
+  
+  const char* dataFullpath = CLI::GetParam<string>("data").c_str();
+
+  const char* initialDictionaryFullpath = 
+    CLI::GetParam<string>("initial_dictionary").c_str();
+  
+  size_t nIterations = CLI::GetParam<int>("n_iterations");
+  
+  size_t nAtoms = CLI::GetParam<int>("n_atoms");
+  
+  mat matX;
+  matX.load(dataFullpath);
+  
+  u32 nPoints = matX.n_cols;
+  printf("Loaded %d points in %d dimensions\n", nPoints, matX.n_rows);
+
+  // normalize each point since these are images
+  for(u32 i = 0; i < nPoints; i++) {
+    matX.col(i) /= norm(matX.col(i), 2);
+  }
+  
+  // run Sparse Coding
+  SparseCoding sc(matX, nAtoms, lambda1, lambda2);
+  
+  if(strlen(initialDictionaryFullpath) == 0) {
+    sc.DataDependentRandomInitDictionary();
+  }
+  else {
+    mat matInitialD;
+    matInitialD.load(initialDictionaryFullpath);
+    if(matInitialD.n_cols != nAtoms) {
+      Log::Fatal << "The specified initial dictionary to load has " 
+		 << matInitialD.n_cols << " atoms, but the learned dictionary "
+		 << "was specified to have " << matInitialD.n_cols << " atoms!\n";
+      return EXIT_FAILURE;
+    }
+    if(matInitialD.n_rows != matX.n_rows) {
+      Log::Fatal << "The specified initial dictionary to load has "
+		 << matInitialD.n_rows << " dimensions, but the specified data "
+		 << "has " << matX.n_rows << " dimensions!\n";
+      return EXIT_FAILURE;
+    }
+    sc.SetDictionary(matInitialD);
+  }
+  
+  
+  Timer::Start("sparse_coding");
+  sc.DoSparseCoding(nIterations);
+  Timer::Stop("sparse_coding"); 
+  
+  mat learnedD = sc.MatD();
+  mat learnedZ =  sc.MatZ();
+
+  if(strlen(resultsDir) == 0) {
+    data::Save("D.csv", learnedD);
+    data::Save("Z.csv", learnedZ);
+  }
+  else {
+    char* dataFullpath = (char*) malloc(320 * sizeof(char));
+
+    sprintf(dataFullpath, "%s/D.csv", resultsDir);
+    data::Save(dataFullpath, learnedD);
+    
+    sprintf(dataFullpath, "%s/Z.csv", resultsDir);
+    data::Save(dataFullpath, learnedZ);
+    
+    free(dataFullpath);
+  }
+}




More information about the mlpack-svn mailing list