[mlpack-svn] r16681 - in mlpack/trunk/src/mlpack/methods/amf: . update_rules

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jun 11 11:17:05 EDT 2014


Author: sumedhghaisas
Date: Wed Jun 11 11:17:04 2014
New Revision: 16681

Log:
* Added momentum to SVD batch learning
* AMF now calls Initialize on update rule before starting the optimization
* Every update rule should now implement Initialize accepting data matrix
  and rank


Modified:
   mlpack/trunk/src/mlpack/methods/amf/amf.hpp
   mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp

Modified: mlpack/trunk/src/mlpack/methods/amf/amf.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf.hpp	Wed Jun 11 11:17:04 2014
@@ -83,7 +83,7 @@
   double Apply(const MatType& V,
              const size_t r,
              arma::mat& W,
-             arma::mat& H) const;
+             arma::mat& H);
 
  private:
   //! The maximum number of iterations allowed before giving up.

Modified: mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/amf_impl.hpp	Wed Jun 11 11:17:04 2014
@@ -43,7 +43,7 @@
     const MatType& V,
     const size_t r,
     arma::mat& W,
-    arma::mat& H) const
+    arma::mat& H)
 {
   const size_t n = V.n_rows;
   const size_t m = V.n_cols;
@@ -61,7 +61,7 @@
   double norm = 0;
   arma::mat WH;
 
-  std::cout << tolerance << std::endl;
+  update.Initialize(V, r);
 
   while (((oldResidue - residue) / oldResidue >= tolerance || iteration < 4) && iteration != maxIterations)
   {
@@ -84,8 +84,6 @@
     normOld = norm;
 
     iteration++;
-
-    std::cout << residue << std::endl;
   }
 
   Log::Info << "AMF converged to residue of " << sqrt(residue) << " in "

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_als.hpp	Wed Jun 11 11:17:04 2014
@@ -27,6 +27,13 @@
   // Empty constructor required for the UpdateRule template.
   NMFALSUpdate() { }
 
+  template<typename MatType>
+  void Initialize(const MatType& dataset, const size_t rank)
+  {
+      (void)dataset;
+      (void)rank;
+  }
+
   /**
    * The update rule for the basis matrix W. The formula used is
    * \f[

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp	Wed Jun 11 11:17:04 2014
@@ -26,6 +26,13 @@
   // Empty constructor required for the UpdateRule template.
   NMFMultiplicativeDistanceUpdate() { }
 
+  template<typename MatType>
+  void Initialize(const MatType& dataset, const size_t rank)
+  {
+        (void)dataset;
+        (void)rank;
+  }
+
   /**
    * The update rule for the basis matrix W. The formula used is
    * \f[

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp	Wed Jun 11 11:17:04 2014
@@ -25,13 +25,20 @@
   // Empty constructor required for the WUpdateRule template.
   NMFMultiplicativeDivergenceUpdate() { }
 
+  template<typename MatType>
+  void Initialize(const MatType& dataset, const size_t rank)
+  {
+    (void)dataset;
+    (void)rank;
+  }
+
   /**
    * The update rule for the basis matrix W. The formula used is
    * \f[
    * W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu}/(WH)_{i\mu}}
    * {\sum_{\nu} H_{a\nu}}
    * \f]
-   * The function takes in all the matrices and only changes the 
+   * The function takes in all the matrices and only changes the
    * value of the W matrix.
    *
    * @param V Input matrix to be factorized.
@@ -73,7 +80,7 @@
    * H_{a\mu} \leftarrow H_{a\mu} \frac{\sum_{i} W_{ia} V_{i\mu}/(WH)_{i\mu}}
    * {\sum_{k} H_{ka}}
    * \f]
-   * The function takes in all the matrices and only changes the value 
+   * The function takes in all the matrices and only changes the value
    * of the H matrix.
    *
    * @param V Input matrix to be factorized.

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	Wed Jun 11 11:17:04 2014
@@ -13,9 +13,21 @@
     SVDBatchLearning(double u = 0.000001,
                      double kw = 0,
                      double kh = 0,
+                     double momentum = 0.2,
                      double min = -DBL_MIN,
                      double max = DBL_MAX)
-        : u(u), kw(kw), kh(kh), min(min), max(max) {}
+        : u(u), kw(kw), kh(kh), min(min), max(max), momentum(momentum)
+    {}
+
+    template<typename MatType>
+    void Initialize(const MatType& dataset, const size_t rank)
+    {
+        const size_t n = dataset.n_rows;
+        const size_t m = dataset.n_cols;
+
+        mW.zeros(n, rank);
+        mH.zeros(rank, m);
+    }
 
     /**
     * The update rule for the basis matrix W.
@@ -29,13 +41,15 @@
     template<typename MatType>
     inline void WUpdate(const MatType& V,
                                arma::mat& W,
-                               const arma::mat& H) const
+                               const arma::mat& H)
     {
         size_t n = V.n_rows;
         size_t m = V.n_cols;
 
         size_t r = W.n_cols;
 
+        mW = momentum * mW;
+
         arma::mat deltaW(n, r);
         deltaW.zeros();
 
@@ -46,7 +60,8 @@
             deltaW.row(i) -= kw * W.row(i);
         }
 
-        W += u * deltaW;
+        mW += u * deltaW;
+        W += mW;
     }
 
     /**
@@ -61,13 +76,15 @@
     template<typename MatType>
     inline void HUpdate(const MatType& V,
                                const arma::mat& W,
-                               arma::mat& H) const
+                               arma::mat& H)
     {
         size_t n = V.n_rows;
         size_t m = V.n_cols;
 
         size_t r = W.n_cols;
 
+        mH = momentum * mH;
+
         arma::mat deltaH(r, m);
         deltaH.zeros();
 
@@ -78,7 +95,8 @@
             deltaH.col(j) -= kh * H.col(j);
         }
 
-        H += u*deltaH;
+        mH += u*deltaH;
+        H += mH;
     }
 private:
 
@@ -94,6 +112,10 @@
     double kh;
     double min;
     double max;
+    double momentum;
+
+    arma::mat mW;
+    arma::mat mH;
 };
 } // namespace amf
 } // namespace mlpack



More information about the mlpack-svn mailing list