[mlpack-git] master: Allow passing a row-major matrix to Predict() to save time. (4842fa9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Aug 19 12:46:49 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7d64cd61aade77d4607aacf00b36f94369fd7cf9...284914879d42af40b028a7f5a002f4cd82c5fe05

>---------------------------------------------------------------

commit 4842fa9fb26fc83bd08de692d30eacf74b5236d3
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Aug 19 12:45:38 2015 -0400

    Allow passing a row-major matrix to Predict() to save time.


>---------------------------------------------------------------

4842fa9fb26fc83bd08de692d30eacf74b5236d3
 src/mlpack/methods/lars/lars.cpp |  9 +++++++--
 src/mlpack/methods/lars/lars.hpp |  8 ++++++--
 src/mlpack/tests/lars_test.cpp   | 30 ++++++++++++++++++++++++++++++
 3 files changed, 43 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/lars/lars.cpp b/src/mlpack/methods/lars/lars.cpp
index 97d644f..fa23f91 100644
--- a/src/mlpack/methods/lars/lars.cpp
+++ b/src/mlpack/methods/lars/lars.cpp
@@ -322,10 +322,15 @@ void LARS::Regress(const arma::mat& matX,
   Timer::Stop("lars_regression");
 }
 
-void LARS::Predict(const arma::mat& points, arma::vec& predictions) const
+void LARS::Predict(const arma::mat& points,
+                   arma::vec& predictions,
+                   const bool rowMajor) const
 {
   // We really only need to store beta internally...
-  predictions = (betaPath.back().t() * points).t();
+  if (rowMajor)
+    predictions = points * betaPath.back();
+  else
+    predictions = (betaPath.back().t() * points).t();
 }
 
 // Private functions.
diff --git a/src/mlpack/methods/lars/lars.hpp b/src/mlpack/methods/lars/lars.hpp
index 7c3cdde..d4ceb9c 100644
--- a/src/mlpack/methods/lars/lars.hpp
+++ b/src/mlpack/methods/lars/lars.hpp
@@ -138,12 +138,16 @@ class LARS
 
   /**
    * Predict y_i for each data point in the given data matrix, using the
-   * currently-trained LARS model (so make sure you run Regress() first).
+   * currently-trained LARS model (so make sure you run Regress() first).  If
+   * the data matrix is row-major (as opposed to the usual column-major format
+   * for mlpack matrices), set rowMajor = true to avoid an extra transpose.
    *
    * @param points The data points to regress on.
    * @param predictions y, which will contained calculated values on completion.
    */
-  void Predict(const arma::mat& points, arma::vec& predictions) const;
+  void Predict(const arma::mat& points,
+               arma::vec& predictions,
+               const bool rowMajor = false) const;
 
   //! Access the set of active dimensions.
   const std::vector<size_t>& ActiveSet() const { return activeSet; }
diff --git a/src/mlpack/tests/lars_test.cpp b/src/mlpack/tests/lars_test.cpp
index d658fca..99a7155 100644
--- a/src/mlpack/tests/lars_test.cpp
+++ b/src/mlpack/tests/lars_test.cpp
@@ -183,6 +183,7 @@ BOOST_AUTO_TEST_CASE(PredictTest)
         lars.Predict(X, predictions);
         arma::vec adjPred = X * predictions;
 
+        BOOST_REQUIRE_EQUAL(predictions.n_elem, 1000);
         for (size_t i = 0; i < betaOptPred.n_elem; ++i)
         {
           if (std::abs(betaOptPred[i]) < 1e-5)
@@ -195,4 +196,33 @@ BOOST_AUTO_TEST_CASE(PredictTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(PredictRowMajorTest)
+{
+  arma::mat X;
+  arma::vec y;
+  GenerateProblem(X, y, 1000, 100);
+
+  // Set lambdas to 0.
+
+  LARS lars(false, 0, 0);
+  arma::vec betaOpt;
+  lars.Regress(X, y, betaOpt);
+
+  // Get both row-major and column-major predictions.  Make sure they are the
+  // same.
+  arma::vec rowMajorPred, colMajorPred;
+
+  lars.Predict(X, colMajorPred);
+  lars.Predict(X.t(), rowMajorPred, true);
+
+  BOOST_REQUIRE_EQUAL(colMajorPred.n_elem, rowMajorPred.n_elem);
+  for (size_t i = 0; i < colMajorPred.n_elem; ++i)
+  {
+    if (std::abs(colMajorPred[i]) < 1e-5)
+      BOOST_REQUIRE_SMALL(rowMajorPred[i], 1e-5);
+    else
+      BOOST_REQUIRE_CLOSE(colMajorPred[i], rowMajorPred[i], 1e-5);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list