[mlpack-svn] r16814 - in mlpack/trunk/src/mlpack: core/optimizers/sgd methods methods/regularized_svd

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 10 15:16:30 EDT 2014


Author: siddharth.950
Date: Thu Jul 10 15:16:30 2014
New Revision: 16814

Log:
Adding Regularized SVD Code

Added:
   mlpack/trunk/src/mlpack/methods/regularized_svd/
   mlpack/trunk/src/mlpack/methods/regularized_svd/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd.hpp
   mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.cpp
   mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.hpp
   mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/core/optimizers/sgd/sgd_impl.hpp
   mlpack/trunk/src/mlpack/methods/CMakeLists.txt

Modified: mlpack/trunk/src/mlpack/core/optimizers/sgd/sgd_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/optimizers/sgd/sgd_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/optimizers/sgd/sgd_impl.hpp	Thu Jul 10 15:16:30 2014
@@ -7,6 +7,7 @@
 #ifndef __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
 #define __MLPACK_CORE_OPTIMIZERS_SGD_SGD_IMPL_HPP
 
+#include <mlpack/methods/regularized_svd/regularized_svd_function.hpp>
 // In case it hasn't been included yet.
 #include "sgd.hpp"
 

Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt	Thu Jul 10 15:16:30 2014
@@ -27,6 +27,7 @@
   radical
   range_search
   rann
+  regularized_svd
   sparse_autoencoder
   sparse_coding
 )

Added: mlpack/trunk/src/mlpack/methods/regularized_svd/CMakeLists.txt
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/regularized_svd/CMakeLists.txt	Thu Jul 10 15:16:30 2014
@@ -0,0 +1,17 @@
+# Define the files we need to compile.
+# Anything not in this list will not be compiled into MLPACK.
+set(SOURCES
+  regularized_svd.hpp
+  regularized_svd_impl.hpp
+  regularized_svd_function.hpp
+  regularized_svd_function.cpp
+)
+
+# Add directory name to sources.
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+# Append sources (with directory name) to list of all MLPACK sources (used at
+# the parent scope).
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)

Added: mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd.hpp	Thu Jul 10 15:16:30 2014
@@ -0,0 +1,70 @@
+/**
+ * @file regularized_svd.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of Regularized SVD.
+ */
+
+#ifndef __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_HPP
+#define __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
+
+#include "regularized_svd_function.hpp"
+
+namespace mlpack {
+namespace svd {
+
+template<
+  template<typename> class OptimizerType = mlpack::optimization::SGD
+>
+class RegularizedSVD
+{
+ public:
+ 
+  /**
+   * Constructor for Regularized SVD. Obtains the user and item matrices after
+   * training on the passed data. The constructor initiates an object of class
+   * RegularizedSVDFunction for optimization. It uses the SGD optimizer by
+   * default. The optimizer uses a template specialization of Optimize().
+   *
+   * @param data Dataset for which SVD is calculated.
+   * @param u User matrix in the matrix decomposition.
+   * @param v Item matrix in the matrix decomposition.
+   * @param rank Rank used for matrix factorization.
+   * @param iterations Number of optimization iterations.
+   * @param lambda Regularization parameter for the optimization.
+   */
+  RegularizedSVD(const arma::mat& data,
+                 arma::mat& u,
+                 arma::mat& v,
+                 const size_t rank,
+                 const size_t iterations = 10,
+                 const double alpha = 0.01,
+                 const double lambda = 0.02);
+                 
+ private:
+  //! Rating data.
+  const arma::mat& data;
+  //! Rank used for matrix factorization.
+  size_t rank;
+  //! Number of optimization iterations.
+  size_t iterations;
+  //! Learning rate for the SGD optimizer.
+  double alpha;
+  //! Regularization parameter for the optimization.
+  double lambda;
+  //! Function that will be held by the optimizer.
+  RegularizedSVDFunction rSVDFunc;
+  //! Default SGD optimizer for the class.
+  mlpack::optimization::SGD<RegularizedSVDFunction> optimizer;
+};
+
+}; // namespace svd
+}; // namespace mlpack
+
+// Include implementation.
+#include "regularized_svd_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.cpp	Thu Jul 10 15:16:30 2014
@@ -0,0 +1,181 @@
+/**
+ * @file regularized_svd_function.cpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of the RegularizedSVDFunction class.
+ */
+
+#include "regularized_svd_function.hpp"
+
+namespace mlpack {
+namespace svd {
+
+RegularizedSVDFunction::RegularizedSVDFunction(const arma::mat& data,
+                                               const size_t rank,
+                                               const double lambda) :
+    data(data),
+    rank(rank),
+    lambda(lambda)
+{
+  // Number of users and items in the data.
+  numUsers = max(data.row(0)) + 1;
+  numItems = max(data.row(1)) + 1;
+  
+  // Initialize the parameters.
+  initialPoint.randu(rank, numUsers + numItems);
+}
+
+double RegularizedSVDFunction::Evaluate(const arma::mat& parameters) const
+{
+  // The cost for the optimization is as follows:
+  //          f(u, v) = sum((rating(i, j) - u(i).t() * v(j))^2)
+  // The sum is over all the ratings in the rating matrix.
+  // 'i' points to the user and 'j' points to the item being considered.
+  // The regularization term is added to the above cost, where the vectors u(i)
+  // and v(j) are regularized for each rating they contribute to.
+
+  double cost = 0.0;
+
+  for(size_t i = 0; i < data.n_cols; i++)
+  {
+    // Indices for accessing the the correct parameter columns.
+    const size_t user = data(0, i);
+    const size_t item = data(1, i) + numUsers;
+
+    // Calculate the squared error in the prediction.
+    const double rating = data(2, i);
+    double ratingError = rating - arma::dot(parameters.col(user),
+                                            parameters.col(item));
+    double ratingErrorSquared = ratingError * ratingError;
+  
+    // Calculate the regularization penalty corresponding to the parameters.
+    double userVecNorm = arma::norm(parameters.col(user), 2);
+    double itemVecNorm = arma::norm(parameters.col(item), 2);
+    double regularizationError = lambda * (userVecNorm * userVecNorm +
+                                           itemVecNorm * itemVecNorm);
+                                           
+    cost += (ratingErrorSquared + regularizationError);
+  }
+  
+  return cost;
+}
+
+double RegularizedSVDFunction::Evaluate(const arma::mat& parameters,
+                                        const size_t i) const
+{
+  // Indices for accessing the the correct parameter columns.
+  const size_t user = data(0, i);
+  const size_t item = data(1, i) + numUsers;
+  
+  // Calculate the squared error in the prediction.
+  const double rating = data(2, i);
+  double ratingError = rating - arma::dot(parameters.col(user),
+                                          parameters.col(item));
+  double ratingErrorSquared = ratingError * ratingError;
+  
+  // Calculate the regularization penalty corresponding to the parameters.
+  double userVecNorm = arma::norm(parameters.col(user), 2);
+  double itemVecNorm = arma::norm(parameters.col(item), 2);
+  double regularizationError = lambda * (userVecNorm * userVecNorm +
+                                         itemVecNorm * itemVecNorm);
+                                         
+  return (ratingErrorSquared + regularizationError);
+}
+
+void RegularizedSVDFunction::Gradient(const arma::mat& parameters,
+                                      arma::mat& gradient) const
+{
+  // For an example with rating corresponding to user 'i' and item 'j', the
+  // gradients for the parameters is as follows:
+  //           grad(u(i)) = lambda * u(i) - error * v(j)
+  //           grad(v(j)) = lambda * v(j) - error * u(i)
+  // 'error' is the prediction error for that example, which is:
+  //           rating(i, j) - u(i).t() * v(j)
+  // The full gradient is calculated by summing the contributions over all the
+  // training examples.
+
+  gradient.zeros(rank, numUsers + numItems);
+
+  for(size_t i = 0; i < data.n_cols; i++)
+  {
+    // Indices for accessing the the correct parameter columns.
+    const size_t user = data(0, i);
+    const size_t item = data(1, i) + numUsers;
+
+    // Prediction error for the example.
+    const double rating = data(2, i);
+    double ratingError = rating - arma::dot(parameters.col(user),
+                                            parameters.col(item));
+
+    // Gradient is non-zero only for the parameter columns corresponding to the
+    // example.
+    gradient.col(user) += lambda * parameters.col(user) -
+                          ratingError * parameters.col(item);
+    gradient.col(item) += lambda * parameters.col(item) -
+                          ratingError * parameters.col(user);
+  }
+}
+
+}; // namespace svd
+}; // namespace mlpack
+
+// Template specialization for the SGD optimizer.
+namespace mlpack {
+namespace optimization {
+
+template<>
+double SGD<mlpack::svd::RegularizedSVDFunction>::Optimize(arma::mat& parameters)
+{
+  // Find the number of functions to use.
+  const size_t numFunctions = function.NumFunctions();
+
+  // To keep track of where we are and how things are going.
+  size_t currentFunction = 0;
+  double overallObjective = 0;
+
+  // Calculate the first objective function.
+  for(size_t i = 0; i < numFunctions; i++)
+    overallObjective += function.Evaluate(parameters, i);
+    
+  const arma::mat data = function.Dataset();
+
+  // Now iterate!
+  for(size_t i = 1; i != maxIterations; i++, currentFunction++)
+  {
+    // Is this iteration the start of a sequence?
+    if((currentFunction % numFunctions) == 0)
+    {
+      // Reset the counter variables.
+      overallObjective = 0;
+      currentFunction = 0;
+    }
+
+    const size_t numUsers = function.NumUsers();
+
+    // Indices for accessing the the correct parameter columns.
+    const size_t user = data(0, currentFunction);
+    const size_t item = data(1, currentFunction) + numUsers;
+
+    // Prediction error for the example.
+    const double rating = data(2, currentFunction);
+    double ratingError = rating - arma::dot(parameters.col(user),
+                                            parameters.col(item));
+                                            
+    double lambda = function.Lambda();
+
+    // Gradient is non-zero only for the parameter columns corresponding to the
+    // example.
+    parameters.col(user) -= stepSize * (lambda * parameters.col(user) -
+                                        ratingError * parameters.col(item));
+    parameters.col(item) -= stepSize * (lambda * parameters.col(item) -
+                                        ratingError * parameters.col(user));
+
+    // Now add that to the overall objective function.
+    overallObjective += function.Evaluate(parameters, currentFunction);
+  }
+
+  return overallObjective;
+}
+
+}; // namespace optimization
+}; // namespace mlpack

Added: mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_function.hpp	Thu Jul 10 15:16:30 2014
@@ -0,0 +1,115 @@
+/**
+ * @file regularized_svd_function.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of the RegularizedSVDFunction class.
+ */
+
+#ifndef __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_FUNCTION_SVD_HPP
+#define __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_FUNCTION_SVD_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/sgd/sgd.hpp>
+
+namespace mlpack {
+namespace svd {
+
+class RegularizedSVDFunction
+{
+ public:
+  
+  /**
+   * Constructor for RegularizedSVDFunction class. The constructor calculates
+   * the number of users and items in the passed data. It also randomly
+   * initializes the parameter values.
+   *
+   * @param data Dataset for which SVD is calculated.
+   * @param rank Rank used for matrix factorization.
+   * @param lambda Regularization parameter used for optimization.
+   */
+  RegularizedSVDFunction(const arma::mat& data,
+                         const size_t rank,
+                         const double lambda);
+  
+  /**
+   * Evaluates the cost function over all examples in the data.
+   *
+   * @param parameters Parameters(user/item matrices) of the decomposition.
+   */
+  double Evaluate(const arma::mat& parameters) const;
+  
+  /**
+   * Evaluates the cost function for one training example. Useful for the SGD
+   * optimizer abstraction which uses one training example at a time.
+   *
+   * @param parameters Parameters(user/item matrices) of the decomposition.
+   * @param i Index of the training example to be used.
+   */
+  double Evaluate(const arma::mat& parameters,
+                  const size_t i) const;
+  
+  /**
+   * Evaluates the full gradient of the cost function over all the training
+   * examples.
+   *
+   * @param parameters Parameters(user/item matrices) of the decomposition.
+   * @param gradient Calculated gradient for the parameters.
+   */
+  void Gradient(const arma::mat& parameters,
+                arma::mat& gradient) const;
+  
+  //! Return the initial point for the optimization.
+  const arma::mat& GetInitialPoint() const { return initialPoint; }
+  
+  //! Return the dataset passed into the constructor.
+  const arma::mat& Dataset() const { return data; }
+  
+  //! Return the number of training examples. Useful for SGD optimizer.
+  size_t NumFunctions() const { return data.n_cols; }
+  
+  //! Return the number of users in the data.
+  size_t NumUsers() const { return numUsers; }
+  
+  //! Return the number of items in the data.
+  size_t NumItems() const { return numItems; }
+  
+  //! Return the regularization parameters.
+  double Lambda() const { return lambda; }
+  
+  //! Return the rank used for the factorization.
+  size_t Rank() const { return rank; }
+                         
+ private:
+  //! Rating data.
+  const arma::mat& data;
+  //! Initial parameter point.
+  arma::mat initialPoint;
+  //! Rank used for matrix factorization.
+  size_t rank;
+  //! Regularization parameter for the optimization.
+  double lambda;
+  //! Number of users in the given dataset.
+  size_t numUsers;
+  //! Number of items in the given dataset.
+  size_t numItems;
+};
+
+}; // namespace svd
+}; // namespace mlpack
+
+namespace mlpack {
+namespace optimization {
+
+  /**
+   * Template specialization for SGD optimizer. Used because the gradient
+   * affects only a small number of parameters per example, and thus the normal
+   * abstraction does not work as fast as we might like it to.
+   */
+  template<>
+  double SGD<mlpack::svd::RegularizedSVDFunction>::Optimize(
+      arma::mat& parameters);
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/methods/regularized_svd/regularized_svd_impl.hpp	Thu Jul 10 15:16:30 2014
@@ -0,0 +1,47 @@
+/**
+ * @file regularized_svd_impl.hpp
+ * @author Siddharth Agrawal
+ *
+ * An implementation of Regularized SVD.
+ */
+
+#ifndef __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_IMPL_HPP
+#define __MLPACK_METHODS_REGULARIZED_SVD_REGULARIZED_SVD_IMPL_HPP
+
+namespace mlpack {
+namespace svd {
+
+template<template<typename> class OptimizerType>
+RegularizedSVD<OptimizerType>::RegularizedSVD(const arma::mat& data,
+                                              arma::mat& u,
+                                              arma::mat& v,
+                                              const size_t rank,
+                                              const size_t iterations,
+                                              const double alpha,
+                                              const double lambda) :
+    data(data),
+    rank(rank),
+    iterations(iterations),
+    alpha(alpha),
+    lambda(lambda),
+    rSVDFunc(data, rank, lambda),
+    optimizer(rSVDFunc, alpha, iterations * data.n_cols)
+{
+  arma::mat parameters = rSVDFunc.GetInitialPoint();
+
+  // Train the model.
+  Timer::Start("regularized_svd_optimization");
+  const double out = optimizer.Optimize(parameters);
+  Timer::Stop("regularized_svd_optimization");
+  
+  const size_t numUsers = max(data.row(0)) + 1;
+  const size_t numItems = max(data.row(1)) + 1;
+  
+  u = parameters.submat(0, 0, rank - 1, numUsers - 1);
+  v = parameters.submat(0, numUsers, rank - 1, numUsers + numItems - 1);
+}
+
+}; // namespace svd
+}; // namespace mlpack
+
+#endif



More information about the mlpack-svn mailing list