[mlpack-git] master: * added plain SVD factorization - wrapper of arma::svd for CF module (27a79dd)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:55:48 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 27a79dd52629b3a8b2d13c5b3d71707df330942c
Author: sumedhghaisas <sumedhghaisas at gmail.com>
Date:   Mon Aug 4 21:43:23 2014 +0000

    * added plain SVD factorization - wrapper of arma::svd for CF module


>---------------------------------------------------------------

27a79dd52629b3a8b2d13c5b3d71707df330942c
 src/mlpack/core/arma_extend/SpMat_extra_bones.hpp  | 15 +++++
 src/mlpack/core/arma_extend/SpMat_extra_meat.hpp   | 28 ++++++++
 .../complete_incremental_termination.hpp           | 13 +++-
 src/mlpack/methods/cf/CMakeLists.txt               |  2 +
 src/mlpack/methods/cf/plain_svd.cpp                | 75 ++++++++++++++++++++++
 src/mlpack/methods/cf/plain_svd.hpp                | 30 +++++++++
 6 files changed, 161 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
index 3214aa1..d9a9709 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_bones.hpp
@@ -17,3 +17,18 @@ template<typename T1, typename T2> inline SpMat(
     const uword n_cols,
     const bool sort_locations = true);
 #endif
+
+/*
+ * Extra functions for SpMat<eT> 
+ * Adding definition of row_col_iterator to generalize with Mat<eT>::row_col_iterator
+ */
+typedef iterator row_col_iterator;
+typedef const_iterator const_row_col_iterator;
+
+// begin for iterator row_col_iterator
+inline const_row_col_iterator begin_row_col() const;
+inline row_col_iterator begin_row_col();
+
+// end for iterator row_col_iterator
+inline const_row_col_iterator end_row_col() const;
+inline row_col_iterator end_row_col();
diff --git a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
index 58134e8..7c5558e 100644
--- a/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
+++ b/src/mlpack/core/arma_extend/SpMat_extra_meat.hpp
@@ -249,3 +249,31 @@ SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_e
   }
 
 #endif
+
+template<typename eT>
+inline typename SpMat<eT>::const_row_col_iterator 
+SpMat<eT>::begin_row_col() const
+  {
+  return begin();
+  }
+  
+template<typename eT>
+inline typename SpMat<eT>::row_col_iterator
+SpMat<eT>::begin_row_col()
+  {
+  return begin();
+  }
+  
+template<typename eT>
+inline typename SpMat<eT>::const_row_col_iterator
+SpMat<eT>::end_row_col() const
+  {
+  return end();
+  }
+  
+template<typename eT>
+inline typename SpMat<eT>::row_col_iterator
+SpMat<eT>::end_row_col()
+  {
+  return end();
+  }
diff --git a/src/mlpack/methods/amf/termination_policies/complete_incremental_termination.hpp b/src/mlpack/methods/amf/termination_policies/complete_incremental_termination.hpp
index 8b45186..7e28d0a 100644
--- a/src/mlpack/methods/amf/termination_policies/complete_incremental_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/complete_incremental_termination.hpp
@@ -1,5 +1,14 @@
-#ifndef COMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
-#define COMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+/**
+ * @file cf.hpp
+ * @author Sumedh Ghaisas
+ *
+ * Collaborative filtering.
+ *
+ * Defines the CF class to perform collaborative filtering on the specified data
+ * set using alternating least squares (ALS).
+ */
+#ifndef _MLPACK_METHODS_AMF_COMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
+#define _MLPACK_METHODS_AMF_COMPLETE_INCREMENTAL_TERMINATION_HPP_INCLUDED
 
 namespace mlpack
 {
diff --git a/src/mlpack/methods/cf/CMakeLists.txt b/src/mlpack/methods/cf/CMakeLists.txt
index 6413af4..840dbc9 100644
--- a/src/mlpack/methods/cf/CMakeLists.txt
+++ b/src/mlpack/methods/cf/CMakeLists.txt
@@ -3,6 +3,8 @@
 set(SOURCES
   cf.hpp
   cf_impl.hpp
+  plain_svd.hpp
+  plain_svd.cpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/cf/plain_svd.cpp b/src/mlpack/methods/cf/plain_svd.cpp
new file mode 100644
index 0000000..d52bbcd
--- /dev/null
+++ b/src/mlpack/methods/cf/plain_svd.cpp
@@ -0,0 +1,75 @@
+#include "plain_svd.hpp"
+
+using namespace mlpack;
+using namespace mlpack::svd;
+
+double PlainSVD::Apply(const arma::mat& V,
+                       arma::mat& W,
+                       arma::mat& sigma,
+                       arma::mat& H) const
+{
+  arma::vec E;
+  arma::svd(W, E, H, V);
+
+  sigma.zeros(V.n_rows, V.n_cols);
+
+  for(size_t i = 0;i < sigma.n_rows && i < sigma.n_cols;i++)
+    sigma(i, i) = E(i, 0);
+
+  arma::mat V_rec = W * sigma * arma::trans(H);
+
+  size_t n = V.n_rows;
+  size_t m = V.n_cols;
+  double sum = 0;
+  for(size_t i = 0;i < n;i++)
+  {
+    for(size_t j = 0;j < m;j++)
+    {
+      double temp = V(i, j);
+      temp = (temp - V_rec(i, j));
+      temp = temp * temp;
+      sum += temp;
+    }
+  }
+  return sqrt(sum / (n * m));
+}
+
+double PlainSVD::Apply(const arma::mat& V,
+                       size_t r,
+                       arma::mat& W,
+                       arma::mat& H) const
+{
+  if(r > V.n_rows || r > V.n_cols)
+  {
+    Log::Info << "Rank " << r << ", given for decomposition is invalid." << std::endl;
+    r = (V.n_rows > V.n_cols) ? V.n_cols : V.n_rows;
+    Log::Info << "Setting decomposition rank to " << r << std::endl;
+  }
+
+  arma::vec sigma;
+  arma::svd(W, sigma, H, V);
+
+  W = W.submat(0, 0, W.n_rows - 1, r - 1);
+  H = H.submat(0, 0, H.n_cols - 1, r - 1);
+
+  sigma = sigma.subvec(0, r - 1);
+
+  W = W * arma::diagmat(sigma);
+
+  arma::mat V_rec = W * arma::trans(H);
+
+  size_t n = V.n_rows;
+  size_t m = V.n_cols;
+  double sum = 0;
+  for(size_t i = 0;i < n;i++)
+  {
+    for(size_t j = 0;j < m;j++)
+    {
+      double temp = V(i, j);
+      temp = (temp - V_rec(i, j));
+      temp = temp * temp;
+      sum += temp;
+    }
+  }
+  return sqrt(sum / (n * m));
+}
diff --git a/src/mlpack/methods/cf/plain_svd.hpp b/src/mlpack/methods/cf/plain_svd.hpp
new file mode 100644
index 0000000..facd5ca
--- /dev/null
+++ b/src/mlpack/methods/cf/plain_svd.hpp
@@ -0,0 +1,30 @@
+#ifndef __MLPACK_METHODS_PLAIN_SVD_HPP
+#define __MLPACK_METHODS_PLAIN_SVD_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack
+{
+namespace svd
+{
+
+class PlainSVD
+{
+ public:
+  PlainSVD() {};
+
+  double Apply(const arma::mat& V,
+               arma::mat& W,
+               arma::mat& sigma,
+               arma::mat& H) const;
+
+  double Apply(const arma::mat& V,
+               size_t r,
+               arma::mat& W,
+               arma::mat& H) const;
+};
+
+};
+};
+
+#endif



More information about the mlpack-git mailing list