[mlpack-svn] r16681 - in mlpack/trunk/src/mlpack/methods/amf: . update_rules
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jun 11 11:17:05 EDT 2014
Author: sumedhghaisas
Date: Wed Jun 11 11:17:04 2014
New Revision: 16681
Log:
* Added momentum to SVD batch learning
* AMF now calls Initialize on update rule before starting the optimization
* Every update rule should now implement Initialize accepting data matrix
and rank
Modified:
mlpack/trunk/src/mlpack/methods/amf/amf.hpp
mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
Modified: mlpack/trunk/src/mlpack/methods/amf/amf.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf.hpp Wed Jun 11 11:17:04 2014
@@ -83,7 +83,7 @@
double Apply(const MatType& V,
const size_t r,
arma::mat& W,
- arma::mat& H) const;
+ arma::mat& H);
private:
//! The maximum number of iterations allowed before giving up.
Modified: mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp Wed Jun 11 11:17:04 2014
@@ -43,7 +43,7 @@
const MatType& V,
const size_t r,
arma::mat& W,
- arma::mat& H) const
+ arma::mat& H)
{
const size_t n = V.n_rows;
const size_t m = V.n_cols;
@@ -61,7 +61,7 @@
double norm = 0;
arma::mat WH;
- std::cout << tolerance << std::endl;
+ update.Initialize(V, r);
while (((oldResidue - residue) / oldResidue >= tolerance || iteration < 4) && iteration != maxIterations)
{
@@ -84,8 +84,6 @@
normOld = norm;
iteration++;
-
- std::cout << residue << std::endl;
}
Log::Info << "AMF converged to residue of " << sqrt(residue) << " in "
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp Wed Jun 11 11:17:04 2014
@@ -27,6 +27,13 @@
// Empty constructor required for the UpdateRule template.
NMFALSUpdate() { }
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ (void)dataset;
+ (void)rank;
+ }
+
/**
* The update rule for the basis matrix W. The formula used is
* \f[
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp Wed Jun 11 11:17:04 2014
@@ -26,6 +26,13 @@
// Empty constructor required for the UpdateRule template.
NMFMultiplicativeDistanceUpdate() { }
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ (void)dataset;
+ (void)rank;
+ }
+
/**
* The update rule for the basis matrix W. The formula used is
* \f[
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp Wed Jun 11 11:17:04 2014
@@ -25,13 +25,20 @@
// Empty constructor required for the WUpdateRule template.
NMFMultiplicativeDivergenceUpdate() { }
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ (void)dataset;
+ (void)rank;
+ }
+
/**
* The update rule for the basis matrix W. The formula used is
* \f[
* W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu}/(WH)_{i\mu}}
* {\sum_{\nu} H_{a\nu}}
* \f]
- * The function takes in all the matrices and only changes the
+ * The function takes in all the matrices and only changes the
* value of the W matrix.
*
* @param V Input matrix to be factorized.
@@ -73,7 +80,7 @@
* H_{a\mu} \leftarrow H_{a\mu} \frac{\sum_{i} W_{ia} V_{i\mu}/(WH)_{i\mu}}
* {\sum_{k} H_{ka}}
* \f]
- * The function takes in all the matrices and only changes the value
+ * The function takes in all the matrices and only changes the value
* of the H matrix.
*
* @param V Input matrix to be factorized.
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp Wed Jun 11 11:17:04 2014
@@ -13,9 +13,21 @@
SVDBatchLearning(double u = 0.000001,
double kw = 0,
double kh = 0,
+ double momentum = 0.2,
double min = -DBL_MIN,
double max = DBL_MAX)
- : u(u), kw(kw), kh(kh), min(min), max(max) {}
+ : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
+ {}
+
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ const size_t n = dataset.n_rows;
+ const size_t m = dataset.n_cols;
+
+ mW.zeros(n, rank);
+ mH.zeros(rank, m);
+ }
/**
* The update rule for the basis matrix W.
@@ -29,13 +41,15 @@
template<typename MatType>
inline void WUpdate(const MatType& V,
arma::mat& W,
- const arma::mat& H) const
+ const arma::mat& H)
{
size_t n = V.n_rows;
size_t m = V.n_cols;
size_t r = W.n_cols;
+ mW = momentum * mW;
+
arma::mat deltaW(n, r);
deltaW.zeros();
@@ -46,7 +60,8 @@
deltaW.row(i) -= kw * W.row(i);
}
- W += u * deltaW;
+ mW += u * deltaW;
+ W += mW;
}
/**
@@ -61,13 +76,15 @@
template<typename MatType>
inline void HUpdate(const MatType& V,
const arma::mat& W,
- arma::mat& H) const
+ arma::mat& H)
{
size_t n = V.n_rows;
size_t m = V.n_cols;
size_t r = W.n_cols;
+ mH = momentum * mH;
+
arma::mat deltaH(r, m);
deltaH.zeros();
@@ -78,7 +95,8 @@
deltaH.col(j) -= kh * H.col(j);
}
- H += u*deltaH;
+ mH += u*deltaH;
+ H += mH;
}
private:
@@ -94,6 +112,10 @@
double kh;
double min;
double max;
+ double momentum;
+
+ arma::mat mW;
+ arma::mat mH;
};
} // namespace amf
} // namespace mlpack
More information about the mlpack-svn
mailing list