[mlpack-svn] r17466 - in mlpack/tags/mlpack-1.0.11: . src/mlpack/methods/amf/update_rules

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Dec 7 14:37:51 EST 2014


Author: rcurtin
Date: Sun Dec  7 14:37:51 2014
New Revision: 17466

Log:
Merge r17388.


Modified:
   mlpack/tags/mlpack-1.0.11/   (props changed)
   mlpack/tags/mlpack-1.0.11/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp

Modified: mlpack/tags/mlpack-1.0.11/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
==============================================================================
--- mlpack/tags/mlpack-1.0.11/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp	(original)
+++ mlpack/tags/mlpack-1.0.11/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp	Sun Dec  7 14:37:51 2014
@@ -22,13 +22,29 @@
 
 #include <mlpack/core.hpp>
 
-namespace mlpack
-{
-namespace amf
-{
+namespace mlpack {
+namespace amf {
+
+/**
+ * This class implements SVD batch learning with momentum. This procedure is
+ * described in the paper 'A Guide to singular Value Decomposition'
+ * by Chih-Chao Ma. Class implements 'Algorithm 4' given in the paper.
+ * This factorizer decomposes the matrix V into two matrices W and H such that
+ * sum of sum of squared error between V and W*H is minimum. This optimization is
+ * performed with gradient descent. To make gradient descent faster momentum is
+ * added.
+ */
 class SVDBatchLearning
 {
  public:
+  /**
+   * SVD Batch learning constructor.
+   *
+   * @param u step value used in batch learning
+   * @param kw regularization constant for W matrix
+   * @param kh regularization constant for H matrix
+   * @param momentum momentum applied to batch learning process
+   */
   SVDBatchLearning(double u = 0.0002,
                    double kw = 0,
                    double kh = 0,
@@ -78,7 +94,7 @@
       {
         double val;
         if((val = V(i, j)) != 0)
-          deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) * 
+          deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
                                                   arma::trans(H.col(j));
       }
       if(kw != 0) deltaW.row(i) -= kw * W.row(i);
@@ -118,7 +134,7 @@
       {
         double val;
         if((val = V(i, j)) != 0)
-          deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * 
+          deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
                                                     arma::trans(W.row(i));
       }
       if(kh != 0) deltaH.col(j) -= kh * H.col(j);
@@ -127,7 +143,7 @@
     mH += u*deltaH;
     H += mH;
   }
-  
+
  private:
   double u;
   double kw;
@@ -140,7 +156,13 @@
   arma::mat mH;
 };
 
-template<> 
+//! TODO : Merge this template specialized function for sparse matrix using
+//!        common row_col_iterator
+
+/**
+ * WUpdate function specialization for sparse matrix
+ */
+template<>
 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
                                                     arma::mat& W,
                                                     const arma::mat& H)
@@ -158,7 +180,7 @@
   {
     size_t row = it.row();
     size_t col = it.col();
-    deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * 
+    deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
                                                   arma::trans(H.col(col));
   }
 
@@ -189,7 +211,7 @@
   {
     size_t row = it.row();
     size_t col = it.col();
-    deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * 
+    deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
                                                 arma::trans(W.row(row));
   }
 
@@ -205,7 +227,6 @@
 } // namespace amf
 } // namespace mlpack
 
-
 #endif
 
 



More information about the mlpack-svn mailing list