[mlpack-svn] r10201 - mlpack/trunk/src/mlpack/methods/linear_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 22:40:34 EST 2011
Author: jcline3
Date: 2011-11-08 22:40:34 -0500 (Tue, 08 Nov 2011)
New Revision: 10201
Modified:
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
Log:
comments
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-11-09 01:42:26 UTC (rev 10200)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-11-09 03:40:34 UTC (rev 10201)
@@ -6,18 +6,30 @@
LinearRegression::LinearRegression(arma::mat& predictors,
const arma::colvec& responses)
{
+
+ /*
+ * We want to calculate the a_i coefficients of:
+ * \sum_{i=0}^n (a_i * x_i^i)
+ * We add a row of ones to get a_0, where x_0^0 = 1, the intercept.
+ */
+
+ // The number of columns and rows
size_t n_cols, n_rows;
n_cols = predictors.n_cols;
n_rows = predictors.n_rows;
+ // Add a row of ones, to get the intercept
arma::rowvec ones;
ones.ones(n_cols);
predictors.insert_rows(0,ones);
+ // We have an additional row, now
++n_rows;
- parameters.set_size(n_cols);
+ // Set the parameters to the correct size, all zeros.
+ parameters.zeros(n_cols);
+ // inverse( A^T * A ) * A^T * responses, where A = predictors
parameters = arma::inv((predictors * arma::trans(predictors))) *
predictors * responses;
}
@@ -33,18 +45,22 @@
void LinearRegression::predict(arma::rowvec& predictions, const arma::mat& points)
{
+ // The number of columns and rows
size_t n_cols, n_rows;
n_cols = points.n_cols;
n_rows = points.n_rows;
+ // Sanity check
assert(n_rows == parameters.n_rows - 1);
predictions.zeros(n_cols);
- predictions += parameters(0);
+ // Set to a_0
+ predictions = parameters(0);
for(size_t i = 1; i < n_rows; ++i)
{
for(size_t j = 0; j < n_cols; ++j)
{
+ // Add in the next term: a_i * x_i
predictions(j) += parameters(i) * points(i-1,j);
}
More information about the mlpack-svn
mailing list