[mlpack-svn] r10084 - mlpack/trunk/src/mlpack/methods/linear_regression
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Oct 31 02:28:11 EDT 2011
Author: jcline3
Date: 2011-10-31 02:28:11 -0400 (Mon, 31 Oct 2011)
New Revision: 10084
Modified:
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp
Log:
Add load and save and a constructor that loads model from file
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp 2011-10-31 06:28:11 UTC (rev 10084)
@@ -4,11 +4,8 @@
namespace linear_regression {
LinearRegression::LinearRegression(arma::mat& predictors,
- const arma::colvec& responses) :
- predictors(predictors), responses(responses) { }
-
-void LinearRegression::run() {
-
+ const arma::colvec& responses)
+{
size_t n_cols, n_rows;
n_cols = predictors.n_cols;
@@ -25,9 +22,15 @@
predictors * responses;
}
+LinearRegression::LinearRegression(const std::string& filename)
+{
+ parameters.load(filename);
+}
+
LinearRegression::~LinearRegression() {}
-void LinearRegression::predict(arma::rowvec& predictions, const arma::mat& points) {
+void LinearRegression::predict(arma::rowvec& predictions, const arma::mat& points)
+{
size_t n_cols, n_rows;
n_cols = points.n_cols;
n_rows = points.n_rows;
@@ -49,5 +52,15 @@
}
+bool LinearRegression::load(const std::string& filename)
+{
+ return parameters.load(filename);
+}
+
+bool LinearRegression::save(const std::string& filename)
+{
+ return parameters.save(filename);
+}
+
}; // namespace linear_regression
}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp 2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp 2011-10-31 06:28:11 UTC (rev 10084)
@@ -10,48 +10,47 @@
*/
class LinearRegression {
public:
- /** Initialize parameters.
+ /** Creates the model.
* @param predictors X, matrix of data points to create B with.
- * @ param responses y, the measured data for each point in X
+ * @param responses y, the measured data for each point in X
*/
LinearRegression(arma::mat& predictors, const arma::colvec& responses);
- /** Destructor - no work done.
+ /** Initialize the model from a file.
+ * @param filename the name of the file to load the model from.
*/
+ LinearRegression(const std::string& filename);
+
+ /** Destructor - no work done. */
~LinearRegression();
- /** Create regression coefficients.
- * y=BX, this creates B.
- **/
- void run();
-
/** Calculate y_i for each data point in points.
* @param predictions y, will contain calculated values on completion.
* @param points the data points to calculate with.
*/
void predict(arma::rowvec& predictions, const arma::mat& points);
- /** Returns B.
+ /** Returns the model.
* @return the parameters which describe the least squares solution.
*/
arma::vec getParameters();
+ /** Saves the model.
+ * @param filename the name of the file to load the model from.
+ */
+ bool save(const std::string& filename);
+
+ /** Loads the model.
+ * @param filename the name of the file to load the model from.
+ */
+ bool load(const std::string& filename);
+
private:
/** The calculated B.
* Initialized and filled by constructor to hold the least squares solution.
*/
arma::vec parameters;
- /** The X values.
- * The regressor values.
- */
- arma::mat& predictors;
-
- /** The y values for the predictors.
- * The response variables, corresponding to points in X.
- */
- const arma::colvec& responses;
-
};
}; // namespace linear_regression
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp 2011-10-31 06:28:11 UTC (rev 10084)
@@ -64,7 +64,6 @@
arma::rowvec predictions;
linear_regression::LinearRegression lr(predictors, responses);
- lr.run();
lr.predict(predictions, points);
//data.row(n_rows) = predictions;
Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp 2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp 2011-10-31 06:28:11 UTC (rev 10084)
@@ -37,7 +37,6 @@
// Initialize and predict
mlpack::linear_regression::LinearRegression lr(predictors, responses);
- lr.run();
lr.predict(predictions, points);
// Output result and verify we have less than .5 error from "correct" value
More information about the mlpack-svn
mailing list