[mlpack-git] master: Add a Predict() method to LARS. (7d64cd6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Aug 19 12:21:41 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/71ea42ffea6f3816c642a8052e7443f526bc5f14...7d64cd61aade77d4607aacf00b36f94369fd7cf9

>---------------------------------------------------------------

commit 7d64cd61aade77d4607aacf00b36f94369fd7cf9
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Aug 19 12:21:28 2015 -0400

    Add a Predict() method to LARS.


>---------------------------------------------------------------

7d64cd61aade77d4607aacf00b36f94369fd7cf9
 src/mlpack/methods/lars/lars.cpp |  6 ++++++
 src/mlpack/methods/lars/lars.hpp | 11 ++++++++++-
 src/mlpack/tests/lars_test.cpp   | 42 +++++++++++++++++++++++++++++++++++++++-
 3 files changed, 57 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/lars/lars.cpp b/src/mlpack/methods/lars/lars.cpp
index 914d013..97d644f 100644
--- a/src/mlpack/methods/lars/lars.cpp
+++ b/src/mlpack/methods/lars/lars.cpp
@@ -322,6 +322,12 @@ void LARS::Regress(const arma::mat& matX,
   Timer::Stop("lars_regression");
 }
 
+void LARS::Predict(const arma::mat& points, arma::vec& predictions) const
+{
+  // We really only need to store beta internally...
+  predictions = (betaPath.back().t() * points).t();
+}
+
 // Private functions.
 void LARS::Deactivate(const size_t activeVarInd)
 {
diff --git a/src/mlpack/methods/lars/lars.hpp b/src/mlpack/methods/lars/lars.hpp
index f67bf45..7c3cdde 100644
--- a/src/mlpack/methods/lars/lars.hpp
+++ b/src/mlpack/methods/lars/lars.hpp
@@ -136,6 +136,15 @@ class LARS
                arma::vec& beta,
                const bool transposeData = true);
 
+  /**
+   * Predict y_i for each data point in the given data matrix, using the
+   * currently-trained LARS model (so make sure you run Regress() first).
+   *
+   * @param points The data points to regress on.
+   * @param predictions y, which will contained calculated values on completion.
+   */
+  void Predict(const arma::mat& points, arma::vec& predictions) const;
+
   //! Access the set of active dimensions.
   const std::vector<size_t>& ActiveSet() const { return activeSet; }
 
@@ -147,7 +156,7 @@ class LARS
   //! the last element.
   const std::vector<double>& LambdaPath() const { return lambdaPath; }
 
-  //! Access the upper triangular cholesky factor
+  //! Access the upper triangular cholesky factor.
   const arma::mat& MatUtriCholFactor() const { return matUtriCholFactor; }
 
   // Returns a string representation of this object.
diff --git a/src/mlpack/tests/lars_test.cpp b/src/mlpack/tests/lars_test.cpp
index d995a67..d658fca 100644
--- a/src/mlpack/tests/lars_test.cpp
+++ b/src/mlpack/tests/lars_test.cpp
@@ -55,7 +55,7 @@ void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
   arma::mat X;
   arma::vec y;
 
-  for(size_t i = 0; i < 100; i++)
+  for (size_t i = 0; i < 100; i++)
   {
     GenerateProblem(X, y, nPoints, nDims);
 
@@ -155,4 +155,44 @@ BOOST_AUTO_TEST_CASE(NoCholeskySingularityTest)
   }
 }
 
+// Make sure that Predict() provides reasonable enough solutions.
+BOOST_AUTO_TEST_CASE(PredictTest)
+{
+  for (size_t i = 0; i < 2; ++i)
+  {
+    // Run with both true and false.
+    bool useCholesky = bool(i);
+
+    arma::mat X;
+    arma::vec y;
+
+    GenerateProblem(X, y, 1000, 100);
+
+    for (double lambda1 = 0.0; lambda1 < 1.0; lambda1 += 0.2)
+    {
+      for (double lambda2 = 0.0; lambda2 < 1.0; lambda2 += 0.2)
+      {
+        LARS lars(useCholesky, lambda1, lambda2);
+        arma::vec betaOpt;
+        lars.Regress(X, y, betaOpt);
+
+        // Calculate what the actual error should be with these regression
+        // parameters.
+        arma::vec betaOptPred = (X * X.t()) * betaOpt;
+        arma::vec predictions;
+        lars.Predict(X, predictions);
+        arma::vec adjPred = X * predictions;
+
+        for (size_t i = 0; i < betaOptPred.n_elem; ++i)
+        {
+          if (std::abs(betaOptPred[i]) < 1e-5)
+            BOOST_REQUIRE_SMALL(adjPred[i], 1e-5);
+          else
+            BOOST_REQUIRE_CLOSE(adjPred[i], betaOptPred[i], 1e-5);
+        }
+      }
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list