[mlpack-svn] r10562 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 5 17:10:14 EST 2011
Author: niche
Date: 2011-12-05 17:10:14 -0500 (Mon, 05 Dec 2011)
New Revision: 10562
Added:
mlpack/trunk/src/mlpack/tests/lars_test.cpp
Log:
added LARS tests, but still need to make small changes before adding to CMake
Added: mlpack/trunk/src/mlpack/tests/lars_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/lars_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/tests/lars_test.cpp 2011-12-05 22:10:14 UTC (rev 10562)
@@ -0,0 +1,99 @@
+/**
+ * @file lars_test.cpp
+ *
+ * Test for LARS
+ */
+
+// Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
+// to use FPC_WEAK, and it's not at all intuitive how to do that.
+
+
+#include <armadillo>
+#include "lars.h"
+
+#define BOOST_TEST_DYN_LINK
+#define BOOST_TEST_MODULE HELLO
+#include <boost/test/unit_test.hpp>
+
+//BOOST_AUTO_TEST_SUITE(LARS_Test);
+
+void GenerateProblem(mat& X, vec& y, u32 nPoints, u32 nDims) {
+ X = randn(nPoints, nDims);
+ vec beta = randn(nDims, 1);
+ y = X * beta;
+}
+
+
+void VerifyCorrectness(vec beta, vec errCorr, double lambda) {
+ u32 nDims = beta.n_elem;
+ const double tol = 1e-12;
+ for(u32 j = 0; j < nDims; j++) {
+ if(beta(j) == 0) {
+ // make sure that errCorr(j) <= lambda
+ BOOST_REQUIRE_SMALL(max(errCorr(j) - lambda, 0.0), tol);
+ }
+ else if(beta(j) < 0) {
+ // make sure that errCorr(j) == lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) - lambda, tol);
+ }
+ else { // beta(j) > 0
+ // make sure that errCorr(j) == -lambda
+ BOOST_REQUIRE_SMALL(errCorr(j) + lambda, tol);
+ }
+ }
+}
+
+
+void LassoTest(u32 nPoints, u32 nDims, bool elasticNet, bool useCholesky) {
+ mat X;
+ vec y;
+
+ for(u32 i = 0; i < 100; i++) {
+ GenerateProblem(X, y, nPoints, nDims);
+
+ // Armadillo's median is broken, so...
+ vec sortedAbsCorr = sort(abs(trans(X) * y));
+ double lambda_1 = sortedAbsCorr(nDims/2);
+ double lambda_2;
+ if(elasticNet) {
+ lambda_2 = lambda_1 / 2;
+ }
+ else {
+ lambda_2 = 0;
+ }
+
+ Lars lars;
+ lars.Init(X, y, useCholesky, lambda_1, lambda_2);
+ lars.DoLARS();
+
+ vec betaOpt;
+ lars.Solution(betaOpt);
+ vec errCorr = (trans(X) * X + lambda_2 * eye(nDims, nDims)) * betaOpt - trans(X) * y;
+
+ VerifyCorrectness(betaOpt, errCorr, lambda_1);
+ }
+}
+
+
+BOOST_AUTO_TEST_CASE(LARS_Test_Lasso_Cholesky) {
+ LassoTest(100, 10, true, false);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARS_Test_Lasso_Gram) {
+ LassoTest(100, 10, false, false);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARS_Test_ElasticNet_Cholesky) {
+ LassoTest(100, 10, true, true);
+}
+
+
+BOOST_AUTO_TEST_CASE(LARS_Test_ElasticNet_Gram) {
+ LassoTest(100, 10, true, false);
+}
+
+
+
+//BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list