[mlpack-svn] r10084 - mlpack/trunk/src/mlpack/methods/linear_regression

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Oct 31 02:28:11 EDT 2011


Author: jcline3
Date: 2011-10-31 02:28:11 -0400 (Mon, 31 Oct 2011)
New Revision: 10084

Modified:
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
   mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp
Log:
Add load and save and a constructor that loads model from file


Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.cpp	2011-10-31 06:28:11 UTC (rev 10084)
@@ -4,11 +4,8 @@
 namespace linear_regression {
 
 LinearRegression::LinearRegression(arma::mat& predictors,
-  const arma::colvec& responses) :
-  predictors(predictors), responses(responses) { }
-
-void LinearRegression::run() {
-
+  const arma::colvec& responses)
+{
   size_t n_cols, n_rows;
 
   n_cols = predictors.n_cols;
@@ -25,9 +22,15 @@
     predictors * responses;
 }
 
+LinearRegression::LinearRegression(const std::string& filename)
+{
+  parameters.load(filename);
+}
+
 LinearRegression::~LinearRegression() {}
 
-void LinearRegression::predict(arma::rowvec& predictions, const arma::mat& points) {
+void LinearRegression::predict(arma::rowvec& predictions, const arma::mat& points)
+{
   size_t n_cols, n_rows;
   n_cols = points.n_cols;
   n_rows = points.n_rows;
@@ -49,5 +52,15 @@
 }
 
 
+bool LinearRegression::load(const std::string& filename)
+{
+  return parameters.load(filename);
+}
+
+bool LinearRegression::save(const std::string& filename)
+{
+  return parameters.save(filename);
+}
+
 }; // namespace linear_regression
 }; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression.hpp	2011-10-31 06:28:11 UTC (rev 10084)
@@ -10,48 +10,47 @@
  */
 class LinearRegression {
   public:
-    /** Initialize parameters. 
+    /** Creates the model.
      *  @param predictors X, matrix of data points to create B with.
-     *  @ param responses y, the measured data for each point in X
+     *  @param responses y, the measured data for each point in X
      */ 
     LinearRegression(arma::mat& predictors, const arma::colvec& responses);
 
-    /** Destructor - no work done.
+    /** Initialize the model from a file.
+     *  @param filename the name of the file to load the model from.
      */
+    LinearRegression(const std::string& filename);
+
+    /** Destructor - no work done. */
     ~LinearRegression();
 
-    /** Create regression coefficients.
-     * y=BX, this creates B.
-     **/
-    void run();
-
     /** Calculate y_i for each data point in points.
      *  @param predictions y, will contain calculated values on completion.
      *  @param points the data points to calculate with.
      */
     void predict(arma::rowvec& predictions, const arma::mat& points);
 
-    /** Returns B.
+    /** Returns the model.
      * @return the parameters which describe the least squares solution.
      */
     arma::vec getParameters();
 
+    /** Saves the model.
+     *  @param filename the name of the file to load the model from.
+     */
+    bool save(const std::string& filename);
+
+    /** Loads the model.
+     *  @param filename the name of the file to load the model from.
+     */
+    bool load(const std::string& filename);
+
   private:
     /** The calculated B.
      * Initialized and filled by constructor to hold the least squares solution.
      */
     arma::vec parameters;
 
-    /** The X values.
-     * The regressor values.
-     */
-    arma::mat& predictors;
-
-    /** The y values for the predictors.
-     * The response variables, corresponding to points in X.
-     */
-    const arma::colvec& responses;
-
 };
 
 }; // namespace linear_regression

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_main.cpp	2011-10-31 06:28:11 UTC (rev 10084)
@@ -64,7 +64,6 @@
   arma::rowvec predictions;
 
   linear_regression::LinearRegression lr(predictors, responses);
-  lr.run();
   lr.predict(predictions, points);
 
   //data.row(n_rows) = predictions;

Modified: mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp	2011-10-31 04:12:31 UTC (rev 10083)
+++ mlpack/trunk/src/mlpack/methods/linear_regression/linear_regression_test.cpp	2011-10-31 06:28:11 UTC (rev 10084)
@@ -37,7 +37,6 @@
 
   // Initialize and predict
   mlpack::linear_regression::LinearRegression lr(predictors, responses);
-  lr.run();
   lr.predict(predictions, points);
 
   // Output result and verify we have less than .5 error from "correct" value




More information about the mlpack-svn mailing list