[mlpack-git] master: Update API, and add tests for retraining. (516ef59)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Dec 3 10:50:10 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/5858f46c3113651f598596bade89c1b838410652...c829fc1a2415f3dddb672431bb51ff05cbc40a76
>---------------------------------------------------------------
commit 516ef599298675ec4932353090568de77605293b
Author: ryan <ryan at ratml.org>
Date: Thu Dec 3 10:49:20 2015 -0500
Update API, and add tests for retraining.
>---------------------------------------------------------------
516ef599298675ec4932353090568de77605293b
src/mlpack/tests/lars_test.cpp | 70 +++++++++++++++++++++++++++++++++++-------
1 file changed, 59 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/tests/lars_test.cpp b/src/mlpack/tests/lars_test.cpp
index 99a7155..83bd0a8 100644
--- a/src/mlpack/tests/lars_test.cpp
+++ b/src/mlpack/tests/lars_test.cpp
@@ -24,11 +24,10 @@ void GenerateProblem(arma::mat& X, arma::vec& y, size_t nPoints, size_t nDims)
y = trans(X) * beta;
}
-
void LARSVerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda)
{
size_t nDims = beta.n_elem;
- const double tol = 1e-12;
+ const double tol = 1e-10;
for(size_t j = 0; j < nDims; j++)
{
if (beta(j) == 0)
@@ -49,7 +48,6 @@ void LARSVerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda)
}
}
-
void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
{
arma::mat X;
@@ -71,7 +69,7 @@ void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
LARS lars(useCholesky, lambda1, lambda2);
arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
+ lars.Train(X, y, betaOpt);
arma::vec errCorr = (X * trans(X) + lambda2 *
arma::eye(nDims, nDims)) * betaOpt - X * y;
@@ -80,7 +78,6 @@ void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky)
}
}
-
BOOST_AUTO_TEST_CASE(LARSTestLassoCholesky)
{
LassoTest(100, 10, false, true);
@@ -92,13 +89,11 @@ BOOST_AUTO_TEST_CASE(LARSTestLassoGram)
LassoTest(100, 10, false, false);
}
-
BOOST_AUTO_TEST_CASE(LARSTestElasticNetCholesky)
{
LassoTest(100, 10, true, true);
}
-
BOOST_AUTO_TEST_CASE(LARSTestElasticNetGram)
{
LassoTest(100, 10, true, false);
@@ -122,7 +117,7 @@ BOOST_AUTO_TEST_CASE(CholeskySingularityTest)
{
LARS lars(true, lambda1, 0.0);
arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
+ lars.Train(X, y, betaOpt);
arma::vec errCorr = (X * X.t()) * betaOpt - X * y;
@@ -146,7 +141,7 @@ BOOST_AUTO_TEST_CASE(NoCholeskySingularityTest)
{
LARS lars(false, lambda1, 0.0);
arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
+ lars.Train(X, y, betaOpt);
arma::vec errCorr = (X * X.t()) * betaOpt - X * y;
@@ -174,7 +169,7 @@ BOOST_AUTO_TEST_CASE(PredictTest)
{
LARS lars(useCholesky, lambda1, lambda2);
arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
+ lars.Train(X, y, betaOpt);
// Calculate what the actual error should be with these regression
// parameters.
@@ -206,7 +201,7 @@ BOOST_AUTO_TEST_CASE(PredictRowMajorTest)
LARS lars(false, 0, 0);
arma::vec betaOpt;
- lars.Regress(X, y, betaOpt);
+ lars.Train(X, y, betaOpt);
// Get both row-major and column-major predictions. Make sure they are the
// same.
@@ -225,4 +220,57 @@ BOOST_AUTO_TEST_CASE(PredictRowMajorTest)
}
}
+/**
+ * Make sure that if we train twice, there is no issue.
+ */
+BOOST_AUTO_TEST_CASE(RetrainTest)
+{
+ arma::mat origX;
+ arma::vec origY;
+ GenerateProblem(origX, origY, 1000, 50);
+
+ arma::mat newX;
+ arma::vec newY;
+ GenerateProblem(newX, newY, 750, 75);
+
+ LARS lars(false, 0.1, 0.1);
+ arma::vec betaOpt;
+ lars.Train(origX, origY, betaOpt);
+
+ // Now train on new data.
+ lars.Train(newX, newY, betaOpt);
+
+ arma::vec errCorr = (newX * trans(newX) + 0.1 *
+ arma::eye(75, 75)) * betaOpt - newX * newY;
+
+ LARSVerifyCorrectness(betaOpt, errCorr, 0.1);
+}
+
+/**
+ * Make sure if we train twice using the Cholesky decomposition, there is no
+ * issue.
+ */
+BOOST_AUTO_TEST_CASE(RetrainCholeskyTest)
+{
+ arma::mat origX;
+ arma::vec origY;
+ GenerateProblem(origX, origY, 1000, 50);
+
+ arma::mat newX;
+ arma::vec newY;
+ GenerateProblem(newX, newY, 750, 75);
+
+ LARS lars(true, 0.1, 0.1);
+ arma::vec betaOpt;
+ lars.Train(origX, origY, betaOpt);
+
+ // Now train on new data.
+ lars.Train(newX, newY, betaOpt);
+
+ arma::vec errCorr = (newX * trans(newX) + 0.1 *
+ arma::eye(75, 75)) * betaOpt - newX * newY;
+
+ LARSVerifyCorrectness(betaOpt, errCorr, 0.1);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list