[mlpack-svn] r15650 - mlpack/trunk/src/mlpack/methods/linear_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Aug 23 15:26:14 EDT 2013
Author: rcurtin
Date: Fri Aug 23 15:26:14 2013
New Revision: 15650
Log:
Add some changes from Sumedh (#298). This improves the Predict() function and
adds an error check in the executable (I moved the error check from the
Predict() function).
Modified:
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp Fri Aug 23 15:26:14 2013
@@ -55,9 +55,6 @@
parameters = linearRegression.parameters;
}
-LinearRegression::~LinearRegression()
-{ }
-
void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
{
// We get the number of columns and rows of the dataset.
@@ -65,26 +62,12 @@
const size_t nRows = points.n_rows;
// We want to be sure we have the correct number of dimensions in the dataset.
- Log::Assert(nRows == parameters.n_rows - 1);
- if (nRows != parameters.n_rows -1)
- {
- Log::Fatal << "The test data must have the same number of columns as the "
- "training file.\n";
- }
+ Log::Assert(points.n_rows == parameters.n_rows - 1);
- predictions.zeros(nCols);
- // We set all the predictions to the intercept value initially.
- predictions += parameters(0);
+ // Get the predictions, but this ignores the intercept value (parameters[0]).
+ predictions = arma::trans(arma::trans(
+ parameters(arma::span(1, parameters.n_elem - 1))) * points);
- // Now we iterate through the dimensions of the data and parameters.
- for (size_t i = 1; i < nRows + 1; ++i)
- {
- // Now we iterate through each row, or point, of the data.
- for (size_t j = 0; j < nCols; ++j)
- {
- // Increment each prediction value by x_i * a_i, or the next dimensional
- // coefficient and x value.
- predictions(j) += parameters(i) * points(i - 1, j);
- }
- }
+ // Now add the intercept.
+ predictions += parameters(0);
}
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp Fri Aug 23 15:26:14 2013
@@ -41,15 +41,9 @@
LinearRegression(const LinearRegression& linearRegression);
/**
- * Default constructor.
+ * Empty constructor.
*/
- LinearRegression() {}
-
-
- /**
- * Destructor - no work done.
- */
- ~LinearRegression();
+ LinearRegression() { }
/**
* Calculate y_i for each data point in points.
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp Fri Aug 23 15:26:14 2013
@@ -60,39 +60,34 @@
bool computeModel;
- // We want to determine if an input file XOR model file were given
- if (trainName.empty()) // The user specified no input file
+ // We want to determine if an input file XOR model file were given.
+ if (trainName.empty()) // The user specified no input file.
{
- if (modelName.empty()) // The user specified no model file, error and exit
- {
- Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
- exit(1);
- }
- else // The model file was specified, no problems
- {
+ if (modelName.empty()) // The user specified no model file, error and exit.
+ Log::Fatal << "You must specify either --input_file or --model_file."
+ << endl;
+ else // The model file was specified, no problems.
computeModel = false;
- }
}
- // The user specified an input file but no model file, no problems
+ // The user specified an input file but no model file, no problems.
else if (modelName.empty())
- {
computeModel = true;
- }
+
// The user specified both an input file and model file.
- // This is ambiguous -- which model should we use? A generated one or given one?
- // Report error and exit.
+ // This is ambiguous -- which model should we use? A generated one or given
+ // one? Report error and exit.
else
{
- Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
- exit(1);
+ Log::Fatal << "You must specify either --input_file or --model_file, not "
+ << "both." << endl;
}
// If they specified a model file, we also need a test file or we
// have nothing to do.
- if(!computeModel && testName.empty())
+ if (!computeModel && testName.empty())
{
- Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
- exit(1);
+ Log::Fatal << "When specifying --model_file, you must also specify "
+ << "--test_file." << endl;
}
// An input file was given and we need to generate the model.
@@ -105,19 +100,19 @@
// Are the responses in a separate file?
if (responseName.empty())
{
- // The initial predictors for y, Nx1
+ // The initial predictors for y, Nx1.
responses = trans(regressors.row(regressors.n_rows - 1));
regressors.shed_row(regressors.n_rows - 1);
}
else
{
- // The initial predictors for y, Nx1
+ // The initial predictors for y, Nx1.
Timer::Start("load_responses");
data::Load(responseName, responses, true);
Timer::Stop("load_responses");
if (responses.n_rows == 1)
- responses = trans(responses); // Probably loaded backwards, but that's ok.
+ responses = trans(responses); // Probably loaded backwards.
if (responses.n_cols > 1)
Log::Fatal << "The responses must have one column.\n";
@@ -136,10 +131,9 @@
}
// Did we want to predict, too?
- if (!testName.empty() )
+ if (!testName.empty())
{
-
- // A model file was passed in, so load it
+ // A model file was passed in, so load it.
if (!computeModel)
{
Timer::Start("load_model");
@@ -147,13 +141,21 @@
Timer::Stop("load_model");
}
- // Load the test file data
+ // Load the test file data.
arma::mat points;
- Timer::Stop("load_test_points");
+ Timer::Start("load_test_points");
data::Load(testName, points, true);
Timer::Stop("load_test_points");
- // Perform the predictions using our model
+ // Ensure that test file data has the right number of features.
+ if ((lr.Parameters().n_elem - 1) != points.n_rows)
+ {
+ Log::Fatal << "The model was trained on " << lr.Parameters().n_elem - 1
+ << "-dimensional data, but the test points in '" << testName
+ << "' are " << points.n_rows << "-dimensional!" << endl;
+ }
+
+ // Perform the predictions using our model.
arma::vec predictions;
Timer::Start("prediction");
lr.Predict(points, predictions);
More information about the mlpack-svn
mailing list