[mlpack-svn] r10805 - mlpack/trunk/src/mlpack/methods/linear_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 16:23:22 EST 2011
Author: rcurtin
Date: 2011-12-14 16:23:21 -0500 (Wed, 14 Dec 2011)
New Revision: 10805
Modified:
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
Log:
Redo LinearRegression API and put it into the regression namespace (not
linear_regression). Make the main executable more functional and document it.
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-12-14 21:23:21 UTC (rev 10805)
@@ -1,12 +1,17 @@
+/**
+ * @file linear_regression.cpp
+ * @author James Cline
+ *
+ * Implementation of simple linear regression.
+ */
#include "linear_regression.hpp"
-namespace mlpack {
-namespace linear_regression {
+using namespace mlpack;
+using namespace mlpack::regression;
LinearRegression::LinearRegression(arma::mat& predictors,
const arma::colvec& responses)
{
-
/*
* We want to calculate the a_i coefficients of:
* \sum_{i=0}^n (a_i * x_i^i)
@@ -16,8 +21,7 @@
// We store the number of rows of the predictors.
// Reminder: Armadillo stores the data transposed from how we think of it,
// that is, columns are actually rows (see: column major order).
- size_t n_cols;
- n_cols = predictors.n_cols;
+ size_t n_cols = predictors.n_cols;
// Here we add the row of ones to the predictors.
arma::rowvec ones;
@@ -49,16 +53,14 @@
LinearRegression::~LinearRegression()
{ }
-void LinearRegression::predict(arma::rowvec& predictions,
- const arma::mat& points)
+void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
{
// We get the number of columns and rows of the dataset.
- size_t n_cols, n_rows;
- n_cols = points.n_cols;
- n_rows = points.n_rows;
+ const size_t n_cols = points.n_cols;
+ const size_t n_rows = points.n_rows;
// We want to be sure we have the correct number of dimensions in the dataset.
- assert(n_rows == parameters.n_rows - 1);
+ Log::Assert(n_rows == parameters.n_rows - 1);
predictions.zeros(n_cols);
// We set all the predictions to the intercept value initially.
@@ -73,26 +75,6 @@
// Increment each prediction value by x_i * a_i, or the next dimensional
// coefficient and x value.
predictions(j) += parameters(i) * points(i - 1, j);
-
}
}
}
-
-arma::vec LinearRegression::getParameters()
-{
- return parameters;
-}
-
-
-bool LinearRegression::load(const std::string& filename)
-{
- return data::Load(filename, parameters);
-}
-
-bool LinearRegression::save(const std::string& filename)
-{
- return data::Save(filename, parameters);
-}
-
-}; // namespace linear_regression
-}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp 2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp 2011-12-14 21:23:21 UTC (rev 10805)
@@ -10,7 +10,7 @@
#include <mlpack/core.hpp>
namespace mlpack {
-namespace linear_regression {
+namespace regression /** Regression methods. */ {
/**
* A simple linear regresion algorithm using ordinary least squares.
@@ -24,7 +24,7 @@
* @param predictors X, matrix of data points to create B with.
* @param responses y, the measured data for each point in X
*/
- LinearRegression(arma::mat& predictors, const arma::colvec& responses);
+ LinearRegression(arma::mat& predictors, const arma::vec& responses);
/**
* Initialize the model from a file.
@@ -41,32 +41,16 @@
/**
* Calculate y_i for each data point in points.
*
- * @param predictions y, will contain calculated values on completion.
* @param points the data points to calculate with.
+ * @param predictions y, will contain calculated values on completion.
*/
- void predict(arma::rowvec& predictions, const arma::mat& points);
+ void Predict(const arma::mat& points, arma::vec& predictions);
- /**
- * Returns the model.
- *
- * @return the parameters which describe the least squares solution.
- */
- arma::vec getParameters();
+ //! Return the parameters (the b vector).
+ const arma::vec& Parameters() const { return parameters; }
+ //! Modify the parameters (the b vector).
+ arma::vec& Parameters() { return parameters; }
- /**
- * Saves the model.
- *
- * @param filename the name of the file to load the model from.
- */
- bool save(const std::string& filename);
-
- /**
- * Loads the model.
- *
- * @param filename the name of the file to load the model from.
- */
- bool load(const std::string& filename);
-
private:
/**
* The calculated B.
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2011-12-14 21:23:21 UTC (rev 10805)
@@ -7,73 +7,94 @@
#include <mlpack/core.hpp>
#include "linear_regression.hpp"
-using namespace mlpack;
+PROGRAM_INFO("Simple Linear Regression Prediction",
+ "An implementation of simple linear regression using ordinary least "
+ "squares. This solves the problem\n\n"
+ " y = X * b + e\n\n"
+ "where X (--input_file) and y (the last row of --input_file, or "
+ "--input_responses) are known and b is the desired variable. The "
+ "calculated b is saved to disk (--output_file).\n"
+ "\n"
+ "Optionally, the calculated value of b is used to predict the responses for"
+ " another matrix X' (--test_file):\n\n"
+ " y' = X' * b\n\n"
+ "and these predicted responses, y', are saved to a file "
+ "(--output_predictions).");
-PARAM_STRING_REQ("train", "A file containing X", "X");
-PARAM_STRING_REQ("test", "A file containing data points to predict on",
- "T");
-PARAM_STRING("responses", "A file containing the y values for X; if not "
- "present, it is assumed the last column of train contains these values.",
- "", "R");
+PARAM_STRING_REQ("input_file", "File containing X (regressors).", "i");
+PARAM_STRING("input_responses", "Optional file containing y (responses). If "
+ "not given, the responses are assumed to be the last row of the input "
+ "file.", "r", "");
-PROGRAM_INFO("Simple Linear Regression", "An implementation of simple linear "
- "regression using ordinary least squares.");
+PARAM_STRING("output_file", "File where parameters (b) will be saved.",
+ "o", "parameters.csv");
+PARAM_STRING("test_file", "File containing X' (test regressors).", "t", "");
+PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
+ "where the predicted responses will be saved.", "p", "predictions.csv");
+
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace arma;
+using namespace std;
+
int main(int argc, char* argv[])
{
- arma::vec B;
- arma::colvec responses;
- arma::mat predictors, file, points;
-
// Handle parameters
CLI::ParseCommandLine(argc, argv);
- const std::string train_name =
- CLI::GetParam<std::string>("train");
- const std::string test_name =
- CLI::GetParam<std::string>("test");
- const std::string response_name =
- CLI::GetParam<std::string>("responses");
+ const string train_name = CLI::GetParam<string>("input_file");
+ const string test_name = CLI::GetParam<string>("test_file");
+ const string response_name = CLI::GetParam<string>("input_responses");
+ const string output_file = CLI::GetParam<string>("output_file");
+ const string output_predictions = CLI::GetParam<string>("output_predictions");
- data::Load(train_name.c_str(), file, true);
- size_t n_cols = file.n_cols,
- n_rows = file.n_rows;
+ mat regressors;
+ mat responses;
+ data::Load(train_name.c_str(), regressors, true);
+ // Are the responses in a separate file?
if (response_name == "")
{
- predictors = file.submat(0,0, n_rows-2, n_cols-1);
// The initial predictors for y, Nx1
- responses = arma::trans(file.row(n_rows-1));
- --n_rows;
+ responses = trans(regressors.row(regressors.n_rows - 1));
+ regressors.shed_row(regressors.n_rows - 1);
}
else
{
- predictors = file;
// The initial predictors for y, Nx1
data::Load(response_name.c_str(), responses, true);
- if (responses.n_rows > 1)
+ if (responses.n_rows == 1)
+ responses = trans(responses); // Probably loaded backwards, but that's ok.
+
+ if (responses.n_cols > 1)
Log::Fatal << "The responses must have one column.\n";
- if (responses.n_cols != n_cols)
+ if (responses.n_rows != regressors.n_cols)
Log::Fatal << "The responses must have the same number of rows as the "
"training file.\n";
}
- data::Load(test_name.c_str(), points, true);
+ LinearRegression lr(regressors, responses.unsafe_col(0));
- if (points.n_rows != n_rows)
- Log::Fatal << "The test data must have the same number of columns as the "
- "training file.\n";
+ // Save the parameters.
+ data::Save(output_file.c_str(), lr.Parameters(), false);
- arma::rowvec predictions;
+ // Did we want to predict, too?
+ if (test_name != "")
+ {
+ arma::mat points;
+ data::Load(test_name.c_str(), points, true);
- linear_regression::LinearRegression lr(predictors, responses);
- lr.predict(predictions, points);
+ if (points.n_rows != regressors.n_rows)
+ Log::Fatal << "The test data must have the same number of columns as the "
+ "training file.\n";
- //data.row(n_rows) = predictions;
- //data::Save("out.csv", data);
- //std::cout << "predictions: " << arma::trans(predictions) << '\n';
+ arma::vec predictions;
+ lr.Predict(points, predictions);
- return 0;
+ // Save predictions.
+ data::Save(output_predictions.c_str(), predictions, false);
+ }
}
More information about the mlpack-svn
mailing list