[mlpack-svn] r16978 - in mlpack/trunk/src/mlpack: methods/cf tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 6 17:38:07 EDT 2014
Author: sumedhghaisas
Date: Wed Aug 6 17:38:07 2014
New Revision: 16978
Log:
* minor changes
Modified:
mlpack/trunk/src/mlpack/methods/cf/cf.hpp
mlpack/trunk/src/mlpack/methods/cf/cf_impl.hpp
mlpack/trunk/src/mlpack/methods/cf/plain_svd.cpp
mlpack/trunk/src/mlpack/methods/cf/plain_svd.hpp
mlpack/trunk/src/mlpack/tests/plain_svd_test.cpp
Modified: mlpack/trunk/src/mlpack/methods/cf/cf.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/cf/cf.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/cf/cf.hpp Wed Aug 6 17:38:07 2014
@@ -1,6 +1,7 @@
/**
* @file cf.hpp
* @author Mudit Raj Gupta
+ * @author Sumedh Ghaisas
*
* Collaborative filtering.
*
Modified: mlpack/trunk/src/mlpack/methods/cf/cf_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/cf/cf_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/cf/cf_impl.hpp Wed Aug 6 17:38:07 2014
@@ -1,6 +1,7 @@
/**
* @file cf.cpp
* @author Mudit Raj Gupta
+ * @author Sumedh Ghaisas
*
* Collaborative Filtering.
*
Modified: mlpack/trunk/src/mlpack/methods/cf/plain_svd.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/cf/plain_svd.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/cf/plain_svd.cpp Wed Aug 6 17:38:07 2014
@@ -1,3 +1,9 @@
+/**
+ * @file plain_svd.cpp
+ * @author Sumedh Ghaisas
+ *
+ * Implementation of the wrapper class for Armadillo's SVD.
+ */
#include "plain_svd.hpp"
using namespace mlpack;
@@ -8,9 +14,11 @@
arma::mat& sigma,
arma::mat& H) const
{
+ // get svd factorization
arma::vec E;
arma::svd(W, E, H, V);
+ // construct sigma matrix
sigma.zeros(V.n_rows, V.n_cols);
for(size_t i = 0;i < sigma.n_rows && i < sigma.n_cols;i++)
@@ -18,6 +26,7 @@
arma::mat V_rec = W * sigma * arma::trans(H);
+ // return normalized frobenius error
return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
}
@@ -26,6 +35,7 @@
arma::mat& W,
arma::mat& H) const
{
+ // check if the given rank is valid
if(r > V.n_rows || r > V.n_cols)
{
Log::Info << "Rank " << r << ", given for decomposition is invalid." << std::endl;
@@ -33,19 +43,27 @@
Log::Info << "Setting decomposition rank to " << r << std::endl;
}
+ // get svd factorization
arma::vec sigma;
arma::svd(W, sigma, H, V);
+ // remove the part of W and H depending upon the value of rank
W = W.submat(0, 0, W.n_rows - 1, r - 1);
H = H.submat(0, 0, H.n_cols - 1, r - 1);
+ // take only required eigenvalues
sigma = sigma.subvec(0, r - 1);
-
+
+ // eigenvalue matrix is multiplied to W
+ // it can either be multiplied to H matrix
W = W * arma::diagmat(sigma);
+ // take transpose of the matrix H as required by CF module
H = arma::trans(H);
+ // reconstruct the matrix
arma::mat V_rec = W * H;
+ // return the normalized frobenius norm
return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
}
Modified: mlpack/trunk/src/mlpack/methods/cf/plain_svd.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/cf/plain_svd.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/cf/plain_svd.hpp Wed Aug 6 17:38:07 2014
@@ -1,3 +1,9 @@
+/**
+ * @file plain_svd.hpp
+ * @author Sumedh Ghaisas
+ *
+ * Wrapper class for Armadillo's SVD.
+ */
#ifndef __MLPACK_METHODS_PLAIN_SVD_HPP
#define __MLPACK_METHODS_PLAIN_SVD_HPP
@@ -8,23 +14,50 @@
namespace svd
{
+/**
+ * This class acts as a wrapper class for Armadillo's SVD implementation to be
+ * used by Collaborative Filteraing module.
+ *
+ * @see CF
+ */
class PlainSVD
{
public:
+ // empty constructor
PlainSVD() {};
+ /**
+ * Factorizer function which takes SVD of the given matrix and returns the
+ * frobenius norm of error.
+ *
+ * @param V input matrix
+ * @param W first unitary matrix
+ * @param sigma eigenvalue matrix
+ * @param H second unitary matrix
+ *
+ * @note V = W * sigma * arma::trans(H)
+ */
double Apply(const arma::mat& V,
arma::mat& W,
arma::mat& sigma,
arma::mat& H) const;
-
+ /**
+ * Factorizer function which computes SVD and returns matrices as required by
+ * CF module.
+ *
+ * @param V input matrix
+ * @param W first unitary matrix
+ * @param H second unitary matrix
+ *
+ * @note V = W * H
+ */
double Apply(const arma::mat& V,
size_t r,
arma::mat& W,
arma::mat& H) const;
-};
+}; // class PlainSVD
-};
-};
+}; // namespace svd
+}; // namespace mlpack
#endif
Modified: mlpack/trunk/src/mlpack/tests/plain_svd_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/plain_svd_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/plain_svd_test.cpp Wed Aug 6 17:38:07 2014
@@ -38,6 +38,7 @@
mat W_t = randu<mat>(30, 3);
mat H_t = randu<mat>(3, 40);
+ // create a row-rank matrix
mat test = W_t * H_t;
PlainSVD svd;
More information about the mlpack-svn
mailing list