[mlpack-svn] r10805 - mlpack/trunk/src/mlpack/methods/linear_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 16:23:22 EST 2011


Author: rcurtin
Date: 2011-12-14 16:23:21 -0500 (Wed, 14 Dec 2011)
New Revision: 10805

Modified:
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
Log:
Redo LinearRegression API and put it into the regression namespace (not
linear_regression).  Make the main executable more functional and document it.


Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2011-12-14 21:23:21 UTC (rev 10805)
@@ -1,12 +1,17 @@
+/**
+ * @file linear_regression.cpp
+ * @author James Cline
+ *
+ * Implementation of simple linear regression.
+ */
 #include "linear_regression.hpp"
 
-namespace mlpack {
-namespace linear_regression {
+using namespace mlpack;
+using namespace mlpack::regression;
 
 LinearRegression::LinearRegression(arma::mat& predictors,
                                    const arma::colvec& responses)
 {
-
   /*
    * We want to calculate the a_i coefficients of:
    * \sum_{i=0}^n (a_i * x_i^i)
@@ -16,8 +21,7 @@
   // We store the number of rows of the predictors.
   // Reminder: Armadillo stores the data transposed from how we think of it,
   //           that is, columns are actually rows (see: column major order).
-  size_t n_cols;
-  n_cols = predictors.n_cols;
+  size_t n_cols = predictors.n_cols;
 
   // Here we add the row of ones to the predictors.
   arma::rowvec ones;
@@ -49,16 +53,14 @@
 LinearRegression::~LinearRegression()
 { }
 
-void LinearRegression::predict(arma::rowvec& predictions,
-                               const arma::mat& points)
+void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
 {
   // We get the number of columns and rows of the dataset.
-  size_t n_cols, n_rows;
-  n_cols = points.n_cols;
-  n_rows = points.n_rows;
+  const size_t n_cols = points.n_cols;
+  const size_t n_rows = points.n_rows;
 
   // We want to be sure we have the correct number of dimensions in the dataset.
-  assert(n_rows == parameters.n_rows - 1);
+  Log::Assert(n_rows == parameters.n_rows - 1);
 
   predictions.zeros(n_cols);
   // We set all the predictions to the intercept value initially.
@@ -73,26 +75,6 @@
       // Increment each prediction value by x_i * a_i, or the next dimensional
       // coefficient and x value.
       predictions(j) += parameters(i) * points(i - 1, j);
-
     }
   }
 }
-
-arma::vec LinearRegression::getParameters()
-{
-  return parameters;
-}
-
-
-bool LinearRegression::load(const std::string& filename)
-{
-  return data::Load(filename, parameters);
-}
-
-bool LinearRegression::save(const std::string& filename)
-{
-  return data::Save(filename, parameters);
-}
-
-}; // namespace linear_regression
-}; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	2011-12-14 21:23:21 UTC (rev 10805)
@@ -10,7 +10,7 @@
 #include <mlpack/core.hpp>
 
 namespace mlpack {
-namespace linear_regression {
+namespace regression /** Regression methods. */ {
 
 /**
  * A simple linear regresion algorithm using ordinary least squares.
@@ -24,7 +24,7 @@
    * @param predictors X, matrix of data points to create B with.
    * @param responses y, the measured data for each point in X
    */
-  LinearRegression(arma::mat& predictors, const arma::colvec& responses);
+  LinearRegression(arma::mat& predictors, const arma::vec& responses);
 
   /**
    * Initialize the model from a file.
@@ -41,32 +41,16 @@
   /**
    * Calculate y_i for each data point in points.
    *
-   * @param predictions y, will contain calculated values on completion.
    * @param points the data points to calculate with.
+   * @param predictions y, will contain calculated values on completion.
    */
-  void predict(arma::rowvec& predictions, const arma::mat& points);
+  void Predict(const arma::mat& points, arma::vec& predictions);
 
-  /**
-   * Returns the model.
-   *
-   * @return the parameters which describe the least squares solution.
-   */
-  arma::vec getParameters();
+  //! Return the parameters (the b vector).
+  const arma::vec& Parameters() const { return parameters; }
+  //! Modify the parameters (the b vector).
+  arma::vec& Parameters() { return parameters; }
 
-  /**
-   * Saves the model.
-   *
-   * @param filename the name of the file to load the model from.
-   */
-  bool save(const std::string& filename);
-
-  /**
-   * Loads the model.
-   *
-   * @param filename the name of the file to load the model from.
-   */
-  bool load(const std::string& filename);
-
  private:
   /**
    * The calculated B.

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2011-12-14 20:11:27 UTC (rev 10804)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2011-12-14 21:23:21 UTC (rev 10805)
@@ -7,73 +7,94 @@
 #include <mlpack/core.hpp>
 #include "linear_regression.hpp"
 
-using namespace mlpack;
+PROGRAM_INFO("Simple Linear Regression Prediction",
+    "An implementation of simple linear regression using ordinary least "
+    "squares. This solves the problem\n\n"
+    "  y = X * b + e\n\n"
+    "where X (--input_file) and y (the last row of --input_file, or "
+    "--input_responses) are known and b is the desired variable.  The "
+    "calculated b is saved to disk (--output_file).\n"
+    "\n"
+    "Optionally, the calculated value of b is used to predict the responses for"
+    " another matrix X' (--test_file):\n\n"
+    "   y' = X' * b\n\n"
+    "and these predicted responses, y', are saved to a file "
+    "(--output_predictions).");
 
-PARAM_STRING_REQ("train", "A file containing X", "X");
-PARAM_STRING_REQ("test", "A file containing data points to predict on",
-    "T");
-PARAM_STRING("responses", "A file containing the y values for X; if not "
-    "present, it is assumed the last column of train contains these values.",
-    "", "R");
+PARAM_STRING_REQ("input_file", "File containing X (regressors).", "i");
+PARAM_STRING("input_responses", "Optional file containing y (responses). If "
+    "not given, the responses are assumed to be the last row of the input "
+    "file.", "r", "");
 
-PROGRAM_INFO("Simple Linear Regression", "An implementation of simple linear "
-    "regression using ordinary least squares.");
+PARAM_STRING("output_file", "File where parameters (b) will be saved.",
+    "o", "parameters.csv");
 
+PARAM_STRING("test_file", "File containing X' (test regressors).", "t", "");
+PARAM_STRING("output_predictions", "If --test_file is specified, this file is "
+    "where the predicted responses will be saved.", "p", "predictions.csv");
+
+using namespace mlpack;
+using namespace mlpack::regression;
+using namespace arma;
+using namespace std;
+
 int main(int argc, char* argv[])
 {
-  arma::vec B;
-  arma::colvec responses;
-  arma::mat predictors, file, points;
-
   // Handle parameters
   CLI::ParseCommandLine(argc, argv);
 
-  const std::string train_name =
-      CLI::GetParam<std::string>("train");
-  const std::string test_name =
-      CLI::GetParam<std::string>("test");
-  const std::string response_name =
-      CLI::GetParam<std::string>("responses");
+  const string train_name = CLI::GetParam<string>("input_file");
+  const string test_name = CLI::GetParam<string>("test_file");
+  const string response_name = CLI::GetParam<string>("input_responses");
+  const string output_file = CLI::GetParam<string>("output_file");
+  const string output_predictions = CLI::GetParam<string>("output_predictions");
 
-  data::Load(train_name.c_str(), file, true);
-  size_t n_cols = file.n_cols,
-         n_rows = file.n_rows;
+  mat regressors;
+  mat responses;
+  data::Load(train_name.c_str(), regressors, true);
 
+  // Are the responses in a separate file?
   if (response_name == "")
   {
-    predictors = file.submat(0,0, n_rows-2, n_cols-1);
     // The initial predictors for y, Nx1
-    responses = arma::trans(file.row(n_rows-1));
-    --n_rows;
+    responses = trans(regressors.row(regressors.n_rows - 1));
+    regressors.shed_row(regressors.n_rows - 1);
   }
   else
   {
-    predictors = file;
     // The initial predictors for y, Nx1
     data::Load(response_name.c_str(), responses, true);
 
-    if (responses.n_rows > 1)
+    if (responses.n_rows == 1)
+      responses = trans(responses); // Probably loaded backwards, but that's ok.
+
+    if (responses.n_cols > 1)
       Log::Fatal << "The responses must have one column.\n";
 
-    if (responses.n_cols != n_cols)
+    if (responses.n_rows != regressors.n_cols)
       Log::Fatal << "The responses must have the same number of rows as the "
           "training file.\n";
   }
 
-  data::Load(test_name.c_str(), points, true);
+  LinearRegression lr(regressors, responses.unsafe_col(0));
 
-  if (points.n_rows != n_rows)
-    Log::Fatal << "The test data must have the same number of columns as the "
-        "training file.\n";
+  // Save the parameters.
+  data::Save(output_file.c_str(), lr.Parameters(), false);
 
-  arma::rowvec predictions;
+  // Did we want to predict, too?
+  if (test_name != "")
+  {
+    arma::mat points;
+    data::Load(test_name.c_str(), points, true);
 
-  linear_regression::LinearRegression lr(predictors, responses);
-  lr.predict(predictions, points);
+    if (points.n_rows != regressors.n_rows)
+      Log::Fatal << "The test data must have the same number of columns as the "
+          "training file.\n";
 
-  //data.row(n_rows) = predictions;
-  //data::Save("out.csv", data);
-  //std::cout << "predictions: " << arma::trans(predictions) << '\n';
+    arma::vec predictions;
+    lr.Predict(points, predictions);
 
-  return 0;
+    // Save predictions.
+    data::Save(output_predictions.c_str(), predictions, false);
+  }
 }




More information about the mlpack-svn mailing list