[mlpack-svn] r15843 - mlpack/trunk/src/mlpack/methods/linear_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Sep 26 14:52:13 EDT 2013


Author: rcurtin
Date: Thu Sep 26 14:52:13 2013
New Revision: 15843

Log:
Add lambda parameter, so this now supports ridge regression.


Modified:
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	Thu Sep 26 14:52:13 2013
@@ -10,7 +10,9 @@
 using namespace mlpack::regression;
 
 LinearRegression::LinearRegression(arma::mat& predictors,
-                                   const arma::colvec& responses)
+                                   const arma::colvec& responses,
+                                   const double lambda) :
+    lambda(lambda)
 {
   /*
    * We want to calculate the a_i coefficients of:
@@ -18,7 +20,7 @@
    * In order to get the intercept value, we will add a row of ones.
    */
 
-  // We store the number of rows of the predictors.
+  // We store the number of rows and columns of the predictors.
   // Reminder: Armadillo stores the data transposed from how we think of it,
   //           that is, columns are actually rows (see: column major order).
   const size_t nCols = predictors.n_cols;
@@ -31,6 +33,16 @@
   // We set the parameters to the correct size and initialize them to zero.
   parameters.zeros(nCols);
 
+  // Now, add the identity matrix to the predictors (this is equivalent to ridge
+  // regression).  See http://math.stackexchange.com/questions/299481/ for more
+  // information.
+  if (lambda > 0)
+  {
+    predictors.insert_cols(nCols, predictors.n_rows);
+    predictors.cols(nCols, predictors.n_cols - 1) =
+        lambda * arma::eye<arma::mat>(predictors.n_rows, predictors.n_rows);
+  }
+
   // We compute the QR decomposition of the predictors.
   // We transpose the predictors because they are in column major order.
   arma::mat Q, R;
@@ -43,6 +55,10 @@
 
   // We now remove the row of ones we added so the user's data is unmodified.
   predictors.shed_row(0);
+  if (lambda > 0)
+  {
+    predictors.shed_cols(nCols, predictors.n_cols - 1);
+  }
 }
 
 LinearRegression::LinearRegression(const std::string& filename)

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	Thu Sep 26 14:52:13 2013
@@ -14,6 +14,8 @@
 
 /**
  * A simple linear regression algorithm using ordinary least squares.
+ * Optionally, this class can perform ridge regression, if the lambda parameter
+ * is set to a number greater than zero.
  */
 class LinearRegression
 {
@@ -24,7 +26,9 @@
    * @param predictors X, matrix of data points to create B with.
    * @param responses y, the measured data for each point in X
    */
-  LinearRegression(arma::mat& predictors, const arma::vec& responses);
+  LinearRegression(arma::mat& predictors,
+                   const arma::vec& responses,
+                   const double lambda = 0);
 
   /**
    * Initialize the model from a file.
@@ -58,11 +62,11 @@
    * this linear regression model.  This calculation returns
    *
    * \f[
-   * (1 / n) * \| y - X \theta \|^2_2
+   * (1 / n) * \| y - X B \|^2_2
    * \f]
    *
    * where \f$ y \f$ is the responses vector, \f$ X \f$ is the matrix of
-   * predictors, and \f$ \theta \f$ is the parameters of the trained linear
+   * predictors, and \f$ B \f$ is the parameters of the trained linear
    * regression model.
    *
    * As this number decreases to 0, the linear regression fit is better.
@@ -84,6 +88,9 @@
    * Initialized and filled by constructor to hold the least squares solution.
    */
   arma::vec parameters;
+
+  //! The lambda parameter for ridge regression (0 for linear regression).
+  double lambda;
 };
 
 }; // namespace linear_regression



More information about the mlpack-svn mailing list