[mlpack-git] master: Add separate Train() method. (ec1f8d7)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:40 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c
>---------------------------------------------------------------
commit ec1f8d7fbea94fc28430f21dbeb70123068b5391
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Sep 28 22:04:24 2015 +0000
Add separate Train() method.
>---------------------------------------------------------------
ec1f8d7fbea94fc28430f21dbeb70123068b5391
.../linear_regression/linear_regression.cpp | 41 ++++++++++++----------
.../linear_regression/linear_regression.hpp | 41 +++++++++++++++-------
.../linear_regression/linear_regression_main.cpp | 2 +-
src/mlpack/tests/linear_regression_test.cpp | 21 +++++++++++
4 files changed, 72 insertions(+), 33 deletions(-)
diff --git a/src/mlpack/methods/linear_regression/linear_regression.cpp b/src/mlpack/methods/linear_regression/linear_regression.cpp
index 36c04a2..14f44da 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.cpp
@@ -18,6 +18,21 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
lambda(lambda),
intercept(intercept)
{
+ Train(predictors, responses, intercept, weights);
+}
+
+LinearRegression::LinearRegression(const LinearRegression& linearRegression) :
+ parameters(linearRegression.parameters),
+ lambda(linearRegression.lambda)
+{ /* Nothing to do. */ }
+
+void LinearRegression::Train(const arma::mat& predictors,
+ const arma::vec& responses,
+ const bool intercept,
+ const arma::vec& weights)
+{
+ this->intercept = intercept;
+
/*
* We want to calculate the a_i coefficients of:
* \sum_{i=0}^n (a_i * x_i^i)
@@ -31,18 +46,19 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
arma::mat p = predictors;
arma::vec r = responses;
+
// Here we add the row of ones to the predictors.
// The intercept is not penalized. Add an "all ones" row to design and set
- // intercept = false to get a penalized intercept
- if(intercept)
+ // intercept = false to get a penalized intercept.
+ if (intercept)
{
p.insert_rows(0, arma::ones<arma::mat>(1,nCols));
}
- if(weights.n_elem > 0)
+ if (weights.n_elem > 0)
{
p = p * diagmat(sqrt(weights));
- r = sqrt(weights) % responses;
+ r = sqrt(weights) % responses;
}
if (lambda != 0.0)
@@ -52,8 +68,8 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
// more information.
p.insert_cols(nCols, predictors.n_rows);
p.submat(p.n_rows - predictors.n_rows, nCols, p.n_rows - 1, nCols +
- predictors.n_rows - 1) = sqrt(lambda) * arma::eye<arma::mat>(predictors.n_rows,
- predictors.n_rows);
+ predictors.n_rows - 1) = sqrt(lambda) *
+ arma::eye<arma::mat>(predictors.n_rows, predictors.n_rows);
}
// We compute the QR decomposition of the predictors.
@@ -77,19 +93,6 @@ LinearRegression::LinearRegression(const arma::mat& predictors,
}
}
-LinearRegression::LinearRegression(const std::string& filename) :
- lambda(0.0)
-{
- arma::mat parameter;
- data::Load(filename, parameter, true);
- parameters = parameter.unsafe_col(0);
-}
-
-LinearRegression::LinearRegression(const LinearRegression& linearRegression) :
- parameters(linearRegression.parameters),
- lambda(linearRegression.lambda)
-{ /* Nothing to do. */ }
-
void LinearRegression::Predict(const arma::mat& points, arma::vec& predictions)
const
{
diff --git a/src/mlpack/methods/linear_regression/linear_regression.hpp b/src/mlpack/methods/linear_regression/linear_regression.hpp
index 511915e..ae4a1b7 100644
--- a/src/mlpack/methods/linear_regression/linear_regression.hpp
+++ b/src/mlpack/methods/linear_regression/linear_regression.hpp
@@ -25,10 +25,10 @@ class LinearRegression
* Creates the model.
*
* @param predictors X, matrix of data points to create B with.
- * @param responses y, the measured data for each point in X
- * @param lambda regularization constant
- * @param intercept include intercept?
- * @param weights observation weights
+ * @param responses y, the measured data for each point in X.
+ * @param lambda Regularization constant for ridge regression.
+ * @param intercept Whether or not to include an intercept term.
+ * @param weights Observation weights (for boosting).
*/
LinearRegression(const arma::mat& predictors,
const arma::vec& responses,
@@ -37,13 +37,6 @@ class LinearRegression
const arma::vec& weights = arma::vec());
/**
- * Initialize the model from a file.
- *
- * @param filename the name of the file to load the model from.
- */
- LinearRegression(const std::string& filename);
-
- /**
* Copy constructor.
*
* @param linearRegression the other instance to copy parameters from.
@@ -51,9 +44,28 @@ class LinearRegression
LinearRegression(const LinearRegression& linearRegression);
/**
- * Empty constructor.
+ * Empty constructor. This gives a non-working model, so make sure Train() is
+ * called (or make sure the model parameters are set) before calling
+ * Predict()!
*/
- LinearRegression() { }
+ LinearRegression() : lambda(0.0), intercept(true) { }
+
+ /**
+ * Train the LinearRegression model on the given data. Careful! This will
+ * completely ignore and overwrite the existing model. This particular
+ * implementation does not have an incremental training algorithm. To set the
+ * regularization parameter lambda, call Lambda() or set a different value in
+ * the constructor.
+ *
+ * @param predictors X, the matrix of data points to train the model on.
+ * @param responses y, the vector of responses to each data point.
+ * @param intercept Whether or not to fit an intercept term.
+ * @param weights Observation weights (for boosting).
+ */
+ void Train(const arma::mat& predictors,
+ const arma::vec& responses,
+ const bool intercept = true,
+ const arma::vec& weights = arma::vec());
/**
* Calculate y_i for each data point in points.
@@ -93,6 +105,9 @@ class LinearRegression
//! Modify the Tikhonov regularization parameter for ridge regression.
double& Lambda() { return lambda; }
+ //! Return whether or not an intercept term is used in the model.
+ bool Intercept() const { return intercept; }
+
/**
* Serialize the model.
*/
diff --git a/src/mlpack/methods/linear_regression/linear_regression_main.cpp b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
index 6aead25..4cdc9c8 100644
--- a/src/mlpack/methods/linear_regression/linear_regression_main.cpp
+++ b/src/mlpack/methods/linear_regression/linear_regression_main.cpp
@@ -145,7 +145,7 @@ int main(int argc, char* argv[])
if (!computeModel)
{
Timer::Start("load_model");
- lr = LinearRegression(modelName);
+ //lr = LinearRegression(modelName);
Timer::Stop("load_model");
}
diff --git a/src/mlpack/tests/linear_regression_test.cpp b/src/mlpack/tests/linear_regression_test.cpp
index 66af7c5..6a0e868 100644
--- a/src/mlpack/tests/linear_regression_test.cpp
+++ b/src/mlpack/tests/linear_regression_test.cpp
@@ -173,4 +173,25 @@ BOOST_AUTO_TEST_CASE(RidgeRegressionTestCase)
BOOST_REQUIRE_SMALL(predictions(i) - responses(i), .05);
}
+/**
+ * Test that a LinearRegression model trained in the constructor and trained in
+ * the Train() method give the same model.
+ */
+BOOST_AUTO_TEST_CASE(LinearRegressionTrainTest)
+{
+ // Random dataset.
+ arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+ arma::vec responses = arma::randu<arma::vec>(1000);
+
+ LinearRegression lr(dataset, responses, 0.3);
+ LinearRegression lrTrain;
+ lrTrain.Lambda() = 0.3;
+
+ lrTrain.Train(dataset, responses);
+
+ BOOST_REQUIRE_EQUAL(lr.Parameters().n_elem, lrTrain.Parameters().n_elem);
+ for (size_t i = 0; i < lr.Parameters().n_elem; ++i)
+ BOOST_REQUIRE_CLOSE(lr.Parameters()[i], lrTrain.Parameters()[i], 1e-5);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list