[mlpack-svn] r12983 - mlpack/trunk/src/mlpack/methods/lars

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Jun 9 17:57:03 EDT 2012


Author: rcurtin
Date: 2012-06-09 17:57:02 -0400 (Sat, 09 Jun 2012)
New Revision: 12983

Modified:
   mlpack/trunk/src/mlpack/methods/lars/lars.cpp
   mlpack/trunk/src/mlpack/methods/lars/lars.hpp
Log:
More refactoring; remove nActive because it isn't necessary, and include
solution in the call to DoLARS().


Modified: mlpack/trunk/src/mlpack/methods/lars/lars.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lars/lars.cpp	2012-06-09 21:10:20 UTC (rev 12982)
+++ mlpack/trunk/src/mlpack/methods/lars/lars.cpp	2012-06-09 21:57:02 UTC (rev 12983)
@@ -38,6 +38,7 @@
 
 void LARS::DoLARS(const arma::mat& matX,
                   const arma::vec& y,
+                  arma::vec& beta,
                   const bool rowMajor)
 {
   // This matrix may end up holding the transpose -- if necessary.
@@ -52,13 +53,10 @@
 
   // Set up active set variables.  In the beginning, the active set has size 0
   // (all dimensions are inactive).
-  nActive = 0;
-  activeSet = std::vector<arma::uword>(0);
-  isActive = std::vector<bool>(dataRef.n_cols);
-  fill(isActive.begin(), isActive.end(), false);
+  isActive.resize(dataRef.n_cols, false);
 
   // Initialize yHat and beta.
-  arma::vec beta = arma::zeros(dataRef.n_cols);
+  beta = arma::zeros(dataRef.n_cols);
   arma::vec yHat = arma::zeros(dataRef.n_rows);
   arma::vec yHatDirection = arma::vec(dataRef.n_rows);
 
@@ -99,7 +97,7 @@
   }
 
   // Main loop.
-  while ((nActive < dataRef.n_cols) && (maxCorr > tolerance))
+  while ((activeSet.size() < dataRef.n_cols) && (maxCorr > tolerance))
   {
     // Compute the maximum correlation among inactive dimensions.
     maxCorr = 0;
@@ -118,8 +116,8 @@
       //printf("activating %d\n", changeInd);
       if (useCholesky)
       {
-        // vec newGramCol = vec(nActive);
-        // for (uword i = 0; i < nActive; i++)
+        // vec newGramCol = vec(activeSet.size());
+        // for (size_t i = 0; i < activeSet.size(); i++)
         // {
         //   newGramCol[i] = dot(matX.col(activeSet[i]), matX.col(changeInd));
         // }
@@ -136,8 +134,8 @@
     }
 
     // compute signs of correlations
-    arma::vec s = arma::vec(nActive);
-    for (arma::uword i = 0; i < nActive; i++)
+    arma::vec s = arma::vec(activeSet.size());
+    for (size_t i = 0; i < activeSet.size(); i++)
       s(i) = corr(activeSet[i]) / fabs(corr(activeSet[i]));
 
     // compute "equiangular" direction in parameter space (betaDirection)
@@ -167,14 +165,14 @@
     }
     else
     {
-      arma::mat matGramActive = arma::mat(nActive, nActive);
-      for (arma::uword i = 0; i < nActive; i++)
-        for (arma::uword j = 0; j < nActive; j++)
-          matGramActive(i,j) = matGram(activeSet[i], activeSet[j]);
+      arma::mat matGramActive = arma::mat(activeSet.size(), activeSet.size());
+      for (size_t i = 0; i < activeSet.size(); i++)
+        for (size_t j = 0; j < activeSet.size(); j++)
+          matGramActive(i, j) = matGram(activeSet[i], activeSet[j]);
 
-      arma::mat matS = s * arma::ones<arma::mat>(1, nActive);
+      arma::mat matS = s * arma::ones<arma::mat>(1, activeSet.size());
       unnormalizedBetaDirection = solve(matGramActive % trans(matS) % matS,
-          arma::ones<arma::mat>(nActive, 1));
+          arma::ones<arma::mat>(activeSet.size(), 1));
       normalization = 1.0 / sqrt(sum(unnormalizedBetaDirection));
       betaDirection = normalization * unnormalizedBetaDirection % s;
     }
@@ -185,10 +183,10 @@
     double gamma = maxCorr / normalization;
 
     // if not all variables are active
-    if (nActive < dataRef.n_cols)
+    if (activeSet.size() < dataRef.n_cols)
     {
       // compute correlations with direction
-      for (arma::uword ind = 0; ind < dataRef.n_cols; ind++)
+      for (size_t ind = 0; ind < dataRef.n_cols; ind++)
       {
         if (isActive[ind])
           continue;
@@ -208,9 +206,9 @@
     {
       lassocond = false;
       double lassoboundOnGamma = DBL_MAX;
-      arma::uword activeIndToKickOut = -1;
+      size_t activeIndToKickOut = -1;
 
-      for (arma::uword i = 0; i < nActive; i++)
+      for (size_t i = 0; i < activeSet.size(); i++)
       {
         double val = -beta(activeSet[i]) / betaDirection(i);
         if ((val > 0) && (val < lassoboundOnGamma))
@@ -237,7 +235,7 @@
     yHat += gamma * yHatDirection;
 
     // update estimator
-    for (arma::uword i = 0; i < nActive; i++)
+    for (size_t i = 0; i < activeSet.size(); i++)
     {
       beta(activeSet[i]) += gamma * betaDirection(i);
     }
@@ -270,10 +268,10 @@
       corr -= lambda2 * beta;
 
     double curLambda = 0;
-    for (arma::uword i = 0; i < nActive; i++)
+    for (size_t i = 0; i < activeSet.size(); i++)
       curLambda += fabs(corr(activeSet[i]));
 
-    curLambda /= ((double) nActive);
+    curLambda /= ((double) activeSet.size());
 
     lambdaPath.push_back(curLambda);
 
@@ -287,24 +285,20 @@
       }
     }
   }
-}
 
-void LARS::Solution(arma::vec& beta)
-{
-  beta = BetaPath().back();
+  // Unfortunate copy...
+  beta = betaPath.back();
 }
 
 // Private functions.
-void LARS::Deactivate(arma::uword activeVarInd)
+void LARS::Deactivate(const size_t activeVarInd)
 {
-  nActive--;
   isActive[activeSet[activeVarInd]] = false;
   activeSet.erase(activeSet.begin() + activeVarInd);
 }
 
-void LARS::Activate(arma::uword varInd)
+void LARS::Activate(const size_t varInd)
 {
-  nActive++;
   isActive[varInd] = true;
   activeSet.push_back(varInd);
 }
@@ -314,7 +308,7 @@
                                 arma::vec& yHatDirection)
 {
   yHatDirection.fill(0);
-  for (arma::uword i = 0; i < nActive; i++)
+  for (size_t i = 0; i < activeSet.size(); i++)
     yHatDirection += betaDirection(i) * matX.col(activeSet[i]);
 }
 
@@ -413,9 +407,9 @@
   }
 }
 
-void LARS::CholeskyDelete(arma::uword colToKill)
+void LARS::CholeskyDelete(const size_t colToKill)
 {
-  arma::uword n = matUtriCholFactor.n_rows;
+  size_t n = matUtriCholFactor.n_rows;
 
   if (colToKill == (n - 1))
   {
@@ -427,7 +421,7 @@
     matUtriCholFactor.shed_col(colToKill); // remove column colToKill
     n--;
 
-    for (arma::uword k = colToKill; k < n; k++)
+    for (size_t k = colToKill; k < n; k++)
     {
       arma::mat matG;
       arma::vec::fixed<2> rotatedVec;

Modified: mlpack/trunk/src/mlpack/methods/lars/lars.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lars/lars.hpp	2012-06-09 21:10:20 UTC (rev 12982)
+++ mlpack/trunk/src/mlpack/methods/lars/lars.hpp	2012-06-09 21:57:02 UTC (rev 12983)
@@ -87,7 +87,7 @@
    * Set the parameters to LARS.  Both lambda1 and lambda2 default to 0.
    *
    * @param useCholesky Whether or not to use Cholesky decomposition when
-   *    solving linear system. If no, compute full Gram matrix at beginning.
+   *    solving linear system (as opposed to using the full Gram matrix).
    * @param lambda1 Regularization parameter for l1-norm penalty.
    * @param lambda2 Regularization parameter for l2-norm penalty.
    * @param tolerance Run until the maximum correlation of elements in (X^T y)
@@ -103,7 +103,7 @@
    * lambda1 and lambda2 default to 0.
    *
    * @param useCholesky Whether or not to use Cholesky decomposition when
-   *    solving linear system.
+   *    solving linear system (as opposed to using the full Gram matrix).
    * @param gramMatrix Gram matrix.
    * @param lambda1 Regularization parameter for l1-norm penalty.
    * @param lambda2 Regularization parameter for l2-norm penalty.
@@ -122,24 +122,21 @@
    * However, because LARS is more efficient on a row-major matrix, this method
    * will (internally) transpose the matrix.  If this transposition is not
    * necessary (i.e., you want to pass in a row-major matrix), pass 'true' for
-   * the rowmajor parameter.
+   * the rowMajor parameter.
    *
    * @param matX Column-major input data (or row-major input data if rowMajor =
    *     true).
    * @param y A vector of targets.
+   * @param beta Vector to store the solution in.
    * @param rowMajor Set to true if matX is row-major.
    */
   void DoLARS(const arma::mat& matX,
               const arma::vec& y,
+              arma::vec& beta,
               const bool rowMajor = false);
 
-  /*
-   * Load the solution vector, which is the last vector from the solution path
-   */
-  void Solution(arma::vec& beta);
-
   //! Accessor for activeSet.
-  const std::vector<arma::uword>& ActiveSet() const { return activeSet; }
+  const std::vector<size_t>& ActiveSet() const { return activeSet; }
 
   //! Accessor for betaPath.
   const std::vector<arma::vec>& BetaPath() const { return betaPath; }
@@ -182,11 +179,8 @@
   //! Value of lambda_1 for each solution in solution path.
   std::vector<double> lambdaPath;
 
-  //! Number of dimensions in active set.
-  arma::uword nActive;
-
   //! Active set of dimensions.
-  std::vector<arma::uword> activeSet;
+  std::vector<size_t> activeSet;
 
   //! Active set membership indicator (for each dimension).
   std::vector<bool> isActive;
@@ -196,14 +190,14 @@
    *
    * @param activeVarInd Index of element to remove from active set.
    */
-  void Deactivate(arma::uword activeVarInd);
+  void Deactivate(const size_t activeVarInd);
 
   /**
    * Add dimension varInd to active set.
    *
    * @param varInd Dimension to add to active set.
    */
-  void Activate(arma::uword varInd);
+  void Activate(const size_t varInd);
 
   // compute "equiangular" direction in output space
   void ComputeYHatDirection(const arma::mat& matX,
@@ -217,9 +211,11 @@
 
   void CholeskyInsert(double sqNormNewX, const arma::vec& newGramCol);
 
-  void GivensRotate(const arma::vec::fixed<2>& x, arma::vec::fixed<2>& rotatedX, arma::mat& G);
+  void GivensRotate(const arma::vec::fixed<2>& x,
+                    arma::vec::fixed<2>& rotatedX,
+                    arma::mat& G);
 
-  void CholeskyDelete(arma::uword colToKill);
+  void CholeskyDelete(const size_t colToKill);
 };
 
 }; // namespace regression




More information about the mlpack-svn mailing list