[mlpack-svn] r10565 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 5 17:49:14 EST 2011
Author: niche
Date: 2011-12-05 17:49:14 -0500 (Mon, 05 Dec 2011)
New Revision: 10565
Modified:
mlpack/trunk/src/mlpack/tests/CMakeLists.txt
mlpack/trunk/src/mlpack/tests/lars_test.cpp
Log:
added test cases for lars
Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2011-12-05 22:46:43 UTC (rev 10564)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2011-12-05 22:49:14 UTC (rev 10565)
@@ -21,6 +21,7 @@
hmm_test.cpp
kernel_test.cpp
kmeans_test.cpp
+ lars_test.cpp
lin_alg_test.cpp
linear_regression_test.cpp
load_save_test.cpp
Modified: mlpack/trunk/src/mlpack/tests/lars_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/lars_test.cpp 2011-12-05 22:46:43 UTC (rev 10564)
+++ mlpack/trunk/src/mlpack/tests/lars_test.cpp 2011-12-05 22:49:14 UTC (rev 10565)
@@ -9,28 +9,26 @@
#include <armadillo>
-#include "lars.h"
+#include <mlpack/methods/lars/lars.hpp>
-#define BOOST_TEST_DYN_LINK
-#define BOOST_TEST_MODULE HELLO
#include <boost/test/unit_test.hpp>
-//BOOST_AUTO_TEST_SUITE(LARS_Test);
+BOOST_AUTO_TEST_SUITE(LARS_Test);
-void GenerateProblem(mat& X, vec& y, u32 nPoints, u32 nDims) {
- X = randn(nPoints, nDims);
- vec beta = randn(nDims, 1);
+void GenerateProblem(arma::mat& X, arma::vec& y, size_t nPoints, size_t nDims) {
+ X = arma::randn(nPoints, nDims);
+ arma::vec beta = arma::randn(nDims, 1);
y = X * beta;
}
-void VerifyCorrectness(vec beta, vec errCorr, double lambda) {
- u32 nDims = beta.n_elem;
+void VerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda) {
+ size_t nDims = beta.n_elem;
const double tol = 1e-12;
- for(u32 j = 0; j < nDims; j++) {
+ for(size_t j = 0; j < nDims; j++) {
if(beta(j) == 0) {
// make sure that errCorr(j) <= lambda
- BOOST_REQUIRE_SMALL(max(errCorr(j) - lambda, 0.0), tol);
+ BOOST_REQUIRE_SMALL(std::max(errCorr(j) - lambda, 0.0), tol);
}
else if(beta(j) < 0) {
// make sure that errCorr(j) == lambda
@@ -44,15 +42,15 @@
}
-void LassoTest(u32 nPoints, u32 nDims, bool elasticNet, bool useCholesky) {
- mat X;
- vec y;
+void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky) {
+ arma::mat X;
+ arma::vec y;
- for(u32 i = 0; i < 100; i++) {
+ for(size_t i = 0; i < 100; i++) {
GenerateProblem(X, y, nPoints, nDims);
// Armadillo's median is broken, so...
- vec sortedAbsCorr = sort(abs(trans(X) * y));
+ arma::vec sortedAbsCorr = sort(abs(trans(X) * y));
double lambda_1 = sortedAbsCorr(nDims/2);
double lambda_2;
if(elasticNet) {
@@ -62,13 +60,12 @@
lambda_2 = 0;
}
- Lars lars;
- lars.Init(X, y, useCholesky, lambda_1, lambda_2);
+ mlpack::lars::LARS lars(X, y, useCholesky, lambda_1, lambda_2);
lars.DoLARS();
- vec betaOpt;
+ arma::vec betaOpt;
lars.Solution(betaOpt);
- vec errCorr = (trans(X) * X + lambda_2 * eye(nDims, nDims)) * betaOpt - trans(X) * y;
+ arma::vec errCorr = (arma::trans(X) * X + lambda_2 * arma::eye(nDims, nDims)) * betaOpt - arma::trans(X) * y;
VerifyCorrectness(betaOpt, errCorr, lambda_1);
}
@@ -94,6 +91,4 @@
LassoTest(100, 10, true, false);
}
-
-
-//BOOST_AUTO_TEST_SUITE_END();
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list