[mlpack-git] master: Allow passing a row-major matrix to Predict() to save time. (4842fa9)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Aug 19 12:46:49 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7d64cd61aade77d4607aacf00b36f94369fd7cf9...284914879d42af40b028a7f5a002f4cd82c5fe05
>---------------------------------------------------------------
commit 4842fa9fb26fc83bd08de692d30eacf74b5236d3
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Aug 19 12:45:38 2015 -0400
Allow passing a row-major matrix to Predict() to save time.
>---------------------------------------------------------------
4842fa9fb26fc83bd08de692d30eacf74b5236d3
src/mlpack/methods/lars/lars.cpp | 9 +++++++--
src/mlpack/methods/lars/lars.hpp | 8 ++++++--
src/mlpack/tests/lars_test.cpp | 30 ++++++++++++++++++++++++++++++
3 files changed, 43 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/lars/lars.cpp b/src/mlpack/methods/lars/lars.cpp
index 97d644f..fa23f91 100644
--- a/src/mlpack/methods/lars/lars.cpp
+++ b/src/mlpack/methods/lars/lars.cpp
@@ -322,10 +322,15 @@ void LARS::Regress(const arma::mat& matX,
Timer::Stop("lars_regression");
}
-void LARS::Predict(const arma::mat& points, arma::vec& predictions) const
+void LARS::Predict(const arma::mat& points,
+ arma::vec& predictions,
+ const bool rowMajor) const
{
// We really only need to store beta internally...
- predictions = (betaPath.back().t() * points).t();
+ if (rowMajor)
+ predictions = points * betaPath.back();
+ else
+ predictions = (betaPath.back().t() * points).t();
}
// Private functions.
diff --git a/src/mlpack/methods/lars/lars.hpp b/src/mlpack/methods/lars/lars.hpp
index 7c3cdde..d4ceb9c 100644
--- a/src/mlpack/methods/lars/lars.hpp
+++ b/src/mlpack/methods/lars/lars.hpp
@@ -138,12 +138,16 @@ class LARS
/**
* Predict y_i for each data point in the given data matrix, using the
- * currently-trained LARS model (so make sure you run Regress() first).
+ * currently-trained LARS model (so make sure you run Regress() first). If
+ * the data matrix is row-major (as opposed to the usual column-major format
+ * for mlpack matrices), set rowMajor = true to avoid an extra transpose.
*
* @param points The data points to regress on.
* @param predictions y, which will contained calculated values on completion.
*/
- void Predict(const arma::mat& points, arma::vec& predictions) const;
+ void Predict(const arma::mat& points,
+ arma::vec& predictions,
+ const bool rowMajor = false) const;
//! Access the set of active dimensions.
const std::vector<size_t>& ActiveSet() const { return activeSet; }
diff --git a/src/mlpack/tests/lars_test.cpp b/src/mlpack/tests/lars_test.cpp
index d658fca..99a7155 100644
--- a/src/mlpack/tests/lars_test.cpp
+++ b/src/mlpack/tests/lars_test.cpp
@@ -183,6 +183,7 @@ BOOST_AUTO_TEST_CASE(PredictTest)
lars.Predict(X, predictions);
arma::vec adjPred = X * predictions;
+ BOOST_REQUIRE_EQUAL(predictions.n_elem, 1000);
for (size_t i = 0; i < betaOptPred.n_elem; ++i)
{
if (std::abs(betaOptPred[i]) < 1e-5)
@@ -195,4 +196,33 @@ BOOST_AUTO_TEST_CASE(PredictTest)
}
}
+BOOST_AUTO_TEST_CASE(PredictRowMajorTest)
+{
+ arma::mat X;
+ arma::vec y;
+ GenerateProblem(X, y, 1000, 100);
+
+ // Set lambdas to 0.
+
+ LARS lars(false, 0, 0);
+ arma::vec betaOpt;
+ lars.Regress(X, y, betaOpt);
+
+ // Get both row-major and column-major predictions. Make sure they are the
+ // same.
+ arma::vec rowMajorPred, colMajorPred;
+
+ lars.Predict(X, colMajorPred);
+ lars.Predict(X.t(), rowMajorPred, true);
+
+ BOOST_REQUIRE_EQUAL(colMajorPred.n_elem, rowMajorPred.n_elem);
+ for (size_t i = 0; i < colMajorPred.n_elem; ++i)
+ {
+ if (std::abs(colMajorPred[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(rowMajorPred[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(colMajorPred[i], rowMajorPred[i], 1e-5);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list