[mlpack-svn] r13118 - mlpack/trunk/src/mlpack/methods/local_coordinate_coding

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jun 27 15:56:53 EDT 2012


Author: rcurtin
Date: 2012-06-27 15:56:53 -0400 (Wed, 27 Jun 2012)
New Revision: 13118

Added:
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
Modified:
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
Log:
Template LocalCoordinateCoding to accept different kinds of dictionary
initialization.  The loading dictionaries from file feature is temporarily
broken for the LCC executable, but this will be fixed soon.


Modified: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/CMakeLists.txt	2012-06-27 18:39:40 UTC (rev 13117)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/CMakeLists.txt	2012-06-27 19:56:53 UTC (rev 13118)
@@ -7,7 +7,7 @@
 # that you have files in both sections
 set(SOURCES
    lcc.hpp
-   lcc.cpp
+   lcc_impl.hpp
 )
 
 # add directory name to sources

Deleted: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	2012-06-27 18:39:40 UTC (rev 13117)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	2012-06-27 19:56:53 UTC (rev 13118)
@@ -1,391 +0,0 @@
-/**
- * @file lcc.cpp
- * @author Nishant Mehta
- *
- * Implementation of Local Coordinate Coding
- */
-
-#include "lcc.hpp"
-
-using namespace arma;
-using namespace std;
-using namespace mlpack::regression;
-using namespace mlpack::lcc;
-
-#define OBJ_TOL 1e-2 // 1E-9
-
-namespace mlpack {
-namespace lcc {
-
-LocalCoordinateCoding::LocalCoordinateCoding(const mat& matX,
-                                             uword nAtoms,
-                                             double lambda) :
-  nDims(matX.n_rows),
-  nAtoms(nAtoms),
-  nPoints(matX.n_cols),
-  matX(matX),
-  matZ(mat(nAtoms, nPoints)),
-  lambda(lambda)
-{ /* nothing left to do */ }
-
-
-void LocalCoordinateCoding::SetDictionary(const mat& matD)
-{
-  this->matD = matD;
-}
-
-
-void LocalCoordinateCoding::InitDictionary()
-{
-  RandomInitDictionary();
-}
-
-
-void LocalCoordinateCoding::LoadDictionary(const char* dictionaryFilename)
-{
-  matD.load(dictionaryFilename);
-}
-
-
-void LocalCoordinateCoding::RandomInitDictionary()
-{
-  matD = randn(nDims, nAtoms);
-  for (uword j = 0; j < nAtoms; j++)
-    matD.col(j) /= norm(matD.col(j), 2);
-}
-
-
-void LocalCoordinateCoding::DataDependentRandomInitDictionary()
-{
-  matD = mat(nDims, nAtoms);
-  for (uword j = 0; j < nAtoms; j++)
-  {
-    vec vecD_j = matD.unsafe_col(j);
-    RandomAtom(vecD_j);
-  }
-}
-
-
-void LocalCoordinateCoding::RandomAtom(vec& atom)
-{
-  atom.zeros();
-  const uword nSeedAtoms = 3;
-  for (uword i = 0; i < nSeedAtoms; i++)
-    atom += matX.col(rand() % nPoints);
-
-  atom /= ((double) nSeedAtoms);
-  atom /= norm(atom, 2);
-}
-
-void LocalCoordinateCoding::DoLCC(uword nIterations)
-{
-  bool converged = false;
-  double lastObjVal = 1e99;
-
-  Log::Info << "Initial Coding Step" << endl;
-  OptimizeCode();
-  uvec adjacencies = find(matZ);
-  Log::Info << "\tSparsity level: " << 100.0 * ((double)(adjacencies.n_elem)) /
-      ((double)(nAtoms * nPoints)) << "%\n";
-  Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
-
-  for (uword t = 1; t <= nIterations && !converged; t++)
-  {
-    Log::Info << "Iteration " << t << " of " << nIterations << endl;
-
-    Log::Info << "Dictionary Step\n";
-    OptimizeDictionary(adjacencies);
-    double dsObjVal = Objective(adjacencies);
-    Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
-
-    Log::Info << "Coding Step" << endl;
-    OptimizeCode();
-    adjacencies = find(matZ);
-    Log::Info << "\tSparsity level: " << 100.0 * ((double)(adjacencies.n_elem))
-        / ((double)(nAtoms * nPoints)) << "%\n";
-    double curObjVal = Objective(adjacencies);
-    Log::Info << "\tObjective value: " << curObjVal << endl;
-
-    if (curObjVal > dsObjVal)
-    {
-      Log::Fatal << "Objective increased in sparse coding step!" << endl;
-    }
-
-    double objValImprov = lastObjVal - curObjVal;
-    Log::Info << "\t\t\t\t\tImprovement: " << std::scientific << objValImprov
-        << endl;
-
-    if (objValImprov < OBJ_TOL)
-    {
-      converged = true;
-      Log::Info << "Converged within tolerance\n";
-    }
-
-    lastObjVal = curObjVal;
-  }
-}
-
-void LocalCoordinateCoding::OptimizeCode()
-{
-  mat matSqDists = repmat(trans(sum(square(matD))), 1, nPoints) +
-      repmat(sum(square(matX)), nAtoms, 1) - 2 * trans(matD) * matX;
-
-  mat matInvSqDists = 1.0 / matSqDists;
-
-  mat matDTD = trans(matD) * matD;
-  mat matDPrimeTDPrime(matDTD.n_rows, matDTD.n_cols);
-
-  for (uword i = 0; i < nPoints; i++)
-  {
-    // report progress
-    if ((i % 100) == 0)
-    {
-      Log::Debug << "\t" << i << endl;
-    }
-
-    vec w = matSqDists.unsafe_col(i);
-    vec invW = matInvSqDists.unsafe_col(i);
-    mat matDPrime = matD * diagmat(invW);
-
-    mat matDPrimeTDPrime = diagmat(invW) * matDTD * diagmat(invW);
-
-    //LARS lars;
-    // do we still need 0.5 * lambda? yes, yes we do
-    //lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, true, 0.5 *
-    //lambda); // apparently not as fast as using the below duo
-    // this may change, depending on the dimensionality and sparsity
-
-    // the duo
-    /* lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, false, 0.5 *
-     * lambda); */
-    /* lars.SetGram(matDPrimeTDPrime.memptr(), nAtoms); */
-
-    bool useCholesky = false;
-    LARS lars(useCholesky, matDPrimeTDPrime, 0.5 * lambda);
-
-    vec beta;
-    lars.Regress(matDPrime, matX.unsafe_col(i), beta, true);
-    matZ.col(i) = beta % invW;
-  }
-}
-
-void LocalCoordinateCoding::OptimizeDictionary(uvec adjacencies)
-{
-  // count number of atomic neighbors for each point x^i
-  uvec neighborCounts = zeros<uvec>(nPoints, 1);
-  if (adjacencies.n_elem > 0)
-  {
-    // this gets the column index
-    uword curPointInd = (uword)(adjacencies(0) / nAtoms);
-    uword curCount = 1;
-    for (uword l = 1; l < adjacencies.n_elem; l++)
-    {
-      if ((uword) (adjacencies(l) / nAtoms) == curPointInd)
-      {
-        curCount++;
-      }
-      else
-      {
-        neighborCounts(curPointInd) = curCount;
-        curPointInd = (uword)(adjacencies(l) / nAtoms);
-        curCount = 1;
-      }
-    }
-    neighborCounts(curPointInd) = curCount;
-  }
-
-  // build matXPrime := [X x^1 ... x^1 ... x^n ... x^n]
-  // where each x^i is repeated for the number of neighbors x^i has
-  mat matXPrime = zeros(nDims, nPoints + adjacencies.n_elem);
-  matXPrime(span::all, span(0, nPoints - 1)) = matX;
-  uword curCol = nPoints;
-  for (uword i = 0; i < nPoints; i++)
-  {
-    if (neighborCounts(i) > 0)
-    {
-      matXPrime(span::all, span(curCol, curCol + neighborCounts(i) - 1)) =
-          repmat(matX.col(i), 1, neighborCounts(i));
-    }
-    curCol += neighborCounts(i);
-  }
-
-  // handle the case of inactive atoms (atoms not used in the given coding)
-  std::vector<uword> inactiveAtoms;
-  std::vector<uword> activeAtoms;
-  activeAtoms.reserve(nAtoms);
-  for (uword j = 0; j < nAtoms; j++)
-  {
-    if (accu(matZ.row(j) != 0) == 0)
-    {
-      inactiveAtoms.push_back(j);
-    }
-    else
-    {
-      activeAtoms.push_back(j);
-    }
-  }
-  uword nActiveAtoms = activeAtoms.size();
-  uword nInactiveAtoms = inactiveAtoms.size();
-
-  // efficient construction of Z restricted to active atoms
-  mat matActiveZ;
-  if (inactiveAtoms.empty())
-  {
-    matActiveZ = matZ;
-  }
-  else
-  {
-    uvec inactiveAtomsVec = conv_to<uvec>::from(inactiveAtoms);
-    RemoveRows(matZ, inactiveAtomsVec, matActiveZ);
-  }
-
-  uvec atomReverseLookup = uvec(nAtoms);
-  for (uword i = 0; i < nActiveAtoms; i++)
-  {
-    atomReverseLookup(activeAtoms[i]) = i;
-  }
-
-
-  if (nInactiveAtoms > 0)
-  {
-    Log::Info << "There are " << nInactiveAtoms << " inactive atoms. They will"
-        << " be re-initialized randomly.\n";
-  }
-
-  mat matZPrime = zeros(nActiveAtoms, nPoints + adjacencies.n_elem);
-  //Log::Debug << "adjacencies.n_elem = " << adjacencies.n_elem << endl;
-  matZPrime(span::all, span(0, nPoints - 1)) = matActiveZ;
-
-  vec wSquared = ones(nPoints + adjacencies.n_elem, 1);
-  //Log::Debug << "building up matZPrime\n";
-  for (uword l = 0; l < adjacencies.n_elem; l++)
-  {
-    uword atomInd = adjacencies(l) % nAtoms;
-    uword pointInd = (uword) (adjacencies(l) / nAtoms);
-    matZPrime(atomReverseLookup(atomInd), nPoints + l) = 1.0;
-    wSquared(nPoints + l) = matZ(atomInd, pointInd);
-  }
-
-  wSquared.subvec(nPoints, wSquared.n_elem - 1) = lambda *
-      abs(wSquared.subvec(nPoints, wSquared.n_elem - 1));
-
-  //Log::Debug << "about to solve\n";
-  mat matDEstimate;
-  if (inactiveAtoms.empty())
-  {
-    mat A = matZPrime * diagmat(wSquared) * trans(matZPrime);
-    mat B = matZPrime * diagmat(wSquared) * trans(matXPrime);
-
-    //Log::Debug << "solving...\n";
-    matDEstimate =
-      trans(solve(A, B));
-    /*
-    matDEstimate =
-      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
-                  matZPrime * diagmat(wSquared) * trans(matXPrime)));
-    */
-  }
-  else
-  {
-    matDEstimate = zeros(nDims, nAtoms);
-    //Log::Debug << "solving...\n";
-    mat matDActiveEstimate =
-      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
-                  matZPrime * diagmat(wSquared) * trans(matXPrime)));
-    for (uword j = 0; j < nActiveAtoms; j++)
-    {
-      matDEstimate.col(activeAtoms[j]) = matDActiveEstimate.col(j);
-    }
-
-    for (uword j = 0; j < nInactiveAtoms; j++)
-    {
-      vec vecD_j = matDEstimate.unsafe_col(inactiveAtoms[j]);
-      RandomAtom(vecD_j);
-      /*
-      vec new_atom = randn(nDims, 1);
-      matDEstimate.col(inactiveAtoms[i]) = new_atom / norm(new_atom, 2);
-      */
-    }
-  }
-
-  matD = matDEstimate;
-}
-
-double LocalCoordinateCoding::Objective(uvec adjacencies)
-{
-  double weightedL1NormZ = 0;
-  uword nAdjacencies = adjacencies.n_elem;
-  for (uword l = 0; l < nAdjacencies; l++)
-  {
-    uword atomInd = adjacencies(l) % nAtoms;
-    uword pointInd = (uword) (adjacencies(l) / nAtoms);
-    weightedL1NormZ += fabs(matZ(atomInd, pointInd)) *
-        as_scalar(sum(square(matD.col(atomInd) - matX.col(pointInd))));
-  }
-
-  double froNormResidual = norm(matX - matD * matZ, "fro");
-  return froNormResidual * froNormResidual + lambda * weightedL1NormZ;
-}
-
-void LocalCoordinateCoding::PrintDictionary()
-{
-  matD.print("Dictionary");
-}
-
-void LocalCoordinateCoding::PrintCoding()
-{
-  matZ.print("Coding matrix");
-}
-
-void RemoveRows(const mat& X, uvec rows_to_remove, mat& X_mod)
-{
-  uword n_cols = X.n_cols;
-  uword n_rows = X.n_rows;
-  uword n_to_remove = rows_to_remove.n_elem;
-  uword n_to_keep = n_rows - n_to_remove;
-
-  if (n_to_remove == 0)
-  {
-    X_mod = X;
-  }
-  else
-  {
-    X_mod.set_size(n_to_keep, n_cols);
-
-    uword cur_row = 0;
-    uword remove_ind = 0;
-    // first, check 0 to first row to remove
-    if (rows_to_remove(0) > 0)
-    {
-      // note that this implies that n_rows > 1
-      uword height = rows_to_remove(0);
-      X_mod(span(cur_row, cur_row + height - 1), span::all) =
-          X(span(0, rows_to_remove(0) - 1), span::all);
-      cur_row += height;
-    }
-    // now, check i'th row to remove to (i + 1)'th row to remove, until i =
-    // penultimate row
-    while (remove_ind < n_to_remove - 1)
-    {
-      uword height = rows_to_remove[remove_ind + 1] - rows_to_remove[remove_ind]
-          - 1;
-      if (height > 0)
-      {
-        X_mod(span(cur_row, cur_row + height - 1), span::all) =
-            X(span(rows_to_remove[remove_ind] + 1,
-            rows_to_remove[remove_ind + 1] - 1), span::all);
-        cur_row += height;
-      }
-      remove_ind++;
-    }
-    // now that i is last row to remove, check last row to remove to last row
-    if (rows_to_remove[remove_ind] < n_rows - 1)
-    {
-      X_mod(span(cur_row, n_to_keep - 1), span::all) =
-          X(span(rows_to_remove[remove_ind] + 1, n_rows - 1), span::all);
-    }
-  }
-}
-
-}; // namespace lcc
-}; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp	2012-06-27 18:39:40 UTC (rev 13117)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.hpp	2012-06-27 19:56:53 UTC (rev 13118)
@@ -5,13 +5,16 @@
  * Definition of the LocalCoordinateCoding class, which performs the Local
  * Coordinate Coding algorithm.
  */
-#ifndef __MLPACK_METHODS_LCC_LCC_HPP
-#define __MLPACK_METHODS_LCC_LCC_HPP
+#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
+#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_HPP
 
 #include <mlpack/core.hpp>
 #include <mlpack/methods/lars/lars.hpp>
 
-// Include three simple dictionary initializers from
+// Include three simple dictionary initializers from sparse coding.
+#include "../sparse_coding/nothing_initializer.hpp"
+#include "../sparse_coding/data_dependent_random_initializer.hpp"
+#include "../sparse_coding/random_initializer.hpp"
 
 namespace mlpack {
 namespace lcc {
@@ -67,6 +70,8 @@
  * }
  * @endcode
  */
+template<typename DictionaryInitializer =
+    sparse_coding::DataDependentRandomInitializer>
 class LocalCoordinateCoding
 {
  public:
@@ -77,45 +82,11 @@
    * @param nAtoms Number of atoms in dictionary
    * @param lambda Regularization parameter for weighted l1-norm penalty
    */
-  LocalCoordinateCoding(const arma::mat& matX, arma::uword nAtoms,
-      double lambda);
+  LocalCoordinateCoding(const arma::mat& matX,
+                        arma::uword nAtoms,
+                        double lambda);
 
   /**
-   * Initialize dictionary somehow.
-   */
-  void InitDictionary();
-
-  /**
-   * Load dictionary from a file
-   *
-   * @param dictionaryFilename Filename containing dictionary
-   */
-  void LoadDictionary(const char* dictionaryFilename);
-
-  /**
-   * Initialize dictionary by drawing k vectors uniformly at random from the
-   * unit sphere
-   */
-  void RandomInitDictionary();
-
-  /**
-   * Initialize dictionary by initializing each atom to a normalized mixture of
-   * a small number of randomly selected points in X
-   */
-  void DataDependentRandomInitDictionary();
-
-  /**
-   * Initialize an atom to a normalized mixture of a small number of randomly
-   * selected points in X
-   *
-   * @param atom The atom to initialize
-   */
-  void RandomAtom(arma::vec& atom);
-
-
-  // core algorithm functions
-
-  /**
    * Run LCC
    *
    * @param nIterations Maximum number of iterations to run algorithm
@@ -185,4 +156,7 @@
 }; // namespace lcc
 }; // namespace mlpack
 
+// Include implementation.
+#include "lcc_impl.hpp"
+
 #endif

Copied: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp (from rev 13116, mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_impl.hpp	2012-06-27 19:56:53 UTC (rev 13118)
@@ -0,0 +1,356 @@
+/**
+ * @file lcc_impl.hpp
+ * @author Nishant Mehta
+ *
+ * Implementation of Local Coordinate Coding
+ */
+#ifndef __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+#define __MLPACK_METHODS_LOCAL_COORDINATE_CODING_LCC_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "lcc.hpp"
+
+#define OBJ_TOL 1e-2 // 1E-9
+
+namespace mlpack {
+namespace lcc {
+
+template<typename DictionaryInitializer>
+LocalCoordinateCoding<DictionaryInitializer>::LocalCoordinateCoding(
+    const arma::mat& matX,
+    arma::uword nAtoms,
+    double lambda) :
+    nDims(matX.n_rows),
+    nAtoms(nAtoms),
+    nPoints(matX.n_cols),
+    matX(matX),
+    matZ(nAtoms, nPoints),
+    lambda(lambda)
+{
+  // Initialize the dictionary.
+  DictionaryInitializer::Initialize(matX, nAtoms, matD);
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::SetDictionary(
+    const arma::mat& matD)
+{
+  this->matD = matD;
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::DoLCC(
+    arma::uword nIterations)
+{
+  bool converged = false;
+  double lastObjVal = 1e99;
+
+  Log::Info << "Initial Coding Step" << std::endl;
+  OptimizeCode();
+  arma::uvec adjacencies = find(matZ);
+  Log::Info << "\tSparsity level: " << 100.0 * ((double)(adjacencies.n_elem)) /
+      ((double)(nAtoms * nPoints)) << "%\n";
+  Log::Info << "\tObjective value: " << Objective(adjacencies) << std::endl;
+
+  for (arma::uword t = 1; t <= nIterations && !converged; t++)
+  {
+    Log::Info << "Iteration " << t << " of " << nIterations << std::endl;
+
+    Log::Info << "Dictionary Step\n";
+    OptimizeDictionary(adjacencies);
+    double dsObjVal = Objective(adjacencies);
+    Log::Info << "\tObjective value: " << Objective(adjacencies) << std::endl;
+
+    Log::Info << "Coding Step" << std::endl;
+    OptimizeCode();
+    adjacencies = find(matZ);
+    Log::Info << "\tSparsity level: " << 100.0 * ((double)(adjacencies.n_elem))
+        / ((double)(nAtoms * nPoints)) << "%\n";
+    double curObjVal = Objective(adjacencies);
+    Log::Info << "\tObjective value: " << curObjVal << std::endl;
+
+    if (curObjVal > dsObjVal)
+    {
+      Log::Fatal << "Objective increased in sparse coding step!" << std::endl;
+    }
+
+    double objValImprov = lastObjVal - curObjVal;
+    Log::Info << "\t\t\t\t\tImprovement: " << std::scientific << objValImprov
+        << std::endl;
+
+    if (objValImprov < OBJ_TOL)
+    {
+      converged = true;
+      Log::Info << "Converged within tolerance\n";
+    }
+
+    lastObjVal = curObjVal;
+  }
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::OptimizeCode()
+{
+  arma::mat matSqDists = repmat(trans(sum(square(matD))), 1, nPoints) +
+      repmat(sum(square(matX)), nAtoms, 1) - 2 * trans(matD) * matX;
+
+  arma::mat matInvSqDists = 1.0 / matSqDists;
+
+  arma::mat matDTD = trans(matD) * matD;
+  arma::mat matDPrimeTDPrime(matDTD.n_rows, matDTD.n_cols);
+
+  for (arma::uword i = 0; i < nPoints; i++)
+  {
+    // report progress
+    if ((i % 100) == 0)
+    {
+      Log::Debug << "\t" << i << std::endl;
+    }
+
+    arma::vec w = matSqDists.unsafe_col(i);
+    arma::vec invW = matInvSqDists.unsafe_col(i);
+    arma::mat matDPrime = matD * diagmat(invW);
+
+    arma::mat matDPrimeTDPrime = diagmat(invW) * matDTD * diagmat(invW);
+
+    //LARS lars;
+    // do we still need 0.5 * lambda? yes, yes we do
+    //lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, true, 0.5 *
+    //lambda); // apparently not as fast as using the below duo
+    // this may change, depending on the dimensionality and sparsity
+
+    // the duo
+    /* lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, false, 0.5 *
+     * lambda); */
+    /* lars.SetGram(matDPrimeTDPrime.memptr(), nAtoms); */
+
+    bool useCholesky = false;
+    regression::LARS lars(useCholesky, matDPrimeTDPrime, 0.5 * lambda);
+
+    arma::vec beta;
+    lars.Regress(matDPrime, matX.unsafe_col(i), beta, true);
+    matZ.col(i) = beta % invW;
+  }
+}
+
+template<typename DictionaryInitializer>
+void LocalCoordinateCoding<DictionaryInitializer>::OptimizeDictionary(
+    arma::uvec adjacencies)
+{
+  // count number of atomic neighbors for each point x^i
+  arma::uvec neighborCounts = arma::zeros<arma::uvec>(nPoints, 1);
+  if (adjacencies.n_elem > 0)
+  {
+    // this gets the column index
+    arma::uword curPointInd = (arma::uword) (adjacencies(0) / nAtoms);
+    arma::uword curCount = 1;
+    for (arma::uword l = 1; l < adjacencies.n_elem; l++)
+    {
+      if ((arma::uword) (adjacencies(l) / nAtoms) == curPointInd)
+      {
+        curCount++;
+      }
+      else
+      {
+        neighborCounts(curPointInd) = curCount;
+        curPointInd = (arma::uword)(adjacencies(l) / nAtoms);
+        curCount = 1;
+      }
+    }
+    neighborCounts(curPointInd) = curCount;
+  }
+
+  // build matXPrime := [X x^1 ... x^1 ... x^n ... x^n]
+  // where each x^i is repeated for the number of neighbors x^i has
+  arma::mat matXPrime = arma::zeros(nDims, nPoints + adjacencies.n_elem);
+  matXPrime(arma::span::all, arma::span(0, nPoints - 1)) = matX;
+  arma::uword curCol = nPoints;
+  for (arma::uword i = 0; i < nPoints; i++)
+  {
+    if (neighborCounts(i) > 0)
+    {
+      matXPrime(arma::span::all, arma::span(curCol, curCol + neighborCounts(i)
+          - 1)) = repmat(matX.col(i), 1, neighborCounts(i));
+    }
+    curCol += neighborCounts(i);
+  }
+
+  // handle the case of inactive atoms (atoms not used in the given coding)
+  std::vector<arma::uword> inactiveAtoms;
+  std::vector<arma::uword> activeAtoms;
+  activeAtoms.reserve(nAtoms);
+  for (arma::uword j = 0; j < nAtoms; j++)
+  {
+    if (accu(matZ.row(j) != 0) == 0)
+    {
+      inactiveAtoms.push_back(j);
+    }
+    else
+    {
+      activeAtoms.push_back(j);
+    }
+  }
+  arma::uword nActiveAtoms = activeAtoms.size();
+  arma::uword nInactiveAtoms = inactiveAtoms.size();
+
+  // efficient construction of Z restricted to active atoms
+  arma::mat matActiveZ;
+  if (inactiveAtoms.empty())
+  {
+    matActiveZ = matZ;
+  }
+  else
+  {
+    arma::uvec inactiveAtomsVec = arma::conv_to<arma::uvec>::from(
+        inactiveAtoms);
+    RemoveRows(matZ, inactiveAtomsVec, matActiveZ);
+  }
+
+  arma::uvec atomReverseLookup = arma::uvec(nAtoms);
+  for (arma::uword i = 0; i < nActiveAtoms; i++)
+  {
+    atomReverseLookup(activeAtoms[i]) = i;
+  }
+
+  if (nInactiveAtoms > 0)
+  {
+    Log::Info << "There are " << nInactiveAtoms << " inactive atoms. They will"
+        << " be re-initialized randomly.\n";
+  }
+
+  arma::mat matZPrime = arma::zeros(nActiveAtoms, nPoints + adjacencies.n_elem);
+  //Log::Debug << "adjacencies.n_elem = " << adjacencies.n_elem << std::endl;
+  matZPrime(arma::span::all, arma::span(0, nPoints - 1)) = matActiveZ;
+
+  arma::vec wSquared = arma::ones(nPoints + adjacencies.n_elem, 1);
+  //Log::Debug << "building up matZPrime\n";
+  for (arma::uword l = 0; l < adjacencies.n_elem; l++)
+  {
+    arma::uword atomInd = adjacencies(l) % nAtoms;
+    arma::uword pointInd = (arma::uword) (adjacencies(l) / nAtoms);
+    matZPrime(atomReverseLookup(atomInd), nPoints + l) = 1.0;
+    wSquared(nPoints + l) = matZ(atomInd, pointInd);
+  }
+
+  wSquared.subvec(nPoints, wSquared.n_elem - 1) = lambda *
+      abs(wSquared.subvec(nPoints, wSquared.n_elem - 1));
+
+  //Log::Debug << "about to solve\n";
+  arma::mat matDEstimate;
+  if (inactiveAtoms.empty())
+  {
+    arma::mat A = matZPrime * diagmat(wSquared) * trans(matZPrime);
+    arma::mat B = matZPrime * diagmat(wSquared) * trans(matXPrime);
+
+    //Log::Debug << "solving...\n";
+    matDEstimate =
+      trans(solve(A, B));
+    /*
+    matDEstimate =
+      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
+                  matZPrime * diagmat(wSquared) * trans(matXPrime)));
+    */
+  }
+  else
+  {
+    matDEstimate = arma::zeros(nDims, nAtoms);
+    //Log::Debug << "solving...\n";
+    arma::mat matDActiveEstimate =
+      trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
+                  matZPrime * diagmat(wSquared) * trans(matXPrime)));
+    for (arma::uword j = 0; j < nActiveAtoms; j++)
+    {
+      matDEstimate.col(activeAtoms[j]) = matDActiveEstimate.col(j);
+    }
+
+    for (arma::uword j = 0; j < nInactiveAtoms; j++)
+    {
+      // Reinitialize randomly.
+      // Add three atoms together.
+      matDEstimate.col(inactiveAtoms[j]) =
+          (matX.col(math::RandInt(matX.n_cols)) +
+           matX.col(math::RandInt(matX.n_cols)) +
+           matX.col(math::RandInt(matX.n_cols)));
+
+      // Now normalize the atom.
+      matDEstimate.col(inactiveAtoms[j]) /=
+          norm(matDEstimate.col(inactiveAtoms[j]), 2);
+    }
+  }
+
+  matD = matDEstimate;
+}
+
+template<typename DictionaryInitializer>
+double LocalCoordinateCoding<DictionaryInitializer>::Objective(
+    arma::uvec adjacencies)
+{
+  double weightedL1NormZ = 0;
+  arma::uword nAdjacencies = adjacencies.n_elem;
+  for (arma::uword l = 0; l < nAdjacencies; l++)
+  {
+    arma::uword atomInd = adjacencies(l) % nAtoms;
+    arma::uword pointInd = (arma::uword) (adjacencies(l) / nAtoms);
+    weightedL1NormZ += fabs(matZ(atomInd, pointInd)) *
+        as_scalar(sum(square(matD.col(atomInd) - matX.col(pointInd))));
+  }
+
+  double froNormResidual = norm(matX - matD * matZ, "fro");
+  return froNormResidual * froNormResidual + lambda * weightedL1NormZ;
+}
+
+void RemoveRows(const arma::mat& X, arma::uvec rows_to_remove, arma::mat& X_mod)
+{
+  arma::uword n_cols = X.n_cols;
+  arma::uword n_rows = X.n_rows;
+  arma::uword n_to_remove = rows_to_remove.n_elem;
+  arma::uword n_to_keep = n_rows - n_to_remove;
+
+  if (n_to_remove == 0)
+  {
+    X_mod = X;
+  }
+  else
+  {
+    X_mod.set_size(n_to_keep, n_cols);
+
+    arma::uword cur_row = 0;
+    arma::uword remove_ind = 0;
+    // first, check 0 to first row to remove
+    if (rows_to_remove(0) > 0)
+    {
+      // note that this implies that n_rows > 1
+      arma::uword height = rows_to_remove(0);
+      X_mod(arma::span(cur_row, cur_row + height - 1), arma::span::all) =
+          X(arma::span(0, rows_to_remove(0) - 1), arma::span::all);
+      cur_row += height;
+    }
+    // now, check i'th row to remove to (i + 1)'th row to remove, until i =
+    // penultimate row
+    while (remove_ind < n_to_remove - 1)
+    {
+      arma::uword height = rows_to_remove[remove_ind + 1] -
+          rows_to_remove[remove_ind] - 1;
+      if (height > 0)
+      {
+        X_mod(arma::span(cur_row, cur_row + height - 1), arma::span::all) =
+            X(arma::span(rows_to_remove[remove_ind] + 1,
+            rows_to_remove[remove_ind + 1] - 1), arma::span::all);
+        cur_row += height;
+      }
+      remove_ind++;
+    }
+    // now that i is last row to remove, check last row to remove to last row
+    if (rows_to_remove[remove_ind] < n_rows - 1)
+    {
+      X_mod(arma::span(cur_row, n_to_keep - 1), arma::span::all) =
+          X(arma::span(rows_to_remove[remove_ind] + 1, n_rows - 1),
+          arma::span::all);
+    }
+  }
+}
+
+}; // namespace lcc
+}; // namespace mlpack
+
+#endif

Modified: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp	2012-06-27 18:39:40 UTC (rev 13117)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp	2012-06-27 19:56:53 UTC (rev 13118)
@@ -54,11 +54,11 @@
   }
 
   // run Local Coordinate Coding
-  LocalCoordinateCoding lcc(matX, nAtoms, lambda);
+  LocalCoordinateCoding<> lcc(matX, nAtoms, lambda);
 
   if (strlen(initialDictionaryFullpath) == 0)
   {
-    lcc.DataDependentRandomInitDictionary();
+//    lcc.DataDependentRandomInitDictionary();
   }
   else
   {




More information about the mlpack-svn mailing list