[mlpack-svn] r13109 - mlpack/trunk/src/mlpack/methods/sparse_coding
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Jun 26 15:16:29 EDT 2012
Author: rcurtin
Date: 2012-06-26 15:16:29 -0400 (Tue, 26 Jun 2012)
New Revision: 13109
Modified:
mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
Log:
Avoid a matrix copy for matActiveZ and potentially speed up gradient calculation
by avoiding instantiation of arma::ones<arma::vec> (Armadillo may have already
been smart enough to avoid that anyway).
Modified: mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp 2012-06-26 19:10:40 UTC (rev 13108)
+++ mlpack/trunk/src/mlpack/methods/sparse_coding/sparse_coding_impl.hpp 2012-06-26 19:16:29 UTC (rev 13109)
@@ -153,12 +153,8 @@
// Efficient construction of Z restricted to active atoms.
arma::mat matActiveZ;
- if (inactiveAtoms.empty())
+ if (!inactiveAtoms.empty())
{
- matActiveZ = codes;
- }
- else
- {
arma::uvec inactiveAtomsVec =
arma::conv_to<arma::uvec>::from(inactiveAtoms);
RemoveRows(codes, inactiveAtomsVec, matActiveZ);
@@ -191,9 +187,22 @@
// dualVars(i) = 0;
bool converged = false;
- arma::mat codesXT = matActiveZ * trans(data);
- arma::mat codesZT = matActiveZ * trans(matActiveZ);
+ // If we have any inactive atoms, we must construct these differently.
+ arma::mat codesXT;
+ arma::mat codesZT;
+
+ if (inactiveAtoms.empty())
+ {
+ codesXT = codes * trans(data);
+ codesZT = codes * trans(codes);
+ }
+ else
+ {
+ codesXT = matActiveZ * trans(data);
+ codesZT = matActiveZ * trans(matActiveZ);
+ }
+
double improvement;
for (size_t t = 1; !converged; ++t)
{
@@ -201,8 +210,8 @@
arma::mat matAInvZXT = solve(A, codesXT);
- arma::vec gradient = -(arma::sum(arma::square(matAInvZXT), 1) -
- arma::ones<arma::vec>(nActiveAtoms));
+ arma::vec gradient = -arma::sum(arma::square(matAInvZXT), 1);
+ gradient += 1;
arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
More information about the mlpack-svn
mailing list