[mlpack-svn] r16776 - in mlpack/trunk/src/mlpack/methods/amf: termination_policies update_rules
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jul 7 17:31:54 EDT 2014
Author: sumedhghaisas
Date: Mon Jul 7 17:31:54 2014
New Revision: 16776
Log:
* added SVD Incomplete incremental learning
* added Termination Policy wrapper for SVD Incomplete Learning
Added:
mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
Modified:
mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
Modified: mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/CMakeLists.txt Mon Jul 7 17:31:54 2014
@@ -4,6 +4,7 @@
simple_residue_termination.hpp
simple_tolerance_termination.hpp
validation_rmse_termination.hpp
+ incomplete_incremental_termination.hpp
)
# Add directory name to sources.
Added: mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/termination_policies/incomplete_incremental_termination.hpp Mon Jul 7 17:31:54 2014
@@ -0,0 +1,60 @@
+/**
+ * @file incomplete_incremental_termination.hpp
+ * @author Sumedh Ghaisas
+ */
+#ifndef _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+#define _INCOMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace amf {
+
+template <class TerminationPolicy>
+class IncompleteIncrementalTermination
+{
+ public:
+ IncompleteIncrementalTermination(TerminationPolicy t_policy = TerminationPolicy())
+ : t_policy(t_policy) {}
+
+ template <class MatType>
+ void Initialize(const MatType& V)
+ {
+ t_policy.Initialize(V);
+
+ incrementalIndex = V.n_rows;
+ iteration = 0;
+ }
+
+ bool IsConverged()
+ {
+ return t_policy.IsConverged();
+ }
+
+ void Step(const arma::mat& W, const arma::mat& H)
+ {
+ if(iteration % incrementalIndex == 0) t_policy.Step(W, H);
+ iteration++;
+ }
+
+ const double& Index()
+ {
+ return t_policy.Index();
+ }
+ const size_t& Iteration()
+ {
+ return iteration;
+ }
+
+ private:
+ TerminationPolicy t_policy;
+
+ size_t incrementalIndex;
+ size_t iteration;
+};
+
+}; // namespace amf
+}; // namespace mlpack
+
+#endif
+
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/CMakeLists.txt Mon Jul 7 17:31:54 2014
@@ -4,6 +4,8 @@
nmf_als.hpp
nmf_mult_dist.hpp
nmf_mult_div.hpp
+ svd_batchlearning.hpp
+ svd_incremental_learning.hpp
)
# Add directory name to sources.
Modified: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_batchlearning.hpp Mon Jul 7 17:31:54 2014
@@ -143,7 +143,8 @@
{
size_t row = it.row();
size_t col = it.col();
- deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(H.col(col));
+ deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(H.col(col));
}
if(kw != 0) for(size_t i = 0; i < n; i++)
@@ -173,7 +174,8 @@
{
size_t row = it.row();
size_t col = it.col();
- deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) * arma::trans(W.row(row));
+ deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
+ arma::trans(W.row(row));
}
if(kh != 0) for(size_t j = 0; j < m; j++)
Added: mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/amf/update_rules/svd_incremental_learning.hpp Mon Jul 7 17:31:54 2014
@@ -0,0 +1,147 @@
+#ifndef SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+#define SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+
+namespace mlpack
+{
+namespace amf
+{
+class SVDIncrementalLearning
+{
+ public:
+ SVDIncrementalLearning(double u = 0.001,
+ double kw = 0,
+ double kh = 0,
+ double min = -DBL_MIN,
+ double max = DBL_MAX)
+ : u(u), kw(kw), kh(kh), min(min), max(max)
+ {}
+
+ template<typename MatType>
+ void Initialize(const MatType& dataset, const size_t rank)
+ {
+ n = dataset.n_rows;
+ m = dataset.n_cols;
+
+ currentUserIndex = 0;
+ }
+
+ /**
+ * The update rule for the basis matrix W.
+ * The function takes in all the matrices and only changes the
+ * value of the W matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix to be updated.
+ * @param H Encoding matrix.
+ */
+ template<typename MatType>
+ inline void WUpdate(const MatType& V,
+ arma::mat& W,
+ const arma::mat& H)
+ {
+ arma::mat deltaW(n, W.n_cols);
+ deltaW.zeros();
+ for(size_t i = 0;i < n;i++)
+ {
+ double val;
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(H.col(currentUserIndex));
+ if(kw != 0) deltaW -= kw * W.row(i);
+ }
+
+ W += u*deltaW;
+ }
+
+ /**
+ * The update rule for the encoding matrix H.
+ * The function takes in all the matrices and only changes the
+ * value of the H matrix.
+ *
+ * @param V Input matrix to be factorized.
+ * @param W Basis matrix.
+ * @param H Encoding matrix to be updated.
+ */
+ template<typename MatType>
+ inline void HUpdate(const MatType& V,
+ const arma::mat& W,
+ arma::mat& H)
+ {
+ arma::mat deltaH(H.n_rows, 1);
+ deltaH.zeros();
+
+ for(size_t i = 0;i < n;i++)
+ {
+ double val;
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(W.row(i));
+ }
+ if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+ H.col(currentUserIndex++) += u * deltaH;
+ currentUserIndex = currentUserIndex % m;
+ }
+
+ private:
+ double u;
+ double kw;
+ double kh;
+ double min;
+ double max;
+
+ size_t n;
+ size_t m;
+
+ size_t currentUserIndex;
+};
+
+template<>
+inline void SVDIncrementalLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ arma::mat& W,
+ const arma::mat& H)
+{
+ arma::mat deltaW(n, W.n_cols);
+ deltaW.zeros();
+ for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+ it != V.end_col(currentUserIndex);it++)
+ {
+ double val = *it;
+ size_t i = it.row();
+ deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(H.col(currentUserIndex));
+ if(kw != 0) deltaW -= kw * W.row(i);
+ }
+
+ W += u*deltaW;
+}
+
+template<>
+inline void SVDIncrementalLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
+ const arma::mat& W,
+ arma::mat& H)
+{
+ arma::mat deltaH(H.n_rows, 1);
+ deltaH.zeros();
+
+ for(arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
+ it != V.end_col(currentUserIndex);it++)
+ {
+ double val = *it;
+ size_t i = it.row();
+ if((val = V(i, currentUserIndex)) != 0)
+ deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
+ arma::trans(W.row(i));
+ }
+ if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
+
+ H.col(currentUserIndex++) += u * deltaH;
+ currentUserIndex = currentUserIndex % m;
+}
+
+}; // namepsace amf
+}; // namespace mlpack
+
+
+#endif // SVD_INCREMENTAL_LEARNING_HPP_INCLUDED
+
More information about the mlpack-svn
mailing list