[mlpack-svn] r11499 - mlpack/trunk/src/mlpack/methods/linear_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Feb 13 17:43:36 EST 2012
Author: jcline3
Date: 2012-02-13 17:43:36 -0500 (Mon, 13 Feb 2012)
New Revision: 11499
Modified:
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
Log:
Various improvements to LinearRegression and our executable.
Now supports --model_file, -m, to do prediction on an already computed
model. Also does useful user input checking to make sure the right combin-
ation of options are provided.
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2012-02-13 22:40:14 UTC (rev 11498)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2012-02-13 22:43:36 UTC (rev 11499)
@@ -66,6 +66,11 @@
// We want to be sure we have the correct number of dimensions in the dataset.
Log::Assert(nRows == parameters.n_rows - 1);
+ if (nRows != parameters.n_rows -1)
+ {
+ Log::Fatal << "The test data must have the same number of columns as the "
+ "training file.\n";
+ }
predictions.zeros(nCols);
// We set all the predictions to the intercept value initially.
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2012-02-13 22:40:14 UTC (rev 11498)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2012-02-13 22:43:36 UTC (rev 11499)
@@ -21,11 +21,14 @@
"and these predicted responses, y', are saved to a file "
"(--output_predictions).");
-PARAM_STRING_REQ("input_file", "File containing X (regressors).", "i");
+PARAM_STRING("input_file", "File containing X (regressors).", "i", "");
PARAM_STRING("input_responses", "Optional file containing y (responses). If "
"not given, the responses are assumed to be the last row of the input "
"file.", "r", "");
+PARAM_STRING("model_file", "File containing existing model (parameters).", "m",
+ "");
+
PARAM_STRING("output_file", "File where parameters (b) will be saved.",
"o", "parameters.csv");
@@ -43,58 +46,110 @@
// Handle parameters
CLI::ParseCommandLine(argc, argv);
- const string trainName = CLI::GetParam<string>("input_file");
- const string testName = CLI::GetParam<string>("test_file");
- const string responseName = CLI::GetParam<string>("input_responses");
+ const string modelName = CLI::GetParam<string>("model_file");
const string outputFile = CLI::GetParam<string>("output_file");
const string outputPredictions = CLI::GetParam<string>("output_predictions");
+ const string responseName = CLI::GetParam<string>("input_responses");
+ const string testName = CLI::GetParam<string>("test_file");
+ const string trainName = CLI::GetParam<string>("input_file");
mat regressors;
mat responses;
- data::Load(trainName.c_str(), regressors, true);
- // Are the responses in a separate file?
- if (responseName == "")
+ LinearRegression lr;
+
+ bool computeModel;
+
+ if (trainName.empty())
{
- // The initial predictors for y, Nx1
- responses = trans(regressors.row(regressors.n_rows - 1));
- regressors.shed_row(regressors.n_rows - 1);
+ if (modelName.empty())
+ {
+ Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
+ exit(1);
+ }
+ else
+ {
+ computeModel = false;
+ }
}
+ else if (modelName.empty())
+ {
+ computeModel = true;
+ }
else
{
- // The initial predictors for y, Nx1
- data::Load(responseName.c_str(), responses, true);
+ Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
+ exit(1);
+ }
- if (responses.n_rows == 1)
- responses = trans(responses); // Probably loaded backwards, but that's ok.
+ if(!computeModel && testName.empty())
+ {
+ Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
+ exit(1);
+ }
- if (responses.n_cols > 1)
- Log::Fatal << "The responses must have one column.\n";
+ if (computeModel)
+ {
+ Timer::Start("load_regressors");
+ data::Load(trainName.c_str(), regressors, true);
+ Timer::Stop("load_regressors");
- if (responses.n_rows != regressors.n_cols)
- Log::Fatal << "The responses must have the same number of rows as the "
- "training file.\n";
- }
+ // Are the responses in a separate file?
+ if (responseName.empty())
+ {
+ // The initial predictors for y, Nx1
+ responses = trans(regressors.row(regressors.n_rows - 1));
+ regressors.shed_row(regressors.n_rows - 1);
+ }
+ else
+ {
+ // The initial predictors for y, Nx1
+ Timer::Start("load_responses");
+ data::Load(responseName.c_str(), responses, true);
+ Timer::Stop("load_responses");
- LinearRegression lr(regressors, responses.unsafe_col(0));
+ if (responses.n_rows == 1)
+ responses = trans(responses); // Probably loaded backwards, but that's ok.
- // Save the parameters.
- data::Save(outputFile.c_str(), lr.Parameters(), false);
+ if (responses.n_cols > 1)
+ Log::Fatal << "The responses must have one column.\n";
+ if (responses.n_rows != regressors.n_cols)
+ Log::Fatal << "The responses must have the same number of rows as the "
+ "training file.\n";
+ }
+
+ Timer::Start("regression");
+ lr = LinearRegression(regressors, responses.unsafe_col(0));
+ Timer::Stop("regression");
+
+ // Save the parameters.
+ data::Save(outputFile.c_str(), lr.Parameters(), true);
+ }
+
// Did we want to predict, too?
- if (testName != "")
+ if (!testName.empty() )
{
+
+ if (!computeModel)
+ {
+ Timer::Start("load_model");
+ lr = LinearRegression(modelName);
+ Timer::Stop("load_model");
+ }
+
arma::mat points;
+ Timer::Stop("load_test_points");
data::Load(testName.c_str(), points, true);
+ Timer::Stop("load_test_points");
- if (points.n_rows != regressors.n_rows)
- Log::Fatal << "The test data must have the same number of columns as the "
- "training file.\n";
-
arma::vec predictions;
+ Timer::Start("prediction");
lr.Predict(points, predictions);
+ Timer::Stop("prediction");
// Save predictions.
- data::Save(outputPredictions.c_str(), predictions, false);
+ predictions = arma::trans(predictions);
+ data::Save(outputPredictions.c_str(), predictions, true);
}
}
More information about the mlpack-svn
mailing list