[mlpack-git] master: Add separate Train() method. (ec1f8d7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:40 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c

>---------------------------------------------------------------

commit ec1f8d7fbea94fc28430f21dbeb70123068b5391
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Sep 28 22:04:24 2015 +0000

    Add separate Train() method.


>---------------------------------------------------------------

ec1f8d7fbea94fc28430f21dbeb70123068b5391
 .../linear_regression/linear_regression.cpp        | 41 ++++++++++++----------
 .../linear_regression/linear_regression.hpp        | 41 +++++++++++++++-------
 .../linear_regression/linear_regression_main.cpp   |  2 +-
 src/mlpack/tests/linear_regression_test.cpp        | 21 +++++++++++
 4 files changed, 72 insertions(+), 33 deletions(-)

diff --git a/src/mlpack/methods/linear_regression/linear_regression.cpp b/src/mlpack/methods/linear_regression/linear_regression.cpp
index 36c04a2..14f44da 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.cpp
@@ -18,6 +18,21 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
     lambda(lambda),
     intercept(intercept)
 {
+  Train(predictors, responses, intercept, weights);
+}
+
+LinearRegression::LinearRegression(const LinearRegression& linearRegression) :
+    parameters(linearRegression.parameters),
+    lambda(linearRegression.lambda)
+{ /* Nothing to do. */ }
+
+void LinearRegression::Train(const arma::mat& predictors,
+                             const arma::vec& responses,
+                             const bool intercept,
+                             const arma::vec& weights)
+{
+  this->intercept = intercept;
+
   /*
    * We want to calculate the a_i coefficients of:
    * \sum_{i=0}^n (a_i * x_i^i)
@@ -31,18 +46,19 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
 
   arma::mat p = predictors;
   arma::vec r = responses;
+
   // Here we add the row of ones to the predictors.
   // The intercept is not penalized. Add an "all ones" row to design and set
-  // intercept = false to get a penalized intercept
-  if(intercept)
+  // intercept = false to get a penalized intercept.
+  if (intercept)
   {
     p.insert_rows(0, arma::ones<arma::mat>(1,nCols));
   }
 
-   if(weights.n_elem > 0)
+  if (weights.n_elem > 0)
   {
     p = p * diagmat(sqrt(weights));
-    r =  sqrt(weights) % responses;
+    r = sqrt(weights) % responses;
   }
 
   if (lambda != 0.0)
@@ -52,8 +68,8 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
     // more information.
     p.insert_cols(nCols, predictors.n_rows);
     p.submat(p.n_rows - predictors.n_rows, nCols, p.n_rows - 1, nCols +
-    predictors.n_rows - 1) = sqrt(lambda) * arma::eye<arma::mat>(predictors.n_rows,
-        predictors.n_rows);
+    predictors.n_rows - 1) = sqrt(lambda) *
+        arma::eye<arma::mat>(predictors.n_rows, predictors.n_rows);
   }
 
   // We compute the QR decomposition of the predictors.
@@ -77,19 +93,6 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
   }
 }
 
-LinearRegression::LinearRegression(const std::string& filename) :
-    lambda(0.0)
-{
-  arma::mat parameter;
-  data::Load(filename, parameter, true);
-  parameters = parameter.unsafe_col(0);
-}
-
-LinearRegression::LinearRegression(const LinearRegression& linearRegression) :
-    parameters(linearRegression.parameters),
-    lambda(linearRegression.lambda)
-{ /* Nothing to do. */ }
-
 void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
     const
 {
diff --git a/src/mlpack/methods/linear_regression/linear_regression.hpp b/src/mlpack/methods/linear_regression/linear_regression.hpp
index 511915e..ae4a1b7 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.hpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.hpp
@@ -25,10 +25,10 @@ class LinearRegression
    * Creates the model.
    *
    * @param predictors X, matrix of data points to create B with.
-   * @param responses y, the measured data for each point in X
-   * @param lambda regularization constant
-   * @param intercept include intercept?
-   * @param weights observation weights
+   * @param responses y, the measured data for each point in X.
+   * @param lambda Regularization constant for ridge regression.
+   * @param intercept Whether or not to include an intercept term.
+   * @param weights Observation weights (for boosting).
    */
   LinearRegression(const arma::mat& predictors,
                    const arma::vec& responses,
@@ -37,13 +37,6 @@ class LinearRegression
                    const arma::vec& weights = arma::vec());
 
   /**
-   * Initialize the model from a file.
-   *
-   * @param filename the name of the file to load the model from.
-   */
-  LinearRegression(const std::string& filename);
-
-  /**
    * Copy constructor.
    *
    * @param linearRegression the other instance to copy parameters from.
@@ -51,9 +44,28 @@ class LinearRegression
   LinearRegression(const LinearRegression& linearRegression);
 
   /**
-   * Empty constructor.
+   * Empty constructor.  This gives a non-working model, so make sure Train() is
+   * called (or make sure the model parameters are set) before calling
+   * Predict()!
    */
-  LinearRegression() { }
+  LinearRegression() : lambda(0.0), intercept(true) { }
+
+  /**
+   * Train the LinearRegression model on the given data.  Careful!  This will
+   * completely ignore and overwrite the existing model.  This particular
+   * implementation does not have an incremental training algorithm.  To set the
+   * regularization parameter lambda, call Lambda() or set a different value in
+   * the constructor.
+   *
+   * @param predictors X, the matrix of data points to train the model on.
+   * @param responses y, the vector of responses to each data point.
+   * @param intercept Whether or not to fit an intercept term.
+   * @param weights Observation weights (for boosting).
+   */
+  void Train(const arma::mat& predictors,
+             const arma::vec& responses,
+             const bool intercept = true,
+             const arma::vec& weights = arma::vec());
 
   /**
    * Calculate y_i for each data point in points.
@@ -93,6 +105,9 @@ class LinearRegression
   //! Modify the Tikhonov regularization parameter for ridge regression.
   double& Lambda() { return lambda; }
 
+  //! Return whether or not an intercept term is used in the model.
+  bool Intercept() const { return intercept; }
+
   /**
    * Serialize the model.
    */
diff --git a/src/mlpack/methods/linear_regression/linear_regression_main.cpp b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
index 6aead25..4cdc9c8 100644
--- a/src/mlpack/methods/linear_regression/linear_regression_main.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
     if (!computeModel)
     {
       Timer::Start("load_model");
-      lr = LinearRegression(modelName);
+      //lr = LinearRegression(modelName);
       Timer::Stop("load_model");
     }
 
diff --git a/src/mlpack/tests/linear_regression_test.cpp b/src/mlpack/tests/linear_regression_test.cpp
index 66af7c5..6a0e868 100644
--- a/src/mlpack/tests/linear_regression_test.cpp
+++ b/src/mlpack/tests/linear_regression_test.cpp
@@ -173,4 +173,25 @@ BOOST_AUTO_TEST_CASE(RidgeRegressionTestCase)
     BOOST_REQUIRE_SMALL(predictions(i) - responses(i), .05);
 }
 
+/**
+ * Test that a LinearRegression model trained in the constructor and trained in
+ * the Train() method give the same model.
+ */
+BOOST_AUTO_TEST_CASE(LinearRegressionTrainTest)
+{
+  // Random dataset.
+  arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+  arma::vec responses = arma::randu<arma::vec>(1000);
+
+  LinearRegression lr(dataset, responses, 0.3);
+  LinearRegression lrTrain;
+  lrTrain.Lambda() = 0.3;
+
+  lrTrain.Train(dataset, responses);
+
+  BOOST_REQUIRE_EQUAL(lr.Parameters().n_elem, lrTrain.Parameters().n_elem);
+  for (size_t i = 0; i < lr.Parameters().n_elem; ++i)
+    BOOST_REQUIRE_CLOSE(lr.Parameters()[i], lrTrain.Parameters()[i], 1e-5);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list