[mlpack-svn] r10565 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Dec 5 17:49:14 EST 2011


Author: niche
Date: 2011-12-05 17:49:14 -0500 (Mon, 05 Dec 2011)
New Revision: 10565

Modified:
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt
   mlpack/trunk/src/mlpack/tests/lars_test.cpp
Log:
added test cases for lars

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	2011-12-05 22:46:43 UTC (rev 10564)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	2011-12-05 22:49:14 UTC (rev 10565)
@@ -21,6 +21,7 @@
   hmm_test.cpp
   kernel_test.cpp
   kmeans_test.cpp
+  lars_test.cpp
   lin_alg_test.cpp
   linear_regression_test.cpp
   load_save_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/lars_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/lars_test.cpp	2011-12-05 22:46:43 UTC (rev 10564)
+++ mlpack/trunk/src/mlpack/tests/lars_test.cpp	2011-12-05 22:49:14 UTC (rev 10565)
@@ -9,28 +9,26 @@
 
 
 #include <armadillo>
-#include "lars.h"
+#include <mlpack/methods/lars/lars.hpp>
 
-#define BOOST_TEST_DYN_LINK
-#define BOOST_TEST_MODULE HELLO
 #include <boost/test/unit_test.hpp>
 
-//BOOST_AUTO_TEST_SUITE(LARS_Test);
+BOOST_AUTO_TEST_SUITE(LARS_Test);
 
-void GenerateProblem(mat& X, vec& y, u32 nPoints, u32 nDims) {
-  X = randn(nPoints, nDims);
-  vec beta = randn(nDims, 1);
+void GenerateProblem(arma::mat& X, arma::vec& y, size_t nPoints, size_t nDims) {
+  X = arma::randn(nPoints, nDims);
+  arma::vec beta = arma::randn(nDims, 1);
   y = X * beta;
 }
 
 
-void VerifyCorrectness(vec beta, vec errCorr, double lambda) {
-  u32 nDims = beta.n_elem;
+void VerifyCorrectness(arma::vec beta, arma::vec errCorr, double lambda) {
+  size_t nDims = beta.n_elem;
   const double tol = 1e-12;
-  for(u32 j = 0; j < nDims; j++) {
+  for(size_t j = 0; j < nDims; j++) {
     if(beta(j) == 0) {
       // make sure that errCorr(j) <= lambda
-      BOOST_REQUIRE_SMALL(max(errCorr(j) - lambda, 0.0), tol);
+      BOOST_REQUIRE_SMALL(std::max(errCorr(j) - lambda, 0.0), tol);
     }
     else if(beta(j) < 0) {
       // make sure that errCorr(j) == lambda
@@ -44,15 +42,15 @@
 }
 
 
-void LassoTest(u32 nPoints, u32 nDims, bool elasticNet, bool useCholesky) {
-  mat X;
-  vec y;
+void LassoTest(size_t nPoints, size_t nDims, bool elasticNet, bool useCholesky) {
+  arma::mat X;
+  arma::vec y;
   
-  for(u32 i = 0; i < 100; i++) {
+  for(size_t i = 0; i < 100; i++) {
     GenerateProblem(X, y, nPoints, nDims);
     
     // Armadillo's median is broken, so...
-    vec sortedAbsCorr = sort(abs(trans(X) * y));
+    arma::vec sortedAbsCorr = sort(abs(trans(X) * y));
     double lambda_1 = sortedAbsCorr(nDims/2);
     double lambda_2;
     if(elasticNet) {
@@ -62,13 +60,12 @@
       lambda_2 = 0;
     }
     
-    Lars lars;
-    lars.Init(X, y, useCholesky, lambda_1, lambda_2);
+    mlpack::lars::LARS lars(X, y, useCholesky, lambda_1, lambda_2);
     lars.DoLARS();
     
-    vec betaOpt;
+    arma::vec betaOpt;
     lars.Solution(betaOpt);
-    vec errCorr = (trans(X) * X + lambda_2 * eye(nDims, nDims)) * betaOpt - trans(X) * y;
+    arma::vec errCorr = (arma::trans(X) * X + lambda_2 * arma::eye(nDims, nDims)) * betaOpt - arma::trans(X) * y;
     
     VerifyCorrectness(betaOpt, errCorr, lambda_1);
   }
@@ -94,6 +91,4 @@
   LassoTest(100, 10, true, false);
 }
 
-
-
-//BOOST_AUTO_TEST_SUITE_END();
+BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list