[mlpack-git] master: Refactor and code cleanup. Sparse matrix overloads are untouched -- for now. (165977e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:16:06 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 165977e800aba6ce70a1f252864973b3ea485cc0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 9 14:09:00 2015 -0500

    Refactor and code cleanup.
    Sparse matrix overloads are untouched -- for now.


>---------------------------------------------------------------

165977e800aba6ce70a1f252864973b3ea485cc0
 .../methods/amf/update_rules/nmf_mult_dist.hpp     |  48 +++++--
 .../methods/amf/update_rules/nmf_mult_div.hpp      |  53 +++++--
 .../amf/update_rules/svd_batch_learning.hpp        | 157 +++++++++++----------
 .../svd_complete_incremental_learning.hpp          | 129 +++++++++--------
 .../svd_incomplete_incremental_learning.hpp        | 140 +++++++++---------
 5 files changed, 296 insertions(+), 231 deletions(-)

diff --git a/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp b/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
index 7fb9c97..c8e96a7 100644
--- a/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
+++ b/src/mlpack/methods/amf/update_rules/nmf_mult_dist.hpp
@@ -14,11 +14,22 @@ namespace amf {
 
 /**
  * The multiplicative distance update rules for matrices W and H. This follows
- * a method described in the paper 'Algorithms for Non-negative Matrix Factorization'
- * by D. D. Lee and H. S. Seung. This is a multiplicative rule that ensures
- * that the Frobenius norm \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ is
- * non-increasing between subsequent iterations. Both of the update rules
- * for W and H are defined in this file.
+ * a method described in the following paper:
+ *
+ * @code
+ * @inproceedings{lee2001algorithms,
+ *   title={Algorithms for non-negative matrix factorization},
+ *   author={Lee, D.D. and Seung, H.S.},
+ *   booktitle={Advances in Neural Information Processing Systems 13
+ *       (NIPS 2000)},
+ *   pages={556--562},
+ *   year={2001}
+ * }
+ * @endcode
+ *
+ * This is a multiplicative rule that ensures that the Frobenius norm
+ * \f$ \sqrt{\sum_i \sum_j(V-WH)^2} \f$ is non-increasing between subsequent
+ * iterations. Both of the update rules for W and H are defined in this file.
  */
 class NMFMultiplicativeDistanceUpdate
 {
@@ -26,20 +37,25 @@ class NMFMultiplicativeDistanceUpdate
   // Empty constructor required for the UpdateRule template.
   NMFMultiplicativeDistanceUpdate() { }
 
+  /**
+   * Initialize the factorization.  These update rules hold no information, so
+   * the input parameters are ignored.
+   */
   template<typename MatType>
-  void Initialize(const MatType& dataset, const size_t rank)
+  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
   {
-        (void)dataset;
-        (void)rank;
+    // Nothing to do.
   }
 
   /**
-   * The update rule for the basis matrix W. The formula used is
+   * The update rule for the basis matrix W. The formula used isa
+   *
    * \f[
    * W_{ia} \leftarrow W_{ia} \frac{(VH^T)_{ia}}{(WHH^T)_{ia}}
    * \f]
-   * The function takes in all the matrices and only changes the
-   * value of the W matrix.
+   *
+   * The function takes in all the matrices and only changes the value of the W
+   * matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix to be updated.
@@ -55,11 +71,13 @@ class NMFMultiplicativeDistanceUpdate
 
   /**
    * The update rule for the encoding matrix H. The formula used is
+   *
    * \f[
    * H_{a\mu} \leftarrow H_{a\mu} \frac{(W^T V)_{a\mu}}{(W^T WH)_{a\mu}}
    * \f]
-   * The function takes in all the matrices and only changes the
-   * value of the H matrix.
+   *
+   * The function takes in all the matrices and only changes the value of the H
+   * matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix.
@@ -74,7 +92,7 @@ class NMFMultiplicativeDistanceUpdate
   }
 };
 
-}; // namespace amf
-}; // namespace mlpack
+} // namespace amf
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp b/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
index 24600b2..bd972c0 100644
--- a/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
+++ b/src/mlpack/methods/amf/update_rules/nmf_mult_div.hpp
@@ -14,9 +14,25 @@ namespace amf {
 
 /**
  * This follows a method described in the paper 'Algorithms for Non-negative
- * Matrix Factorization' by D. D. Lee and H. S. Seung. This is a multiplicative 
- * rule that ensures that the Kullback–Leibler divergence
- * \f$ \sum_i \sum_j (V_{ij} log\frac{V_{ij}}{(WH)_{ij}}-V_{ij}+(WH)_{ij}) \f$
+ *
+ * @code
+ * @inproceedings{lee2001algorithms,
+ *   title={Algorithms for non-negative matrix factorization},
+ *   author={Lee, D.D. and Seung, H.S.},
+ *   booktitle={Advances in Neural Information Processing Systems 13
+ *       (NIPS 2000)},
+ *   pages={556--562},
+ *   year={2001}
+ * }
+ * @endcode
+ *
+ * This is a multiplicative rule that ensures that the Kullback–Leibler
+ * divergence
+ *
+ * \f[
+ * \sum_i \sum_j (V_{ij} \log\frac{V_{ij}}{(W H)_{ij}} - V_{ij} + (W H)_{ij})
+ * \f]
+ *
  * is non-increasing between subsequent iterations. Both of the update rules
  * for W and H are defined in this file.
  *
@@ -30,21 +46,26 @@ class NMFMultiplicativeDivergenceUpdate
   // Empty constructor required for the WUpdateRule template.
   NMFMultiplicativeDivergenceUpdate() { }
 
+  /**
+   * Initialize the factorization.  These rules don't store any state, so the
+   * input values are ignore.
+   */
   template<typename MatType>
-  void Initialize(const MatType& dataset, const size_t rank)
+  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
   {
-    (void)dataset;
-    (void)rank;
+    // Nothing to do.
   }
 
   /**
    * The update rule for the basis matrix W. The formula used is
+   *
    * \f[
-   * W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu}/(WH)_{i\mu}}
+   * W_{ia} \leftarrow W_{ia} \frac{\sum_{\mu} H_{a\mu} V_{i\mu} / (W H)_{i\mu}}
    * {\sum_{\nu} H_{a\nu}}
    * \f]
-   * The function takes in all the matrices and only changes the
-   * value of the W matrix.
+   *
+   * The function takes in all the matrices and only changes the value of the W
+   * matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix to be updated.
@@ -52,8 +73,8 @@ class NMFMultiplicativeDivergenceUpdate
    */
   template<typename MatType>
   inline static void WUpdate(const MatType& V,
-                            arma::mat& W,
-                            const arma::mat& H)
+                             arma::mat& W,
+                             const arma::mat& H)
   {
     // Simple implementation left in the header file.
     arma::mat t1;
@@ -81,12 +102,14 @@ class NMFMultiplicativeDivergenceUpdate
 
   /**
    * The update rule for the encoding matrix H. The formula used is
+   *
    * \f[
    * H_{a\mu} \leftarrow H_{a\mu} \frac{\sum_{i} W_{ia} V_{i\mu}/(WH)_{i\mu}}
    * {\sum_{k} H_{ka}}
    * \f]
-   * The function takes in all the matrices and only changes the value
-   * of the H matrix.
+   *
+   * The function takes in all the matrices and only changes the value of the H
+   * matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix.
@@ -122,7 +145,7 @@ class NMFMultiplicativeDivergenceUpdate
   }
 };
 
-}; // namespace amf
-}; // namespace mlpack
+} // namespace amf
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
index 3b9308c..c912ac2 100644
--- a/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
@@ -4,8 +4,8 @@
  *
  * SVD factorizer used in AMF (Alternating Matrix Factorization).
  */
-#ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
-#define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
+#ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
+#define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
 
 #include <mlpack/core.hpp>
 
@@ -14,12 +14,24 @@ namespace amf {
 
 /**
  * This class implements SVD batch learning with momentum. This procedure is
- * described in the paper 'A Guide to singular Value Decomposition'
- * by Chih-Chao Ma. Class implements 'Algorithm 4' given in the paper.
- * This factorizer decomposes the matrix V into two matrices W and H such that
- * sum of sum of squared error between V and W*H is minimum. This optimization is
- * performed with gradient descent. To make gradient descent faster momentum is
- * added.
+ * described in the following paper:
+ *
+ * @code
+ * @techreport{ma2008guide,
+ *   title={A Guide to Singular Value Decomposition for Collaborative
+ *       Filtering},
+ *   author={Ma, Chih-Chao},
+ *   year={2008},
+ *   institution={Department of Computer Science, National Taiwan University}
+ * }
+ * @endcode
+ *
+ * This class implements 'Algorithm 4' as given in the paper.
+ *
+ * The factorizer decomposes the matrix V into two matrices W and H such that
+ * sum of sum of squared error between V and W * H is minimum. This optimization
+ * is performed with gradient descent. To make gradient descent faster, momentum
+ * is added.
  */
 class SVDBatchLearning
 {
@@ -42,8 +54,8 @@ class SVDBatchLearning
   }
 
   /**
-   * Initialize parameters before factorization.
-   * This function must be called before a new factorization.
+   * Initialize parameters before factorization.  This function must be called
+   * before a new factorization.  This resets the internally-held momentum.
    *
    * @param dataset Input matrix to be factorized.
    * @param rank rank of factorization
@@ -77,28 +89,29 @@ class SVDBatchLearning
 
     size_t r = W.n_cols;
 
-    // initialize the momentum of this iteration
+    // initialize the momentum of this iteration.
     mW = momentum * mW;
 
-    // compute the step
-    arma::mat deltaW(n, r);
-    deltaW.zeros();
-    for(size_t i = 0;i < n;i++)
+    // Compute the step.
+    arma::mat deltaW;
+    deltaW.zeros(n, r);
+    for (size_t i = 0; i < n; i++)
     {
-      for(size_t j = 0;j < m;j++)
+      for (size_t j = 0; j < m; j++)
       {
-        double val;
-        if((val = V(i, j)) != 0)
+        const double val = V(i, j);
+        if (val != 0)
           deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
-                                                  arma::trans(H.col(j));
+                                            arma::trans(H.col(j));
       }
-      // add regularization
-      if(kw != 0) deltaW.row(i) -= kw * W.row(i);
+      // Add regularization.
+      if (kw != 0)
+        deltaW.row(i) -= kw * W.row(i);
     }
 
-    // add the step to the momentum
+    // Add the step to the momentum.
     mW += u * deltaW;
-    // add the momentum to W matrix
+    // Add the momentum to the W matrix.
     W += mW;
   }
 
@@ -121,46 +134,46 @@ class SVDBatchLearning
 
     size_t r = W.n_cols;
 
-    // initialize the momentum of this iteration
+    // Initialize the momentum of this iteration.
     mH = momentum * mH;
 
-    // compute the step
-    arma::mat deltaH(r, m);
-    deltaH.zeros();
-    for(size_t j = 0;j < m;j++)
+    // Compute the step.
+    arma::mat deltaH;
+    deltaH.zeros(r, m);
+    for (size_t j = 0; j < m; j++)
     {
-      for(size_t i = 0;i < n;i++)
+      for (size_t i = 0; i < n; i++)
       {
-        double val;
-        if((val = V(i, j)) != 0)
-          deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
-                                                    arma::trans(W.row(i));
+        const double val = V(i, j);
+        if (val != 0)
+          deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
       }
-      // add regularization
-      if(kh != 0) deltaH.col(j) -= kh * H.col(j);
+      // Add regularization.
+      if (kh != 0)
+        deltaH.col(j) -= kh * H.col(j);
     }
 
-    // add step to the momentum
-    mH += u*deltaH;
-    // add momentum to H
+    // Add this step to the momentum.
+    mH += u * deltaH;
+    // Add the momentum to H.
     H += mH;
   }
 
  private:
-  //! step size of the algorithm
+  //! Step size of the algorithm.
   double u;
-  //! regularization parameter for matrix W
+  //! Regularization parameter for matrix W.
   double kw;
-  //! regularization parameter matrix for matrix H
+  //! Regularization parameter for matrix H.
   double kh;
-  //! momentum value
+  //! Momentum value (between 0 and 1).
   double momentum;
 
-  //! momentum matrix for matrix W
+  //! Momentum matrix for matrix W
   arma::mat mW;
-  //! momentum matrix for matrix H
+  //! Momentum matrix for matrix H
   arma::mat mH;
-}; // class SBDBatchLearning
+}; // class SVDBatchLearning
 
 //! TODO : Merge this template specialized function for sparse matrix using
 //!        common row_col_iterator
@@ -173,26 +186,28 @@ inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
                                                     arma::mat& W,
                                                     const arma::mat& H)
 {
-  size_t n = V.n_rows;
-
-  size_t r = W.n_cols;
+  const size_t n = V.n_rows;
+  const size_t r = W.n_cols;
 
   mW = momentum * mW;
 
-  arma::mat deltaW(n, r);
-  deltaW.zeros();
+  arma::mat deltaW;
+  deltaW.zeros(n, r);
 
-  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
+  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
   {
-    size_t row = it.row();
-    size_t col = it.col();
+    const size_t row = it.row();
+    const size_t col = it.col();
     deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
-                                                  arma::trans(H.col(col));
+                                             arma::trans(H.col(col));
   }
 
-  if(kw != 0) for(size_t i = 0; i < n; i++)
+  if (kw != 0)
   {
-    deltaW.row(i) -= kw * W.row(i);
+    for (size_t i = 0; i < n; i++)
+    {
+      deltaW.row(i) -= kw * W.row(i);
+    }
   }
 
   mW += u * deltaW;
@@ -204,35 +219,35 @@ inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
                                                     const arma::mat& W,
                                                     arma::mat& H)
 {
-  size_t m = V.n_cols;
-
-  size_t r = W.n_cols;
+  const size_t m = V.n_cols;
+  const size_t r = W.n_cols;
 
   mH = momentum * mH;
 
-  arma::mat deltaH(r, m);
-  deltaH.zeros();
+  arma::mat deltaH;
+  deltaH.zeros(r, m);
 
-  for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
+  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
   {
-    size_t row = it.row();
-    size_t col = it.col();
+    const size_t row = it.row();
+    const size_t col = it.col();
     deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
-                                                arma::trans(W.row(row));
+        W.row(row).t();
   }
 
-  if(kh != 0) for(size_t j = 0; j < m; j++)
+  if (kh != 0)
   {
-    deltaH.col(j) -= kh * H.col(j);
+    for (size_t j = 0; j < m; j++)
+    {
+      deltaH.col(j) -= kh * H.col(j);
+    }
   }
 
-  mH += u*deltaH;
+  mH += u * deltaH;
   H += mH;
 }
 
 } // namespace amf
 } // namespace mlpack
 
-#endif // __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
-
-
+#endif // __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
diff --git a/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
index 8b354dc..7d706b6 100644
--- a/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
@@ -4,8 +4,8 @@
  *
  * SVD factorizer used in AMF (Alternating Matrix Factorization).
  */
-#ifndef _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
-#define _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
+#ifndef __MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
+#define __MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
 
 #include <mlpack/core.hpp>
 
@@ -15,13 +15,24 @@ namespace amf
 {
 
 /**
- * This class computes SVD by method complete incremental batch learning. 
- * This procedure is described in the paper 'A Guide to singular Value Decomposition' 
- * by Chih-Chao Ma. Class implements 'Algorithm 3' given in the paper. Complete
- * incremental learning is an extreme extreme case of incremental learning where 
- * feature vectors are updated after looking at each single score. This approach
- * differs from incomplete incremental learning where feature vectors are updated 
- * after seeing scores of individual users.
+ * This class computes SVD using complete incremental batch learning, as
+ * described in the following paper:
+ *
+ * @code
+ * @techreport{ma2008guide,
+ *   title={A Guide to Singular Value Decomposition for Collaborative
+ *       Filtering},
+ *   author={Ma, Chih-Chao},
+ *   year={2008},
+ *   institution={Department of Computer Science, National Taiwan University}
+ * }
+ * @endcode
+ *
+ * This class implements 'Algorithm 3' given in the paper.  Complete incremental
+ * learning is an extreme case of incremental learning, where feature vectors
+ * are updated after looking at each single element in the input matrix (V).
+ * This approach differs from incomplete incremental learning where feature
+ * vectors are updated after seeing columns of elements in the input matrix.
  *
  * @see SVDIncompleteIncrementalLearning
  */
@@ -30,40 +41,39 @@ class SVDCompleteIncrementalLearning
 {
  public:
   /**
-   * Empty constructor
+   * Initialize the SVDCompleteIncrementalLearning class with the given
+   * parameters.
    *
-   * @param u step value used in batch learning
-   * @param kw regularization constant for W matrix
-   * @param kh regularization constant for H matrix
+   * @param u Step value used in batch learning.
+   * @param kw Regularization constant for W matrix.
+   * @param kh Regularization constant for H matrix.
    */
   SVDCompleteIncrementalLearning(double u = 0.0001,
                                  double kw = 0,
                                  double kh = 0)
             : u(u), kw(kw), kh(kh)
-    {}
+  {
+    // Nothing to do.
+  }
 
   /**
-   * Initialize parameters before factorization.
-   * This function must be called before a new factorization.
+   * Initialize parameters before factorization.  This function must be called
+   * before a new factorization.  For this initialization, the input parameters
+   * are unnecessary; we are only setting the current element index to 0.
    *
    * @param dataset Input matrix to be factorized.
    * @param rank rank of factorization
    */
-  void Initialize(const MatType& dataset, const size_t rank)
+  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
   {
-    (void)rank;
-    n = dataset.n_rows;
-    m = dataset.n_cols;
-
-    // initialize the current score counters
+    // Initialize the current score counters.
     currentUserIndex = 0;
     currentItemIndex = 0;
   }
 
   /**
-   * The update rule for the basis matrix W.
-   * The function takes in all the matrices and only changes the
-   * value of the W matrix.
+   * The update rule for the basis matrix W.  The function takes in all the
+   * matrices and only changes the value of the W matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix to be updated.
@@ -73,25 +83,27 @@ class SVDCompleteIncrementalLearning
                       arma::mat& W,
                       const arma::mat& H)
   {
-    arma::mat deltaW(1, W.n_cols);
-    deltaW.zeros();
+    arma::mat deltaW;
+    deltaW.zeros(1, W.n_cols);
 
-    // loop till a non-zero entry is found 
+    // Loop until a non-zero entry is found.
     while(true)
     {
-      double val;
-      // update feature vector if current entry is non-zero and break the loop
-      if((val = V(currentItemIndex, currentUserIndex)) != 0)
+      const double val = V(currentItemIndex, currentUserIndex);
+      // Update feature vector if current entry is non-zero and break the loop.
+      if (val != 0)
       {
-        deltaW += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex))) 
-                                        * arma::trans(H.col(currentUserIndex));
-        // add regularization                               
-        if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
+        deltaW += (val - arma::dot(W.row(currentItemIndex),
+            H.col(currentUserIndex))) * H.col(currentUserIndex).t();
+
+        // Add regularization.
+        if (kw != 0)
+          deltaW -= kw * W.row(currentItemIndex);
         break;
       }
     }
 
-    W.row(currentItemIndex) += u*deltaW;
+    W.row(currentItemIndex) += u * deltaW;
   }
 
   /**
@@ -107,44 +119,40 @@ class SVDCompleteIncrementalLearning
                       const arma::mat& W,
                       arma::mat& H)
   {
-    arma::mat deltaH(H.n_rows, 1);
-    deltaH.zeros();
+    arma::mat deltaH;
+    deltaH.zeros(H.n_rows, 1);
 
-    const double& val = V(currentItemIndex, currentUserIndex);
+    const double val = V(currentItemIndex, currentUserIndex);
 
-    // update H matrix based on the non-zero enrty found in WUpdate function
-    deltaH += (val - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex))) 
-                                      * arma::trans(W.row(currentItemIndex));
-    // add regularization
-    if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+    // Update H matrix based on the non-zero entry found in WUpdate function.
+    deltaH += (val - arma::dot(W.row(currentItemIndex),
+        H.col(currentUserIndex))) * W.row(currentItemIndex).t();
+    // Add regularization.
+    if (kh != 0)
+      deltaH -= kh * H.col(currentUserIndex);
 
-    // move on to the next entry
+    // Move on to the next entry.
     currentUserIndex = currentUserIndex + 1;
-    if(currentUserIndex == n)
+    if (currentUserIndex == V.n_rows)
     {
       currentUserIndex = 0;
-      currentItemIndex = (currentItemIndex + 1) % m;
+      currentItemIndex = (currentItemIndex + 1) % V.n_cols;
     }
 
     H.col(currentUserIndex++) += u * deltaH;
   }
 
  private:
-  //! step count of batch learning
+  //! Step count of batch learning.
   double u;
-  //! regularization parameter for matrix w
+  //! Regularization parameter for matrix W.
   double kw;
-  //! regualrization matrix for matrix H
+  //! Regularization parameter for matrix H.
   double kh;
 
-  //! number of items
-  size_t n;
-  //! number of users
-  size_t m;
-  
-  //! user of index of current entry
+  //! User of index of current entry.
   size_t currentUserIndex;
-  //! item index of current entry
+  //! Item index of current entry.
   size_t currentItemIndex;
 };
 
@@ -254,9 +262,8 @@ class SVDCompleteIncrementalLearning<arma::sp_mat>
   bool isStart;
 }; // class SVDCompleteIncrementalLearning
 
-}; // namespace amf
-}; // namespace mlpack
-
+} // namespace amf
+} // namespace mlpack
 
-#endif // _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
+#endif
 
diff --git a/src/mlpack/methods/amf/update_rules/svd_incomplete_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_incomplete_incremental_learning.hpp
index 0903c8d..420b368 100644
--- a/src/mlpack/methods/amf/update_rules/svd_incomplete_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_incomplete_incremental_learning.hpp
@@ -4,8 +4,8 @@
  *
  * SVD factorizer used in AMF (Alternating Matrix Factorization).
  */
-#ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
-#define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+#ifndef __MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
+#define __MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
 
 namespace mlpack
 {
@@ -13,15 +13,25 @@ namespace amf
 {
 
 /**
- * This class computes SVD by method incomplete incremental batch learning. 
- * This procedure is described in the paper 'A Guide to singular Value Decomposition' 
- * by Chih-Chao Ma. Class implements 'Algorithm 2' given in the paper. 
- * Incremental learning modifies only some feature values in W and H after 
- * scanning part of the training data. Incomplete incremental learning approach 
- * is different from batch learning as each object feature vector Mj has an 
- * additional regularization coefficient, which is equal to the number of 
- * existing scores for object j. Therefore, an object with more scores has a 
- * larger regularization coefficient in this incremental learning approach.
+ * This class computes SVD using incomplete incremental batch learning, as
+ * described in the following paper:
+ *
+ * @code
+ * @techreport{ma2008guide,
+ *   title={A Guide to Singular Value Decomposition for Collaborative
+ *       Filtering},
+ *   author={Ma, Chih-Chao},
+ *   year={2008},
+ *   institution={Department of Computer Science, National Taiwan University}
+ * }
+ * @endcode
+ *
+ * This class implements 'Algorithm 2' as given in the paper.  Incremental
+ * learning modifies only some feature values in W and H after scanning part of
+ * the input matrix (V).  This differs from batch learning, which considers
+ * every element in V for each update of W and H.  The regularization technique
+ * is also different: in incomplete incremental learning, regularization takes
+ * into account the number of elements in a given column of V.
  *
  * @see SVDBatchLearning
  */
@@ -29,34 +39,32 @@ class SVDIncompleteIncrementalLearning
 {
  public:
   /**
-   * Empty constructor
+   * Initialize the parameters of SVDIncompleteIncrementalLearning.
    *
-   * @param u step value used in batch learning
-   * @param kw regularization constant for W matrix
-   * @param kh regularization constant for H matrix
+   * @param u Step value used in batch learning.
+   * @param kw Regularization constant for W matrix.
+   * @param kh Regularization constant for H matrix.
    */
   SVDIncompleteIncrementalLearning(double u = 0.001,
                                    double kw = 0,
                                    double kh = 0)
           : u(u), kw(kw), kh(kh)
-  {}
+  {
+    // Nothing to do.
+  }
 
   /**
-   * Initialize parameters before factorization.
-   * This function must be called before a new factorization.
+   * Initialize parameters before factorization.  This function must be called
+   * before a new factorization.  This simply sets the column being considered
+   * to 0, so the input matrix and rank are not used.
    *
    * @param dataset Input matrix to be factorized.
    * @param rank rank of factorization
    */
   template<typename MatType>
-  void Initialize(const MatType& dataset, const size_t rank)
+  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
   {
-    (void)rank;
-  
-    n = dataset.n_rows;
-    m = dataset.n_cols;
-
-    // set the current user to 0
+    // Set the current user to 0.
     currentUserIndex = 0;
   }
 
@@ -74,29 +82,29 @@ class SVDIncompleteIncrementalLearning
                       arma::mat& W,
                       const arma::mat& H)
   {
-    arma::mat deltaW(n, W.n_cols);
-    deltaW.zeros();
+    arma::mat deltaW;
+    deltaW.zeros(V.n_rows, W.n_cols);
 
-    // iterate through all the rating by this user to update corresponding
-    // item feature feature vector
-    for(size_t i = 0;i < n;i++)
+    // Iterate through all the rating by this user to update corresponding item
+    // feature feature vector.
+    for (size_t i = 0; i < V.n_rows; ++i)
     {
-      double val;
-      // update only if the rating is non-zero
-      if((val = V(i, currentUserIndex)) != 0)
+      const double val = V(i, currentUserIndex);
+      // Update only if the rating is non-zero.
+      if (val != 0)
         deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
-                                         arma::trans(H.col(currentUserIndex));
-      // add regularization
-      if(kw != 0) deltaW.row(i) -= kw * W.row(i);
+            H.col(currentUserIndex).t();
+      // Add regularization.
+      if (kw != 0)
+        deltaW.row(i) -= kw * W.row(i);
     }
 
-    W += u*deltaW;
+    W += u * deltaW;
   }
 
   /**
-   * The update rule for the encoding matrix H.
-   * The function takes in all the matrices and only changes the
-   * value of the H matrix.
+   * The update rule for the encoding matrix H.  The function takes in all the
+   * matrices and only changes the value of the H matrix.
    *
    * @param V Input matrix to be factorized.
    * @param W Basis matrix.
@@ -107,41 +115,37 @@ class SVDIncompleteIncrementalLearning
                       const arma::mat& W,
                       arma::mat& H)
   {
-    arma::mat deltaH(H.n_rows, 1);
-    deltaH.zeros();
+    arma::vec deltaH;
+    deltaH.zeros(H.n_rows);
 
-    // iterate through all the rating by this user to update corresponding
-    // item feature feature vector
-    for(size_t i = 0;i < n;i++)
+    // Iterate through all the rating by this user to update corresponding item
+    // feature feature vector.
+    for (size_t i = 0; i < V.n_rows; ++i)
     {
-      double val;
-      // update only if the rating is non-zero
-      if((val = V(i, currentUserIndex)) != 0)
+      const double val = V(i, currentUserIndex);
+      // Update only if the rating is non-zero.
+      if (val != 0)
         deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
-                                                    arma::trans(W.row(i));
+            W.row(i).t();
     }
-    // add regularization
-    if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+    // Add regularization.
+    if (kh != 0)
+      deltaH -= kh * H.col(currentUserIndex);
 
-    // update H matrix and move on to the next user
+    // Update H matrix and move on to the next user.
     H.col(currentUserIndex++) += u * deltaH;
-    currentUserIndex = currentUserIndex % m;
+    currentUserIndex = currentUserIndex % V.n_cols;
   }
 
  private:
-  //! step count of btach learning
+  //! Step size of batch learning.
   double u;
-  //! regularization parameter for W matrix
+  //! Regularization parameter for W matrix.
   double kw;
-  //! regularization parameter for H matrix
+  //! Regularization parameter for H matrix.
   double kh;
 
-  //! number of items
-  size_t n;
-  //! number of users
-  size_t m;
-
-  //! current user under consideration 
+  //! Current user under consideration.
   size_t currentUserIndex;
 };
 
@@ -155,7 +159,7 @@ inline void SVDIncompleteIncrementalLearning::
                                                           arma::mat& W,
                                                           const arma::mat& H)
 {
-  arma::mat deltaW(n, W.n_cols);
+  arma::mat deltaW(V.n_rows, W.n_cols);
   deltaW.zeros();
   for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
                                       it != V.end_col(currentUserIndex);it++)
@@ -191,12 +195,10 @@ inline void SVDIncompleteIncrementalLearning::
   if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
 
   H.col(currentUserIndex++) += u * deltaH;
-  currentUserIndex = currentUserIndex % m;
+  currentUserIndex = currentUserIndex % V.n_cols;
 }
 
-}; // namepsace amf
-}; // namespace mlpack
-
-
-#endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+} // namepsace amf
+} // namespace mlpack
 
+#endif



More information about the mlpack-git mailing list