[mlpack-svn] r11682 - in mlpack/trunk/src/mlpack/methods: . local_coordinate_coding

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 1 11:56:54 EST 2012


Author: niche
Date: 2012-03-01 11:56:53 -0500 (Thu, 01 Mar 2012)
New Revision: 11682

Added:
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
Log:
added local coordinate coding to methods

Added: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	2012-03-01 16:56:53 UTC (rev 11682)
@@ -0,0 +1,377 @@
+/**
+ * @file lcc.cpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Local Coordinate Coding
+ */
+
+#include "lcc.hpp"
+
+using namespace arma;
+using namespace std;
+using namespace mlpack::regression;
+using namespace mlpack::lcc;
+
+#define OBJ_TOL 1e-2 // 1E-9
+
+namespace mlpack {
+namespace lcc {
+
+LocalCoordinateCoding::LocalCoordinateCoding(const mat& matX, u32 nAtoms, double lambda) :
+  nDims(matX.n_rows),
+  nAtoms(nAtoms),
+  nPoints(matX.n_cols),
+  matX(matX),
+  matZ(mat(nAtoms, nPoints)),
+  lambda(lambda)
+{ /* nothing left to do */ }
+
+
+// void LocalCoordinateCoding::Init(const mat& matX, u32 nAtoms, double lambda) {
+//   this->matX = matX;
+
+//   nDims = matX.n_rows;
+//   nPoints = matX.n_cols;
+
+//   this->nAtoms = nAtoms;
+//   matD = mat(nDims, nAtoms);
+//   matZ = mat(nAtoms, nPoints);
+  
+//   this->lambda = lambda;
+// }
+
+
+void LocalCoordinateCoding::SetDictionary(const mat& matD) {
+  this->matD = matD;
+}
+
+
+void LocalCoordinateCoding::InitDictionary() {  
+  RandomInitDictionary();
+}
+
+
+void LocalCoordinateCoding::LoadDictionary(const char* dictionaryFilename) {  
+  matD.load(dictionaryFilename);
+}
+
+
+void LocalCoordinateCoding::RandomInitDictionary() {
+  matD = randn(nDims, nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    matD.col(j) /= norm(matD.col(j), 2);
+  }
+}
+
+
+void LocalCoordinateCoding::DataDependentRandomInitDictionary() {
+  matD = mat(nDims, nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    vec vecD_j = matD.unsafe_col(j);
+    RandomAtom(vecD_j);
+  }
+}
+
+
+void LocalCoordinateCoding::RandomAtom(vec& atom) {
+  atom.zeros();
+  const u32 nSeedAtoms = 3;
+  for(u32 i = 0; i < nSeedAtoms; i++) {
+    atom +=  matX.col(rand() % nPoints);
+  }
+  atom /= ((double) nSeedAtoms);
+  atom /= norm(atom, 2);
+}
+
+
+void LocalCoordinateCoding::DoLCC(u32 nIterations) {
+
+  bool converged = false;
+  double lastObjVal = 1e99;
+  
+  Log::Info << "Initial Coding Step" << endl;
+  OptimizeCode();
+  uvec adjacencies = find(matZ);
+  Log::Info << "\tSparsity level: " 
+	    << 100.0 * ((double)(adjacencies.n_elem)) 
+                     / ((double)(nAtoms * nPoints))
+	    << "%\n";
+  Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
+  
+  for(u32 t = 1; t <= nIterations && !converged; t++) {
+    Log::Info << "Iteration " << t << " of " << nIterations << endl;
+
+    Log::Info << "Dictionary Step\n";
+    OptimizeDictionary(adjacencies);
+    double dsObjVal = Objective(adjacencies);
+    Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
+    
+    Log::Info << "Coding Step" << endl;
+    OptimizeCode();
+    adjacencies = find(matZ);
+    Log::Info << "\tSparsity level: " 
+	      << 100.0 * ((double)(adjacencies.n_elem)) 
+                       / ((double)(nAtoms * nPoints))
+	      << "%\n";
+    double curObjVal = Objective(adjacencies);
+    Log::Info << "\tObjective value: " << curObjVal << endl;
+
+    if(curObjVal > dsObjVal) {
+      Log::Fatal << "Objective increased in sparse coding step!" << endl;
+    }
+    
+    double objValImprov = lastObjVal - curObjVal;
+    Log::Info << "\t\t\t\t\tImprovement: " << std::scientific
+	      <<  objValImprov << endl;
+    if(objValImprov < OBJ_TOL) {
+      converged = true;
+      Log::Info << "Converged within tolerance\n";
+    }
+    
+    lastObjVal = curObjVal;
+  }
+}
+
+
+void LocalCoordinateCoding::OptimizeCode() {
+  mat matSqDists = 
+    repmat(trans(sum(square(matD))), 1, nPoints)
+    + repmat(sum(square(matX)), nAtoms, 1)
+    - 2 * trans(matD) * matX;			     
+  
+  mat matInvSqDists = 1.0 / matSqDists;
+  
+  mat matDTD = trans(matD) * matD;
+  mat matDPrimeTDPrime(matDTD.n_rows, matDTD.n_cols);
+  
+  for(u32 i = 0; i < nPoints; i++) {
+    // report progress
+    if((i % 100) == 0) {
+      Log::Debug << "\t" << i << endl;
+    }
+    
+    vec w = matSqDists.unsafe_col(i);
+    vec invW = matInvSqDists.unsafe_col(i);
+    mat matDPrime = matD * diagmat(invW);
+    
+    mat matDPrimeTDPrime = diagmat(invW) * matDTD * diagmat(invW);
+    
+    //LARS lars;
+    // do we still need 0.5 * lambda? yes, yes we do
+    //lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, true, 0.5 * lambda); // apparently not as fast as using the below duo
+                                                                                       // this may change, depending on the dimensionality and sparsity
+
+    // the duo
+    /* lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, false, 0.5 * lambda); */
+    /* lars.SetGram(matDPrimeTDPrime.memptr(), nAtoms); */
+    
+    bool useCholesky = false;
+    LARS lars(useCholesky, 0.5 * lambda);
+    lars.SetGram(matDPrimeTDPrime);
+    
+    lars.DoLARS(matDPrime, matX.unsafe_col(i));
+    vec beta;
+    lars.Solution(beta);
+    matZ.col(i) = beta % invW;
+  }
+}
+
+
+void LocalCoordinateCoding::OptimizeDictionary(uvec adjacencies) {
+  // count number of atomic neighbors for each point x^i
+  uvec neighborCounts = zeros<uvec>(nPoints, 1);
+  if(adjacencies.n_elem > 0) {
+    // this gets the column index
+    u32 curPointInd = (u32)(adjacencies(0) / nAtoms);
+    u32 curCount = 1;
+    for(u32 l = 1; l < adjacencies.n_elem; l++) {
+      if((u32)(adjacencies(l) / nAtoms) == curPointInd) {
+	curCount++;
+      }
+      else {
+	neighborCounts(curPointInd) = curCount;
+	curPointInd = (u32)(adjacencies(l) / nAtoms);
+	curCount = 1;
+      }
+    }
+    neighborCounts(curPointInd) = curCount;
+  }
+  
+  // build matXPrime := [X x^1 ... x^1 ... x^n ... x^n]
+  // where each x^i is repeated for the number of neighbors x^i has
+  mat matXPrime = zeros(nDims, nPoints + adjacencies.n_elem);
+  matXPrime(span::all, span(0, nPoints - 1)) = matX;
+  u32 curCol = nPoints;
+  for(u32 i = 0; i < nPoints; i++) {
+    if(neighborCounts(i) > 0) {
+      matXPrime(span::all, span(curCol, curCol + neighborCounts(i) - 1)) =
+	repmat(matX.col(i), 1, neighborCounts(i));
+    }
+    curCol += neighborCounts(i);
+  }
+  
+  // handle the case of inactive atoms (atoms not used in the given coding)
+  std::vector<u32> inactiveAtoms;
+  std::vector<u32> activeAtoms;
+  activeAtoms.reserve(nAtoms);
+  for(u32 j = 0; j < nAtoms; j++) {
+    if(accu(matZ.row(j) != 0) == 0) {
+      inactiveAtoms.push_back(j);
+    }
+    else {
+      activeAtoms.push_back(j);
+    }
+  }
+  u32 nActiveAtoms = activeAtoms.size();
+  u32 nInactiveAtoms = inactiveAtoms.size();
+
+  // efficient construction of Z restricted to active atoms
+  mat matActiveZ;
+  if(inactiveAtoms.empty()) {
+    matActiveZ = matZ;
+  }
+  else {
+    uvec inactiveAtomsVec = conv_to< uvec >::from(inactiveAtoms);
+    RemoveRows(matZ, inactiveAtomsVec, matActiveZ);
+  }
+  
+  uvec atomReverseLookup = uvec(nAtoms);
+  for(u32 i = 0; i < nActiveAtoms; i++) {
+    atomReverseLookup(activeAtoms[i]) = i;
+  }
+
+
+  if(nInactiveAtoms > 0) {
+    Log::Info << "There are " << nInactiveAtoms << " inactive atoms. They will be re-initialized randomly.\n";
+  }
+  
+  mat matZPrime = zeros(nActiveAtoms, nPoints + adjacencies.n_elem);
+  //Log::Debug << "adjacencies.n_elem = " << adjacencies.n_elem << endl;
+  matZPrime(span::all, span(0, nPoints - 1)) = matActiveZ;
+  
+  vec wSquared = ones(nPoints + adjacencies.n_elem, 1);
+  //Log::Debug << "building up matZPrime\n";
+  for(u32 l = 0; l < adjacencies.n_elem; l++) {
+    u32 atomInd = adjacencies(l) % nAtoms;
+    u32 pointInd = (u32) (adjacencies(l) / nAtoms);
+    matZPrime(atomReverseLookup(atomInd), nPoints + l) = 1.0;
+    wSquared(nPoints + l) = matZ(atomInd, pointInd); 
+  }
+  
+  wSquared.subvec(nPoints, wSquared.n_elem - 1) = 
+    lambda * abs(wSquared.subvec(nPoints, wSquared.n_elem - 1));
+  
+  //Log::Debug << "about to solve\n";
+  mat matDEstimate;
+  if(inactiveAtoms.empty()) {
+    mat A = matZPrime * diagmat(wSquared) * trans(matZPrime);
+    mat B = matZPrime * diagmat(wSquared) * trans(matXPrime);
+    
+    //Log::Debug << "solving...\n";
+    matDEstimate = 
+      trans(solve(A, B));
+    /*    
+    matDEstimate = 
+      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
+		  matZPrime * diagmat(wSquared) * trans(matXPrime)));
+    */
+  }
+  else {
+    matDEstimate = zeros(nDims, nAtoms);
+    //Log::Debug << "solving...\n";
+    mat matDActiveEstimate = 
+      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
+		  matZPrime * diagmat(wSquared) * trans(matXPrime)));
+    for(u32 j = 0; j < nActiveAtoms; j++) {
+      matDEstimate.col(activeAtoms[j]) = matDActiveEstimate.col(j);
+    }
+    for(u32 j = 0; j < nInactiveAtoms; j++) {
+      vec vecD_j = matDEstimate.unsafe_col(inactiveAtoms[j]);
+      RandomAtom(vecD_j);
+      /*
+      vec new_atom = randn(nDims, 1);
+      matDEstimate.col(inactiveAtoms[i]) = 
+	new_atom / norm(new_atom, 2);
+      */
+    }
+  }
+  matD = matDEstimate;
+}
+// need to test above function, sleepy now, will resume soon!
+
+
+double LocalCoordinateCoding::Objective(uvec adjacencies) {
+  double weightedL1NormZ = 0;
+  u32 nAdjacencies = adjacencies.n_elem;
+  for(u32 l = 0; l < nAdjacencies; l++) {
+    u32 atomInd = adjacencies(l) % nAtoms;
+    u32 pointInd = (u32) (adjacencies(l) / nAtoms);
+    weightedL1NormZ += fabs(matZ(atomInd, pointInd)) * as_scalar(sum(square(matD.col(atomInd) - matX.col(pointInd))));
+  }
+  double froNormResidual = norm(matX - matD * matZ, "fro");
+  return froNormResidual * froNormResidual + lambda * weightedL1NormZ;
+}
+
+
+void LocalCoordinateCoding::PrintDictionary() {
+  matD.print("Dictionary");
+}
+
+
+void LocalCoordinateCoding::PrintCoding() {
+  matZ.print("Coding matrix");
+}
+
+
+void RemoveRows(const mat& X, uvec rows_to_remove, mat& X_mod) {
+
+  u32 n_cols = X.n_cols;
+  u32 n_rows = X.n_rows;
+  u32 n_to_remove = rows_to_remove.n_elem;
+  u32 n_to_keep = n_rows - n_to_remove;
+  
+  if(n_to_remove == 0) {
+    X_mod = X;
+  }
+  else {
+    X_mod.set_size(n_to_keep, n_cols);
+
+    u32 cur_row = 0;
+    u32 remove_ind = 0;
+    // first, check 0 to first row to remove
+    if(rows_to_remove(0) > 0) {
+      // note that this implies that n_rows > 1
+      u32 height = rows_to_remove(0);
+      X_mod(span(cur_row, cur_row + height - 1), span::all) =
+	X(span(0, rows_to_remove(0) - 1), span::all);
+      cur_row += height;
+    }
+    // now, check i'th row to remove to (i + 1)'th row to remove, until i = penultimate row
+    while(remove_ind < n_to_remove - 1) {
+      u32 height = 
+	rows_to_remove[remove_ind + 1]
+	- rows_to_remove[remove_ind]
+	- 1;
+      if(height > 0) {
+	X_mod(span(cur_row, cur_row + height - 1), 
+	      span::all) =
+	  X(span(rows_to_remove[remove_ind] + 1,
+		 rows_to_remove[remove_ind + 1] - 1), 
+	    span::all);
+	cur_row += height;
+      }
+      remove_ind++;
+    }
+    // now that i is last row to remove, check last row to remove to last row
+    if(rows_to_remove[remove_ind] < n_rows - 1) {
+      X_mod(span(cur_row, n_to_keep - 1), 
+	    span::all) = 
+	X(span(rows_to_remove[remove_ind] + 1, n_rows - 1), 
+	  span::all);
+    }
+  }
+}
+
+
+}; // namespace lcc
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp	2012-03-01 16:56:53 UTC (rev 11682)
@@ -0,0 +1,192 @@
+/** 
+ * @file lcc.hpp
+ * @author Nishant Mehta
+ *
+ * Definition of the LocalCoordinateCoding class, which performs the Local 
+ * Coordinate Coding algorithm
+ */
+
+#ifndef __MLPACK_METHODS_LCC_LCC_HPP
+#define __MLPACK_METHODS_LCC_LCC_HPP
+
+//#include <armadillo>
+#include <mlpack/core.hpp>
+#include <mlpack/methods/lars/lars.hpp>
+
+namespace mlpack {
+namespace lcc {
+
+/**
+ * An implementation of Local Coordinate Coding (LCC) that codes data which 
+ * approximately lives on a manifold using a variation of l1-norm regularized 
+ * sparse coding; in LCC, the penalty on the absolute value of each point's 
+ * coefficient for each atom is weighted by the squared distance of that point 
+ * to that atom. The paper is below. 
+ * Let d be the number of dimensions in the original space, m the number of 
+ * training points, and k the number of atoms in the dictionary (the dimension 
+ * of the learned feature space). The training data X is a d-by-m matrix where 
+ * each column is a point and each row is a dimension. The dictionary D is a 
+ * d-by-k matrix, and the sparse codes matrix Z is a k-by-m matrix.
+ * This program seeks to minimize the objective:
+ * min_{D,Z} ||X - D Z||_{Fro}^2 
+             + lambda sum_{i=1}^m sum_{j=1}^k dist(X_i,D_j)^2 Z_i^j
+ * where lambda > 0.
+ *
+ * This problem is solved by an algorithm that alternates between a dictionary
+ * learning step and a sparse coding step. The dictionary learning step updates
+ * the dictionary D by solving a linear system (note that the objective is a
+ * positive definite quadratic program). The sparse coding step involves 
+ * solving a large number of weighted l1-norm regularized linear regression 
+ * problems problems; this can be done efficiently using LARS, an algorithm 
+ * that can solve the LASSO (paper below).
+ *
+ * The papers:
+ *
+ * @incollection{NIPS2009_0719,
+ *   title = {Nonlinear Learning using Local Coordinate Coding},
+ *   author = {Kai Yu and Tong Zhang and Yihong Gong},
+ *   booktitle = {Advances in Neural Information Processing Systems 22},
+ *   editor = {Y. Bengio and D. Schuurmans and J. Lafferty and C. K. I. Williams and A. Culotta},
+ *   pages = {2223--2231},
+ *   year = {2009}
+ * }
+ * @endcode
+ *
+ * @code
+ * @article{efron2004least,
+ *   title={Least angle regression},
+ *   author={Efron, B. and Hastie, T. and Johnstone, I. and Tibshirani, R.},
+ *   journal={The Annals of statistics},
+ *   volume={32},
+ *   number={2},
+ *   pages={407--499},
+ *   year={2004},
+ *   publisher={Institute of Mathematical Statistics}
+ * }
+ * @endcode
+ */
+class LocalCoordinateCoding {
+
+ public:
+  //void Init(const arma::mat& matX, u32 nAtoms, double lambda);
+  
+  
+  /**
+   * Set the parameters to LocalCoordinateCoding.
+   *
+   * @param matX Data matrix
+   * @param nAtoms Number of atoms in dictionary
+   * @param lambda Regularization parameter for weighted l1-norm penalty
+   */
+  LocalCoordinateCoding(const arma::mat& matX, arma::u32 nAtoms, double lambda);
+  
+  /**
+   * Initialize dictionary somehow
+   */
+  void InitDictionary();
+  
+  /** 
+   * Load dictionary from a file
+   * 
+   * @param dictionaryFilename Filename containing dictionary
+   */
+  void LoadDictionary(const char* dictionaryFilename);
+
+  /**
+   * Initialize dictionary by drawing k vectors uniformly at random from the 
+   * unit sphere
+   */  
+  void RandomInitDictionary();
+
+  /**
+   * Initialize dictionary by initializing each atom to a normalized mixture of
+   * a small number of randomly selected points in X
+   */
+  void DataDependentRandomInitDictionary();
+
+  /**
+   * Initialize an atom to a normalized mixture of a small number of randomly
+   * selected points in X
+   *
+   * @param atom The atom to initialize
+   */
+  void RandomAtom(arma::vec& atom);
+
+
+  // core algorithm functions
+  
+  /**
+   * Run LCC
+   *
+   * @param nIterations Maximum number of iterations to run algorithm
+   */
+  void DoLCC(arma::u32 nIterations);
+
+  /**
+   * Sparse code each point via distance-weighted LARS
+   */
+  void OptimizeCode();
+  
+  /** 
+   * Learn dictionary by solving linear systemx
+   *
+   * @param adjacencies Indices of entries (unrolled column by column) of 
+   *    the coding matrix Z that are non-zero (the adjacency matrix for the 
+   *    bipartite graph of points and atoms)
+   */
+  void OptimizeDictionary(arma::uvec adjacencies);
+
+  /**
+   * Compute objective function
+   */  
+  double Objective(arma::uvec adjacencies);
+
+
+  // accessors, modifiers, printers
+  
+  //! Modifier for matD
+  void SetDictionary(const arma::mat& matD);
+
+  //! Accessor for matD
+  const arma::mat& MatD() {
+    return matD;
+  }
+
+  //! Accessor for matZ
+  const arma::mat& MatZ() {
+    return matZ;
+  }
+
+  // Print the dictionary matD
+  void PrintDictionary();
+    
+  // Print the sparse codes matZ
+  void PrintCoding();
+
+
+ private:
+  arma::u32 nDims;
+  arma::u32 nAtoms;
+  arma::u32 nPoints;
+
+  // data (columns are points)
+  arma::mat matX;
+
+  // dictionary (columns are atoms)
+  arma::mat matD;
+
+  // sparse codes (columns are points)
+  arma::mat matZ; 
+  
+  // l1 regularization term
+  double lambda;
+  
+};
+
+void RemoveRows(const arma::mat& X, arma::uvec rows_to_remove, arma::mat& X_mod);
+
+
+}; // namespace lcc
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp	2012-03-01 16:56:53 UTC (rev 11682)
@@ -0,0 +1,106 @@
+/** @file lcc_main.cpp
+ *  @author Nishant Mehta
+ *
+ *  Executable for Local Coordinate Coding
+ */
+
+#include <mlpack/core.hpp>
+#include "lcc.hpp"
+
+PROGRAM_INFO("LCC", "An implementation of Local Coordinate Coding");
+
+PARAM_DOUBLE_REQ("lambda", "weighted l1-norm regularization parameter.", "l");
+
+PARAM_INT_REQ("n_atoms", "number of atoms in dictionary.", "k");
+
+PARAM_INT_REQ("n_iterations", "number of iterations for sparse coding.", "");
+
+PARAM_INT_REQ("digit1", "digit for first class.", "");
+PARAM_INT_REQ("digit2", "digit for second class.", "");
+
+PARAM_STRING_REQ("data", "path to the input data.", "");
+PARAM_STRING("initial_dictionary", "Filename for initial dictionary.", "", "");
+PARAM_STRING("results_dir", "Directory for results.", "", "");
+
+
+
+using namespace arma;
+using namespace std;
+using namespace mlpack;
+using namespace mlpack::lcc;
+
+int main(int argc, char* argv[]) {
+  CLI::ParseCommandLine(argc, argv);
+  
+  double lambda = CLI::GetParam<double>("lambda");
+  
+  // if using fx-run, one could just leave resultsDir blank
+  const char* resultsDir = CLI::GetParam<string>("results_dir").c_str();
+  
+  const char* dataFullpath = CLI::GetParam<string>("data").c_str();
+  
+  const char* initialDictionaryFullpath = CLI::GetParam<string>("initial_dictionary").c_str();
+  
+  size_t nIterations = CLI::GetParam<int>("n_iterations");
+
+  size_t nAtoms = CLI::GetParam<int>("n_atoms");
+  
+  mat matX;
+  matX.load(dataFullpath);
+  
+  u32 nPoints = matX.n_cols;
+
+  // normalize each point since these are images
+  for(u32 i = 0; i < nPoints; i++) {
+    matX.col(i) /= norm(matX.col(i), 2);
+  }
+  
+  // run Local Coordinate Coding
+  LocalCoordinateCoding lcc(matX, nAtoms, lambda);
+  
+  if(strlen(initialDictionaryFullpath) == 0) {
+    lcc.DataDependentRandomInitDictionary();
+  }
+  else {
+    mat matInitialD;
+    matInitialD.load(initialDictionaryFullpath);
+    if(matInitialD.n_cols != nAtoms) {
+      Log::Fatal << "The specified initial dictionary to load has " 
+		 << matInitialD.n_cols << " atoms, but the learned dictionary "
+		 << "was specified to have " << nAtoms << " atoms!\n";
+      return EXIT_FAILURE;
+    }
+    if(matInitialD.n_rows != matX.n_rows) {
+      Log::Fatal << "The specified initial dictionary to load has "
+		 << matInitialD.n_rows << " dimensions, but the specified data "
+		 << "has " << matX.n_rows << " dimensions!\n";
+      return EXIT_FAILURE;
+    }
+    lcc.SetDictionary(matInitialD);
+  }
+  
+  
+  Timer::Start("lcc");
+  lcc.DoLCC(nIterations);
+  Timer::Stop("lcc");
+  
+  mat learnedD = lcc.MatD();
+  
+  mat learnedZ = lcc.MatZ();
+  
+  if(strlen(resultsDir) == 0) {
+    data::Save("D.csv", learnedD);
+    data::Save("Z.csv", learnedZ);
+  }
+  else {
+    char* dataFullpath = (char*) malloc(320 * sizeof(char));
+
+    sprintf(dataFullpath, "%s/D.csv", resultsDir);
+    data::Save(dataFullpath, learnedD);
+    
+    sprintf(dataFullpath, "%s/Z.csv", resultsDir);
+    data::Save(dataFullpath, learnedZ);
+    
+    free(dataFullpath);
+  }
+}




More information about the mlpack-svn mailing list