[mlpack-git] master: * changed PlainSVD to SVDwrapper * SVDwrapper is templatized to support other SVD factorizers * Added simple typedefs for simple API (ed80898)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:57:28 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit ed808983957a436d4b9fa7d50535ec407c68572f
Author: sumedhghaisas <sumedhghaisas at gmail.com>
Date:   Mon Aug 11 17:43:24 2014 +0000

    * changed PlainSVD to SVDwrapper
    * SVDwrapper is templatized to support other SVD factorizers
    * Added simple typedefs for simple API


>---------------------------------------------------------------

ed808983957a436d4b9fa7d50535ec407c68572f
 src/mlpack/core.hpp                                |   5 +
 src/mlpack/methods/amf/amf.hpp                     |  62 ++++++++++
 .../simple_tolerance_termination.hpp               |   2 +-
 .../amf/update_rules/svd_batch_learning.hpp        |  51 +++++++-
 .../svd_complete_incremental_learning.hpp          |  12 +-
 src/mlpack/methods/cf/CMakeLists.txt               |   4 +-
 src/mlpack/methods/cf/cf.hpp                       |   6 +-
 src/mlpack/methods/cf/plain_svd.cpp                |  69 -----------
 .../methods/cf/{plain_svd.hpp => svd_wrapper.hpp}  |  32 ++++--
 src/mlpack/methods/cf/svd_wrapper_impl.hpp         | 128 +++++++++++++++++++++
 src/mlpack/tests/CMakeLists.txt                    |   2 +-
 .../{plain_svd_test.cpp => armadillo_svd_test.cpp} |  18 +--
 12 files changed, 285 insertions(+), 106 deletions(-)

diff --git a/src/mlpack/core.hpp b/src/mlpack/core.hpp
index de7fd77..5a63f12 100644
--- a/src/mlpack/core.hpp
+++ b/src/mlpack/core.hpp
@@ -172,6 +172,11 @@
 #include <mlpack/core/kernels/spherical_kernel.hpp>
 #include <mlpack/core/kernels/triangular_kernel.hpp>
 
+// Use armadillo's C++ version detection
+#ifdef ARMA_USE_CXX11
+  #define MLPACK_USE_CX11
+#endif
+
 #endif
 
 // Clean up unfortunate Windows preprocessor definitions, even if this file was
diff --git a/src/mlpack/methods/amf/amf.hpp b/src/mlpack/methods/amf/amf.hpp
index 6f7e91e..ba6e96a 100644
--- a/src/mlpack/methods/amf/amf.hpp
+++ b/src/mlpack/methods/amf/amf.hpp
@@ -12,9 +12,17 @@
 #define __MLPACK_METHODS_AMF_AMF_HPP
 
 #include <mlpack/core.hpp>
+
 #include <mlpack/methods/amf/update_rules/nmf_mult_dist.hpp>
+#include <mlpack/methods/amf/update_rules/nmf_als.hpp>
+#include <mlpack/methods/amf/update_rules/svd_batch_learning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_incomplete_incremental_learning.hpp>
+#include <mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp>
+
 #include <mlpack/methods/amf/init_rules/random_init.hpp>
+
 #include <mlpack/methods/amf/termination_policies/simple_residue_termination.hpp>
+#include <mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp>
 
 namespace mlpack {
 namespace amf {
@@ -122,6 +130,60 @@ class AMF
   UpdateRuleType update;
 }; // class AMF
 
+typedef amf::AMF<amf::SimpleResidueTermination,
+                 amf::RandomInitialization, 
+                 amf::NMFALSUpdate> NMFALSFactorizer;
+
+//! Add simple typedefs 
+#ifdef MLPACK_USE_CXX11
+
+template<class MatType>
+using SVDBatchFactorizer = amf::AMF<amf::SimpleToleranceTermination<MatType>,
+                                    amf::RandomInitialization,
+                                    amf::SVDBatchLearning>;
+                                    
+template<class MatType>
+using SVDIncompleteIncrementalFactorizer = amf::AMF<amf::SimpleToleranceTermination<MatType>,
+                                                    amf::RandomInitialization,
+                                                    amf::SVDIncompleteIncrementalLearning>;
+                                                    
+template<class MatType>
+using SVDCompleteIncrementalFactorizer = amf::AMF<amf::SimpleToleranceTermination<MatType>,
+                                                  amf::RandomInitialization,
+                                                  amf::SVDCompleteIncrementalLearning<MatType> >;
+
+#else                
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::sp_mat>,
+                 amf::RandomInitialization,
+               amf::SVDBatchLearning> SparseSVDBatchFactorizer;
+                 
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::mat>,
+                 amf::RandomInitialization,
+                 amf::SVDBatchLearning> SVDBatchFactorizer;
+                 
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::sp_mat>,
+                 amf::RandomInitialization,
+                 amf::SVDIncompleteIncrementalLearning> 
+        SparseSVDIncompleteIncrementalFactorizer;
+                 
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::mat>,
+                 amf::RandomInitialization,
+                 amf::SVDIncompleteIncrementalLearning> 
+        SVDIncompleteIncrementalFactorizer;
+
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::sp_mat>,
+                 amf::RandomInitialization,
+                 amf::SVDCompleteIncrementalLearning<arma::sp_mat> > 
+        SparseSVDCompleteIncrementalFactorizer;
+                 
+typedef amf::AMF<amf::SimpleToleranceTermination<arma::mat>,
+                 amf::RandomInitialization,
+                 amf::SVDCompleteIncrementalLearning<arma::mat> > 
+        SVDCompleteIncrementalFactorizer;
+
+#endif
+
+
 }; // namespace amf
 }; // namespace mlpack
 
diff --git a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
index 3c12e21..7a34ccb 100644
--- a/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
+++ b/src/mlpack/methods/amf/termination_policies/simple_tolerance_termination.hpp
@@ -177,7 +177,7 @@ class SimpleToleranceTermination
   //! minimum residue point
   bool isCopy;
   
-  //! variables to store information of minimum residue point
+  //! variables to store information of minimum residue poi
   arma::mat W;
   arma::mat H;
   double c_indexOld;
diff --git a/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
index 84e41ce..d692696 100644
--- a/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_batch_learning.hpp
@@ -2,7 +2,7 @@
  * @file svd_batch_learning.hpp
  * @author Sumedh Ghaisas
  *
- * SVD factorization used in AMF (Alternating Matrix Factorization).
+ * SVD factorizer used in AMF (Alternating Matrix Factorization).
  */
 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
@@ -13,6 +13,16 @@ namespace mlpack
 {
 namespace amf
 {
+
+/**
+ * This class implements SVD batch learning with momentum. This procedure is 
+ * described in the paper 'A Guide to singular Value Decomposition' 
+ * by Chih-Chao Ma. Class implements 'Algorithm 4' given in the paper. 
+ * This factorizer decomposes the matrix V into two matrices W and H such that
+ * sum of sum of squared error between V and W*H is minimum. This optimization is
+ * performed with gradient descent. To make gradient descent faster momentum is 
+ * added. 
+ */
 class SVDBatchLearning
 {
  public:
@@ -29,8 +39,17 @@ class SVDBatchLearning
                    double kh = 0,
                    double momentum = 0.9)
         : u(u), kw(kw), kh(kh), momentum(momentum)
-    {}
+  {
+    // empty constructor
+  }
 
+  /**
+   * Initialize value before factorization.
+   * This function must be called before each new factorization.
+   *
+   * @param dataset Input matrix to be factorized.
+   * @param rank rank of factorization
+   */
   template<typename MatType>
   void Initialize(const MatType& dataset, const size_t rank)
   {
@@ -60,11 +79,12 @@ class SVDBatchLearning
 
     size_t r = W.n_cols;
 
+    // initialize the momentum of this iteration
     mW = momentum * mW;
 
+    // compute the step
     arma::mat deltaW(n, r);
     deltaW.zeros();
-
     for(size_t i = 0;i < n;i++)
     {
       for(size_t j = 0;j < m;j++)
@@ -74,10 +94,13 @@ class SVDBatchLearning
           deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) * 
                                                   arma::trans(H.col(j));
       }
+      // add regularization
       if(kw != 0) deltaW.row(i) -= kw * W.row(i);
     }
 
+    // add the step to the momentum
     mW += u * deltaW;
+    // add the momentum to W matrix
     W += mW;
   }
 
@@ -100,11 +123,12 @@ class SVDBatchLearning
 
     size_t r = W.n_cols;
 
+    // initialize the momentum of this iteration
     mH = momentum * mH;
 
+    // compute the step
     arma::mat deltaH(r, m);
     deltaH.zeros();
-
     for(size_t j = 0;j < m;j++)
     {
       for(size_t i = 0;i < n;i++)
@@ -114,23 +138,38 @@ class SVDBatchLearning
           deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * 
                                                     arma::trans(W.row(i));
       }
+      // add regularization
       if(kh != 0) deltaH.col(j) -= kh * H.col(j);
     }
 
+    // add step to the momentum
     mH += u*deltaH;
+    // add momentum to H
     H += mH;
   }
   
  private:
+  //! step size of the algorithm
   double u;
+  //! regularization parameter for matrix W
   double kw;
+  //! regularization parameter matrix for matrix H
   double kh;
+  //! momentum value
   double momentum;
 
+  //! momentum matrix for matrix W
   arma::mat mW;
+  //! momentum matrix for matrix H
   arma::mat mH;
-};
+}; // class SBDBatchLearning
 
+//! TODO : Merge this template specialized function for sparse matrix using 
+//!        common row_col_iterator
+
+/**
+ * WUpdate function specialization for sparse matrix
+ */
 template<> 
 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
                                                     arma::mat& W,
@@ -197,6 +236,6 @@ inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
 } // namespace mlpack
 
 
-#endif
+#endif // __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP
 
 
diff --git a/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp b/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
index 6ff7b3f..fe3150c 100644
--- a/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
+++ b/src/mlpack/methods/amf/update_rules/svd_complete_incremental_learning.hpp
@@ -1,5 +1,11 @@
-#ifndef SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
-#define SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
+/**
+ * @file svd_batch_learning.hpp
+ * @author Sumedh Ghaisas
+ *
+ * SVD factorizer used in AMF (Alternating Matrix Factorization).
+ */
+#ifndef _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
+#define _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
 
 #include <mlpack/core.hpp>
 
@@ -217,5 +223,5 @@ class SVDCompleteIncrementalLearning<arma::sp_mat>
 }
 
 
-#endif // SVD_COMPLETE_INCREMENTAL_LEARNING_HPP_INCLUDED
+#endif // _MLPACK_METHODS_AMF_SVDCOMPLETEINCREMENTALLEARNING_HPP_INCLUDED
 
diff --git a/src/mlpack/methods/cf/CMakeLists.txt b/src/mlpack/methods/cf/CMakeLists.txt
index 840dbc9..5238d8c 100644
--- a/src/mlpack/methods/cf/CMakeLists.txt
+++ b/src/mlpack/methods/cf/CMakeLists.txt
@@ -3,8 +3,8 @@
 set(SOURCES
   cf.hpp
   cf_impl.hpp
-  plain_svd.hpp
-  plain_svd.cpp
+  svd_wrapper.hpp
+  svd_wrapper_impl.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp
index 8a7b503..f8321ab 100644
--- a/src/mlpack/methods/cf/cf.hpp
+++ b/src/mlpack/methods/cf/cf.hpp
@@ -21,7 +21,7 @@
 #include <iostream>
 
 namespace mlpack {
-namespace cf /** Collaborative filtering. */{
+namespace cf /** Collaborative filtering. */ {
 
 /**
  * This class implements Collaborative Filtering (CF). This implementation
@@ -56,9 +56,7 @@ namespace cf /** Collaborative filtering. */{
  *     Apply(arma::sp_mat& data, size_t rank, arma::mat& W, arma::mat& H).
  */
 template<
-    typename FactorizerType = amf::AMF<amf::SimpleResidueTermination,
-                                       amf::RandomInitialization, 
-                                       amf::NMFALSUpdate> >
+    typename FactorizerType = amf::NMFALSFactorizer>
 class CF
 {
  public:
diff --git a/src/mlpack/methods/cf/plain_svd.cpp b/src/mlpack/methods/cf/plain_svd.cpp
deleted file mode 100644
index cdedb69..0000000
--- a/src/mlpack/methods/cf/plain_svd.cpp
+++ /dev/null
@@ -1,69 +0,0 @@
-/**
- * @file plain_svd.cpp
- * @author Sumedh Ghaisas
- *
- * Implementation of the wrapper class for Armadillo's SVD.
- */
-#include "plain_svd.hpp"
-
-using namespace mlpack;
-using namespace mlpack::svd;
-
-double PlainSVD::Apply(const arma::mat& V,
-                       arma::mat& W,
-                       arma::mat& sigma,
-                       arma::mat& H) const
-{
-  // get svd factorization
-  arma::vec E;
-  arma::svd(W, E, H, V);
-
-  // construct sigma matrix 
-  sigma.zeros(V.n_rows, V.n_cols);
-
-  for(size_t i = 0;i < sigma.n_rows && i < sigma.n_cols;i++)
-    sigma(i, i) = E(i, 0);
-
-  arma::mat V_rec = W * sigma * arma::trans(H);
-
-  // return normalized frobenius error
-  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
-}
-
-double PlainSVD::Apply(const arma::mat& V,
-                       size_t r,
-                       arma::mat& W,
-                       arma::mat& H) const
-{
-  // check if the given rank is valid
-  if(r > V.n_rows || r > V.n_cols)
-  {
-    Log::Info << "Rank " << r << ", given for decomposition is invalid." << std::endl;
-    r = (V.n_rows > V.n_cols) ? V.n_cols : V.n_rows;
-    Log::Info << "Setting decomposition rank to " << r << std::endl;
-  }
-
-  // get svd factorization
-  arma::vec sigma;
-  arma::svd(W, sigma, H, V);
-
-  // remove the part of W and H depending upon the value of rank
-  W = W.submat(0, 0, W.n_rows - 1, r - 1);
-  H = H.submat(0, 0, H.n_cols - 1, r - 1);
-
-  // take only required eigenvalues
-  sigma = sigma.subvec(0, r - 1);
-  
-  // eigenvalue matrix is multiplied to W
-  // it can either be multiplied to H matrix
-  W = W * arma::diagmat(sigma);
-  
-  // take transpose of the matrix H as required by CF module
-  H = arma::trans(H);
-
-  // reconstruct the matrix
-  arma::mat V_rec = W * H;
-
-  // return the normalized frobenius norm
-  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
-}
diff --git a/src/mlpack/methods/cf/plain_svd.hpp b/src/mlpack/methods/cf/svd_wrapper.hpp
similarity index 64%
rename from src/mlpack/methods/cf/plain_svd.hpp
rename to src/mlpack/methods/cf/svd_wrapper.hpp
index f1191d4..27ef452 100644
--- a/src/mlpack/methods/cf/plain_svd.hpp
+++ b/src/mlpack/methods/cf/svd_wrapper.hpp
@@ -1,30 +1,33 @@
 /**
- * @file plain_svd.hpp
+ * @file svd_wrapper.hpp
  * @author Sumedh Ghaisas
  *
- * Wrapper class for Armadillo's SVD.
+ * Wrapper class for SVD factorizers used for Collaborative Filtering.
  */
-#ifndef __MLPACK_METHODS_PLAIN_SVD_HPP
-#define __MLPACK_METHODS_PLAIN_SVD_HPP
+#ifndef __MLPACK_METHODS_SVDWRAPPER_HPP
+#define __MLPACK_METHODS_SVDWRAPPER_HPP
 
 #include <mlpack/core.hpp>
 
 namespace mlpack
 {
-namespace svd
+namespace cf
 {
 
 /**
- * This class acts as a wrapper class for Armadillo's SVD implementation to be 
- * used by Collaborative Filteraing module.
  *
  * @see CF
  */
-class PlainSVD
+
+class DummyClass {}; 
+ 
+template<class Factorizer = DummyClass>
+class SVDWrapper
 {
  public:
   // empty constructor
-  PlainSVD() {};
+  SVDWrapper(const Factorizer& factorizer = Factorizer()) 
+    : factorizer(factorizer) {};
 
   /**
    * Factorizer function which takes SVD of the given matrix and returns the 
@@ -55,9 +58,16 @@ class PlainSVD
                size_t r,
                arma::mat& W,
                arma::mat& H) const;
-}; // class PlainSVD
                
-}; // namespace svd
+ private:
+  //! svd factorizer
+  Factorizer factorizer;
+}; // class SVDWrapper
+
+//! include the implementation
+#include "svd_wrapper_impl.hpp"
+
+}; // namespace cf
 }; // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/cf/svd_wrapper_impl.hpp b/src/mlpack/methods/cf/svd_wrapper_impl.hpp
new file mode 100644
index 0000000..f3368c3
--- /dev/null
+++ b/src/mlpack/methods/cf/svd_wrapper_impl.hpp
@@ -0,0 +1,128 @@
+/**
+ * @file svd_wrapper_impl.hpp
+ * @author Sumedh Ghaisas
+ *
+ * Implementation of the SVD wrapper class.
+ */
+
+template<class Factorizer>
+double mlpack::cf::SVDWrapper<Factorizer>::Apply(const arma::mat& V,
+                         arma::mat& W,
+                         arma::mat& sigma,
+                         arma::mat& H) const
+{
+  // get svd factorization
+  arma::vec E;
+  factorizer.Apply(W, E, H, V);
+
+  // construct sigma matrix 
+  sigma.zeros(V.n_rows, V.n_cols);
+
+  for(size_t i = 0;i < sigma.n_rows && i < sigma.n_cols;i++)
+    sigma(i, i) = E(i, 0);
+
+  arma::mat V_rec = W * sigma * arma::trans(H);
+
+  // return normalized frobenius error
+  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
+}
+
+template<>
+double mlpack::cf::SVDWrapper<DummyClass>::Apply(const arma::mat& V,
+                                     arma::mat& W,
+                                     arma::mat& sigma,
+                                     arma::mat& H) const
+{
+  // get svd factorization
+  arma::vec E;
+  arma::svd(W, E, H, V);
+
+  // construct sigma matrix 
+  sigma.zeros(V.n_rows, V.n_cols);
+
+  for(size_t i = 0;i < sigma.n_rows && i < sigma.n_cols;i++)
+    sigma(i, i) = E(i, 0);
+
+  arma::mat V_rec = W * sigma * arma::trans(H);
+
+  // return normalized frobenius error
+  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
+}
+
+template<class Factorizer>
+double mlpack::cf::SVDWrapper<Factorizer>::Apply(const arma::mat& V,
+                         size_t r,
+                         arma::mat& W,
+                         arma::mat& H) const
+{
+  // check if the given rank is valid
+  if(r > V.n_rows || r > V.n_cols)
+  {
+    Log::Info << "Rank " << r << ", given for decomposition is invalid." << std::endl;
+    r = (V.n_rows > V.n_cols) ? V.n_cols : V.n_rows;
+    Log::Info << "Setting decomposition rank to " << r << std::endl;
+  }
+
+  // get svd factorization
+  arma::vec sigma;
+  factorizer.Apply(W, sigma, H, V);
+
+  // remove the part of W and H depending upon the value of rank
+  W = W.submat(0, 0, W.n_rows - 1, r - 1);
+  H = H.submat(0, 0, H.n_cols - 1, r - 1);
+
+  // take only required eigenvalues
+  sigma = sigma.subvec(0, r - 1);
+  
+  // eigenvalue matrix is multiplied to W
+  // it can either be multiplied to H matrix
+  W = W * arma::diagmat(sigma);
+  
+  // take transpose of the matrix H as required by CF module
+  H = arma::trans(H);
+
+  // reconstruct the matrix
+  arma::mat V_rec = W * H;
+
+  // return the normalized frobenius norm
+  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
+}
+
+template<>
+double mlpack::cf::SVDWrapper<DummyClass>::Apply(const arma::mat& V,
+                                     size_t r,
+                                     arma::mat& W,
+                                     arma::mat& H) const
+{
+  // check if the given rank is valid
+  if(r > V.n_rows || r > V.n_cols)
+  {
+    Log::Info << "Rank " << r << ", given for decomposition is invalid." << std::endl;
+    r = (V.n_rows > V.n_cols) ? V.n_cols : V.n_rows;
+    Log::Info << "Setting decomposition rank to " << r << std::endl;
+  }
+
+  // get svd factorization
+  arma::vec sigma;
+  arma::svd(W, sigma, H, V);
+
+  // remove the part of W and H depending upon the value of rank
+  W = W.submat(0, 0, W.n_rows - 1, r - 1);
+  H = H.submat(0, 0, H.n_cols - 1, r - 1);
+
+  // take only required eigenvalues
+  sigma = sigma.subvec(0, r - 1);
+  
+  // eigenvalue matrix is multiplied to W
+  // it can either be multiplied to H matrix
+  W = W * arma::diagmat(sigma);
+  
+  // take transpose of the matrix H as required by CF module
+  H = arma::trans(H);
+
+  // reconstruct the matrix
+  arma::mat V_rec = W * H;
+
+  // return the normalized frobenius norm
+  return arma::norm(V - V_rec, "fro") / arma::norm(V, "fro");
+}
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index d177223..5f4c0e9 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -55,7 +55,7 @@ add_executable(mlpack_test
   svd_batch_test.cpp
   svd_incremental_test.cpp
   nystroem_method_test.cpp
-  plain_svd_test.cpp
+  armadillo_svd_test.cpp
 )
 # Link dependencies of test executable.
 target_link_libraries(mlpack_test
diff --git a/src/mlpack/tests/plain_svd_test.cpp b/src/mlpack/tests/armadillo_svd_test.cpp
similarity index 65%
rename from src/mlpack/tests/plain_svd_test.cpp
rename to src/mlpack/tests/armadillo_svd_test.cpp
index 0aed9a4..b83fcca 100644
--- a/src/mlpack/tests/plain_svd_test.cpp
+++ b/src/mlpack/tests/armadillo_svd_test.cpp
@@ -1,24 +1,24 @@
 #include <mlpack/core.hpp>
-#include <mlpack/methods/cf/plain_svd.hpp>
+#include <mlpack/methods/cf/svd_wrapper.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
-BOOST_AUTO_TEST_SUITE(PlainSVDTest);
+BOOST_AUTO_TEST_SUITE(ArmadilloSVDTest);
 
 using namespace std;
 using namespace mlpack;
-using namespace mlpack::svd;
+using namespace mlpack::cf;
 using namespace arma;
 
 /**
- * Test PlainSVD for normal factorization
+ * Test armadillo SVD for normal factorization
  */
-BOOST_AUTO_TEST_CASE(PlainSVDNormalFactorizationTest)
+BOOST_AUTO_TEST_CASE(ArmadilloSVDNormalFactorizationTest)
 {
   mat test = randu<mat>(20, 20);
 
-  PlainSVD svd;
+  SVDWrapper<> svd;
   arma::mat W, H, sigma;
   double result = svd.Apply(test, W, sigma, H);
   
@@ -31,9 +31,9 @@ BOOST_AUTO_TEST_CASE(PlainSVDNormalFactorizationTest)
 }
 
 /**
- * Test PlainSVD for low rank matrix factorization
+ * Test armadillo SVD for low rank matrix factorization
  */
-BOOST_AUTO_TEST_CASE(PlainSVDLowRankFactorizationTest)
+BOOST_AUTO_TEST_CASE(ArmadilloSVDLowRankFactorizationTest)
 {
   mat W_t = randu<mat>(30, 3);
   mat H_t = randu<mat>(3, 40);
@@ -41,7 +41,7 @@ BOOST_AUTO_TEST_CASE(PlainSVDLowRankFactorizationTest)
   // create a row-rank matrix
   mat test = W_t * H_t;
 
-  PlainSVD svd;
+  SVDWrapper<> svd;
   arma::mat W, H;
   double result = svd.Apply(test, 3, W, H);
   



More information about the mlpack-git mailing list