[mlpack-svn] r10694 - mlpack/trunk/src/mlpack/methods/lars
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 10 12:53:35 EST 2011
Author: niche
Date: 2011-12-10 12:53:34 -0500 (Sat, 10 Dec 2011)
New Revision: 10694
Modified:
mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp
Log:
cleaned up lars_main and added command line options
Modified: mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp 2011-12-10 04:47:48 UTC (rev 10693)
+++ mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp 2011-12-10 17:53:34 UTC (rev 10694)
@@ -1,13 +1,12 @@
-/** @file new_main.cc
+/**
+ * @file lars_main.cpp
+ * @author Nishant Mehta
*
- * Driver file for testing LARS
- *
- * @author Nishant Mehta (niche)
+ * Executable for LARS
*/
-//#include <fastlib/fastlib.h>
+#include <mlpack/core.hpp>
#include <armadillo>
-
#include "lars.hpp"
using namespace arma;
@@ -15,70 +14,48 @@
using namespace mlpack;
using namespace mlpack::lars;
-int main(int argc, char* argv[])
-{
- //bool use_cholesky = false;
- double lambda_1 = 1;
- double lambda_2 = 0.5;
+PROGRAM_INFO("LARS", "An implementation of LARS: Least Angle Regression (Stagewise/laSso)");
- u32 n = 100;
- u32 p = 10;
+PARAM_STRING_REQ("X", "Covariates filename (observations of input random "
+ "variables)", "");
+PARAM_STRING_REQ("y", "Targets filename (observations of output random "
+ "variable", "");
+PARAM_STRING_REQ("beta", "Solution filename (linear estimator)", "");
- mat X = randu<mat>(n,p);
+PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty", "", 0);
+PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty", "", 0);
+PARAM_FLAG("use_cholesky", "Use Cholesky decomposition during computation "
+ "rather than explicitly computing full Gram matrix", "");
- /*
- mat X_reg = zeros(n + p, p);
- X_reg(span(0, n - 1), span::all) = X;
- for(u32 i = 0; i < p; i++) {
- X_reg(n + i, i) = sqrt(lambda_2);
- }
- //X_reg.print("X_reg");
- */
- mat beta_true = zeros(p,1);
- beta_true(0) = 1;
- beta_true(1) = -1;
- beta_true(9) = 1;
+int main(int argc, char* argv[])
+{
- vec y = X * beta_true + 0.1 * randu<vec>(n);
- //vec y = randu(n);
- //y.load("y.dat", raw_ascii);
- //y.load("x.dat", raw_ascii);
+ // Handle parameters
+ CLI::ParseCommandLine(argc, argv);
+
+ double lambda1 = CLI::GetParam<double>("lambda1");
+ double lambda2 = CLI::GetParam<double>("lambda2");
+ bool useCholesky = CLI::GetParam<bool>("use_cholesky");
- vec y_reg = zeros(n + p);
- y_reg.subvec(0, n - 1) = y;
- //y_reg.print("y_reg");
-
- mat Gram = trans(X) * X;
-
- LARS lars(X, y, true, lambda_1, lambda_2);
- //lars.Init(X, y, true);
- //lars.Init(X, y, false);
- //lars.SetGram(Gram.memptr(), X.n_cols);
- //lars.Init(X_reg, y_reg, false, lambda_1);
- //lars.Init(X_reg, y_reg, use_cholesky);
-
+ // load covariates
+ const std::string matXFilename = CLI::GetParam<std::string>("X");
+ mat matX;
+ matX.load(matXFilename, raw_ascii);
+
+ // load targets
+ const std::string yFilename = CLI::GetParam<std::string>("y");
+ vec y;
+ y.load(yFilename, raw_ascii);
+
+ // do LARS
+ LARS lars(matX, y, useCholesky, lambda1, lambda2);
lars.DoLARS();
-
- u32 path_length = lars.beta_path().size();
-
- mat beta_matrix = mat(p, path_length);
- for(u32 i = 0; i < path_length; i++)
- {
- beta_matrix.col(i) = lars.beta_path()[i];
- }
- //beta_matrix.print("beta matrix");
-
- vec lambda_path_vec = conv_to<colvec>::from(lars.lambda_path());
- //lambda_path_vec.print("lambda path");
-
- X.save("X.dat", raw_ascii);
- y.save("y.dat", raw_ascii);
-
- ////beta_matrix.save("beta.dat", raw_ascii);
- ////lambda_path_vec.save("lambda.dat", raw_ascii);
+
+ // get and save solution
vec beta;
lars.Solution(beta);
-
- beta.print("final beta");
+
+ const std::string betaFilename = CLI::GetParam<std::string>("beta");
+ beta.save(betaFilename, raw_ascii);
}
More information about the mlpack-svn
mailing list