[mlpack-svn] r15650 - mlpack/trunk/src/mlpack/methods/linear_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Aug 23 15:26:14 EDT 2013


Author: rcurtin
Date: Fri Aug 23 15:26:14 2013
New Revision: 15650

Log:
Add some changes from Sumedh (#298).  This improves the Predict() function and
adds an error check in the executable (I moved the error check from the
Predict() function).


Modified:
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	Fri Aug 23 15:26:14 2013
@@ -55,9 +55,6 @@
   parameters = linearRegression.parameters;
 }
 
-LinearRegression::~LinearRegression()
-{ }
-
 void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
 {
   // We get the number of columns and rows of the dataset.
@@ -65,26 +62,12 @@
   const size_t nRows = points.n_rows;
 
   // We want to be sure we have the correct number of dimensions in the dataset.
-  Log::Assert(nRows == parameters.n_rows - 1);
-  if (nRows != parameters.n_rows -1)
-  {
-    Log::Fatal << "The test data must have the same number of columns as the "
-        "training file.\n";
-  }
+  Log::Assert(points.n_rows == parameters.n_rows - 1);
 
-  predictions.zeros(nCols);
-  // We set all the predictions to the intercept value initially.
-  predictions += parameters(0);
+  // Get the predictions, but this ignores the intercept value (parameters[0]).
+  predictions = arma::trans(arma::trans(
+      parameters(arma::span(1, parameters.n_elem - 1))) * points);
 
-  // Now we iterate through the dimensions of the data and parameters.
-  for (size_t i = 1; i < nRows + 1; ++i)
-  {
-    // Now we iterate through each row, or point, of the data.
-    for (size_t j = 0; j < nCols; ++j)
-    {
-      // Increment each prediction value by x_i * a_i, or the next dimensional
-      // coefficient and x value.
-      predictions(j) += parameters(i) * points(i - 1, j);
-    }
-  }
+  // Now add the intercept.
+  predictions += parameters(0);
 }

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	Fri Aug 23 15:26:14 2013
@@ -41,15 +41,9 @@
   LinearRegression(const LinearRegression& linearRegression);
 
   /**
-   * Default constructor.
+   * Empty constructor.
    */
-  LinearRegression() {}
-
-
-  /**
-   * Destructor - no work done.
-   */
-  ~LinearRegression();
+  LinearRegression() { }
 
   /**
    * Calculate y_i for each data point in points.

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	Fri Aug 23 15:26:14 2013
@@ -60,39 +60,34 @@
 
   bool computeModel;
 
-  // We want to determine if an input file XOR model file were given
-  if (trainName.empty()) // The user specified no input file
+  // We want to determine if an input file XOR model file were given.
+  if (trainName.empty()) // The user specified no input file.
   {
-    if (modelName.empty()) // The user specified no model file, error and exit
-    {
-      Log::Fatal << "You must specify either --input_file or --model_file." << std::endl;
-      exit(1);
-    }
-    else // The model file was specified, no problems
-    {
+    if (modelName.empty()) // The user specified no model file, error and exit.
+      Log::Fatal << "You must specify either --input_file or --model_file."
+          << endl;
+    else // The model file was specified, no problems.
       computeModel = false;
-    }
   }
-  // The user specified an input file but no model file, no problems
+  // The user specified an input file but no model file, no problems.
   else if (modelName.empty())
-  {
     computeModel = true;
-  }
+
   // The user specified both an input file and model file.
-  // This is ambiguous -- which model should we use? A generated one or given one?
-  // Report error and exit.
+  // This is ambiguous -- which model should we use? A generated one or given
+  // one?  Report error and exit.
   else
   {
-      Log::Fatal << "You must specify either --input_file or --model_file, not both." << std::endl;
-      exit(1);
+    Log::Fatal << "You must specify either --input_file or --model_file, not "
+        << "both." << endl;
   }
 
   // If they specified a model file, we also need a test file or we
   // have nothing to do.
-  if(!computeModel && testName.empty())
+  if (!computeModel && testName.empty())
   {
-    Log::Fatal << "When specifying --model_file, you must also specify --test_file." << std::endl;
-    exit(1);
+    Log::Fatal << "When specifying --model_file, you must also specify "
+        << "--test_file." << endl;
   }
 
   // An input file was given and we need to generate the model.
@@ -105,19 +100,19 @@
     // Are the responses in a separate file?
     if (responseName.empty())
     {
-      // The initial predictors for y, Nx1
+      // The initial predictors for y, Nx1.
       responses = trans(regressors.row(regressors.n_rows - 1));
       regressors.shed_row(regressors.n_rows - 1);
     }
     else
     {
-      // The initial predictors for y, Nx1
+      // The initial predictors for y, Nx1.
       Timer::Start("load_responses");
       data::Load(responseName, responses, true);
       Timer::Stop("load_responses");
 
       if (responses.n_rows == 1)
-        responses = trans(responses); // Probably loaded backwards, but that's ok.
+        responses = trans(responses); // Probably loaded backwards.
 
       if (responses.n_cols > 1)
         Log::Fatal << "The responses must have one column.\n";
@@ -136,10 +131,9 @@
   }
 
   // Did we want to predict, too?
-  if (!testName.empty() )
+  if (!testName.empty())
   {
-
-    // A model file was passed in, so load it
+    // A model file was passed in, so load it.
     if (!computeModel)
     {
       Timer::Start("load_model");
@@ -147,13 +141,21 @@
       Timer::Stop("load_model");
     }
 
-    // Load the test file data
+    // Load the test file data.
     arma::mat points;
-    Timer::Stop("load_test_points");
+    Timer::Start("load_test_points");
     data::Load(testName, points, true);
     Timer::Stop("load_test_points");
 
-    // Perform the predictions using our model
+    // Ensure that test file data has the right number of features.
+    if ((lr.Parameters().n_elem - 1) != points.n_rows)
+    {
+      Log::Fatal << "The model was trained on " << lr.Parameters().n_elem - 1
+          << "-dimensional data, but the test points in '" << testName
+          << "' are " << points.n_rows << "-dimensional!" << endl;
+    }
+
+    // Perform the predictions using our model.
     arma::vec predictions;
     Timer::Start("prediction");
     lr.Predict(points, predictions);



More information about the mlpack-svn mailing list