[mlpack-svn] r10694 - mlpack/trunk/src/mlpack/methods/lars

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Dec 10 12:53:35 EST 2011


Author: niche
Date: 2011-12-10 12:53:34 -0500 (Sat, 10 Dec 2011)
New Revision: 10694

Modified:
   mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp
Log:
cleaned up lars_main and added command line options

Modified: mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp	2011-12-10 04:47:48 UTC (rev 10693)
+++ mlpack/trunk/src/mlpack/methods/lars/lars_main.cpp	2011-12-10 17:53:34 UTC (rev 10694)
@@ -1,13 +1,12 @@
-/** @file new_main.cc
+/**
+ * @file lars_main.cpp
+ * @author Nishant Mehta
  *
- *  Driver file for testing LARS
- *
- *  @author Nishant Mehta (niche)
+ * Executable for LARS
  */
 
-//#include <fastlib/fastlib.h>
+#include <mlpack/core.hpp>
 #include <armadillo>
-
 #include "lars.hpp"
 
 using namespace arma;
@@ -15,70 +14,48 @@
 using namespace mlpack;
 using namespace mlpack::lars;
 
-int main(int argc, char* argv[])
-{
-  //bool use_cholesky = false;
-  double lambda_1 = 1;
-  double lambda_2 = 0.5;
+PROGRAM_INFO("LARS", "An implementation of LARS: Least Angle Regression (Stagewise/laSso)");
 
-  u32 n = 100;
-  u32 p = 10;
+PARAM_STRING_REQ("X", "Covariates filename (observations of input random "
+		 "variables)", "");
+PARAM_STRING_REQ("y", "Targets filename (observations of output random "
+		 "variable", "");
+PARAM_STRING_REQ("beta", "Solution filename (linear estimator)", "");
 
-  mat X = randu<mat>(n,p);
+PARAM_DOUBLE("lambda1", "Regularization parameter for l1-norm penalty", "", 0);
+PARAM_DOUBLE("lambda2", "Regularization parameter for l2-norm penalty", "", 0);
+PARAM_FLAG("use_cholesky", "Use Cholesky decomposition during computation "
+	   "rather than explicitly computing full Gram matrix", "");
 
-  /*
-  mat X_reg = zeros(n + p, p);
-  X_reg(span(0, n - 1), span::all) = X;
-  for(u32 i = 0; i < p; i++) {
-    X_reg(n + i, i) = sqrt(lambda_2);
-  }
-  //X_reg.print("X_reg");
-  */
 
-  mat beta_true = zeros(p,1);
-  beta_true(0) = 1;
-  beta_true(1) = -1;
-  beta_true(9) = 1;
+int main(int argc, char* argv[])
+{
 
-  vec y = X * beta_true + 0.1 * randu<vec>(n);
-  //vec y = randu(n);
-  //y.load("y.dat", raw_ascii);
-  //y.load("x.dat", raw_ascii);
+  // Handle parameters
+  CLI::ParseCommandLine(argc, argv);
+  
+  double lambda1 = CLI::GetParam<double>("lambda1");
+  double lambda2 = CLI::GetParam<double>("lambda2");
+  bool useCholesky = CLI::GetParam<bool>("use_cholesky");
 
-  vec y_reg = zeros(n + p);
-  y_reg.subvec(0, n - 1) = y;
-  //y_reg.print("y_reg");
-
-  mat Gram = trans(X) * X;
-
-  LARS lars(X, y, true, lambda_1, lambda_2);
-  //lars.Init(X, y, true);
-  //lars.Init(X, y, false);
-  //lars.SetGram(Gram.memptr(), X.n_cols);
-  //lars.Init(X_reg, y_reg, false, lambda_1);
-  //lars.Init(X_reg, y_reg, use_cholesky);
-
+  // load covariates
+  const std::string matXFilename = CLI::GetParam<std::string>("X");
+  mat matX;
+  matX.load(matXFilename, raw_ascii);
+  
+  // load targets
+  const std::string yFilename = CLI::GetParam<std::string>("y");
+  vec y;
+  y.load(yFilename, raw_ascii);
+  
+  // do LARS
+  LARS lars(matX, y, useCholesky, lambda1, lambda2);
   lars.DoLARS();
-
-  u32 path_length = lars.beta_path().size();
-
-  mat beta_matrix = mat(p, path_length);
-  for(u32 i = 0; i < path_length; i++)
-  {
-    beta_matrix.col(i) = lars.beta_path()[i];
-  }
-  //beta_matrix.print("beta matrix");
-
-  vec lambda_path_vec = conv_to<colvec>::from(lars.lambda_path());
-  //lambda_path_vec.print("lambda path");
-
-  X.save("X.dat", raw_ascii);
-  y.save("y.dat", raw_ascii);
-
-  ////beta_matrix.save("beta.dat", raw_ascii);
-  ////lambda_path_vec.save("lambda.dat", raw_ascii);
+  
+  // get and save solution
   vec beta;
   lars.Solution(beta);
-
-  beta.print("final beta");
+  
+  const std::string betaFilename = CLI::GetParam<std::string>("beta");
+  beta.save(betaFilename, raw_ascii);
 }




More information about the mlpack-svn mailing list