[mlpack-svn] r16776 - in mlpack/trunk/src/mlpack/methods/amf: termination_policies update_rules

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jul 7 17:31:54 EDT 2014


Author: sumedhghaisas
Date: Mon Jul  7 17:31:54 2014
New Revision: 16776

Log:
* added SVD Incomplete incremental learning
* added Termination Policy wrapper for SVD Incomplete Learning


Added:
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp

Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt	Mon Jul  7 17:31:54 2014
@@ -4,6 +4,7 @@
   simple_residue_termination.hpp
   simple_tolerance_termination.hpp
   validation_rmse_termination.hpp
+  incomplete_incremental_termination.hpp
 )
 
 # Add directory name to sources.

Added: mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp	Mon Jul  7 17:31:54 2014
@@ -0,0 +1,60 @@
+/**
+ * @file incomplete_incremental_termination.hpp
+ * @author Sumedh Ghaisas
+ */
+#ifndef _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+#define _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace amf {
+
+template <class TerminationPolicy>
+class IncompleteIncrementalTermination
+{
+ public:
+  IncompleteIncrementalTermination(TerminationPolicy t_policy = TerminationPolicy())
+            : t_policy(t_policy) {}
+
+  template <class MatType>
+  void Initialize(const MatType& V)
+  {
+    t_policy.Initialize(V);
+
+    incrementalIndex = V.n_rows;
+    iteration = 0;
+  }
+
+  bool IsConverged()
+  {
+    return t_policy.IsConverged();
+  }
+
+  void Step(const arma::mat& W, const arma::mat& H)
+  {
+    if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
+    iteration++;
+  }
+
+  const double& Index()
+  {
+    return t_policy.Index();
+  }
+  const size_t& Iteration()
+  {
+    return iteration;
+  }
+
+ private:
+  TerminationPolicy t_policy;
+
+  size_t incrementalIndex;
+  size_t iteration;
+};
+
+}; // namespace amf
+}; // namespace mlpack
+
+#endif
+

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt	Mon Jul  7 17:31:54 2014
@@ -4,6 +4,8 @@
   nmf_als.hpp
   nmf_mult_dist.hpp
   nmf_mult_div.hpp
+  svd_batchlearning.hpp
+  svd_incremental_learning.hpp
 )
 
 # Add directory name to sources.

Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp	Mon Jul  7 17:31:54 2014
@@ -143,7 +143,8 @@
   {
     size_t row = it.row();
     size_t col = it.col();
-    deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(H.col(col));
+    deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * 
+                                                  arma::trans(H.col(col));
   }
 
   if(kw != 0) for(size_t i = 0; i < n; i++)
@@ -173,7 +174,8 @@
   {
     size_t row = it.row();
     size_t col = it.col();
-    deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(W.row(row));
+    deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * 
+                                                arma::trans(W.row(row));
   }
 
   if(kh != 0) for(size_t j = 0; j < m; j++)

Added: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp	Mon Jul  7 17:31:54 2014
@@ -0,0 +1,147 @@
+#ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+#define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+
+namespace mlpack
+{
+namespace amf
+{
+class SVDIncrementalLearning
+{
+ public:
+  SVDIncrementalLearning(double u = 0.001,
+                   double kw = 0,
+                   double kh = 0,
+                   double min = -DBL_MIN,
+                   double max = DBL_MAX)
+        : u(u), kw(kw), kh(kh), min(min), max(max)
+    {}
+
+  template<typename MatType>
+  void Initialize(const MatType& dataset, const size_t rank)
+  {
+    n = dataset.n_rows;
+    m = dataset.n_cols;
+
+    currentUserIndex = 0;
+  }
+
+  /**
+   * The update rule for the basis matrix W.
+   * The function takes in all the matrices and only changes the
+   * value of the W matrix.
+   *
+   * @param V Input matrix to be factorized.
+   * @param W Basis matrix to be updated.
+   * @param H Encoding matrix.
+   */
+  template<typename MatType>
+  inline void WUpdate(const MatType& V,
+                      arma::mat& W,
+                      const arma::mat& H)
+  {
+    arma::mat deltaW(n, W.n_cols);
+    deltaW.zeros();
+    for(size_t i = 0;i < n;i++)
+    {
+      double val;
+      if((val = V(i, currentUserIndex)) != 0)
+        deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+                                         arma::trans(H.col(currentUserIndex));
+      if(kw != 0) deltaW -= kw * W.row(i);
+    }
+
+    W += u*deltaW;
+  }
+
+  /**
+   * The update rule for the encoding matrix H.
+   * The function takes in all the matrices and only changes the
+   * value of the H matrix.
+   *
+   * @param V Input matrix to be factorized.
+   * @param W Basis matrix.
+   * @param H Encoding matrix to be updated.
+   */
+  template<typename MatType>
+  inline void HUpdate(const MatType& V,
+                      const arma::mat& W,
+                      arma::mat& H)
+  {
+    arma::mat deltaH(H.n_rows, 1);
+    deltaH.zeros();
+
+    for(size_t i = 0;i < n;i++)
+    {
+      double val;
+      if((val = V(i, currentUserIndex)) != 0)
+        deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+                                                    arma::trans(W.row(i));
+    }
+    if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+    H.col(currentUserIndex++) += u * deltaH;
+    currentUserIndex = currentUserIndex % m;
+  }
+
+ private:
+  double u;
+  double kw;
+  double kh;
+  double min;
+  double max;
+
+  size_t n;
+  size_t m;
+
+  size_t currentUserIndex;
+};
+
+template<>
+inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
+                                                          arma::mat& W,
+                                                          const arma::mat& H)
+{
+  arma::mat deltaW(n, W.n_cols);
+  deltaW.zeros();
+  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+                                      it != V.end_col(currentUserIndex);it++)
+  {
+    double val = *it;
+    size_t i = it.row();
+    deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+                                         arma::trans(H.col(currentUserIndex));
+    if(kw != 0) deltaW -= kw * W.row(i);
+  }
+
+  W += u*deltaW;
+}
+
+template<>
+inline void SVDIncrementalLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
+                                                    const arma::mat& W,
+                                                    arma::mat& H)
+{
+  arma::mat deltaH(H.n_rows, 1);
+  deltaH.zeros();
+
+  for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+                                        it != V.end_col(currentUserIndex);it++)
+  {
+    double val = *it;
+    size_t i = it.row();
+    if((val = V(i, currentUserIndex)) != 0)
+      deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+                                                    arma::trans(W.row(i));
+  }
+  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+  H.col(currentUserIndex++) += u * deltaH;
+  currentUserIndex = currentUserIndex % m;
+}
+
+}; // namepsace amf
+}; // namespace mlpack
+
+
+#endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+



More information about the mlpack-svn mailing list