[mlpack-svn] r12973 - mlpack/trunk/src/mlpack/methods/lars
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jun 8 00:52:19 EDT 2012
Author: rcurtin
Date: 2012-06-08 00:52:19 -0400 (Fri, 08 Jun 2012)
New Revision: 12973
Modified:
mlpack/trunk/src/mlpack/methods/lars/lars.cpp
Log:
Some code cleanup.
Modified: mlpack/trunk/src/mlpack/methods/lars/lars.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lars/lars.cpp 2012-06-08 02:24:18 UTC (rev 12972)
+++ mlpack/trunk/src/mlpack/methods/lars/lars.cpp 2012-06-08 04:52:19 UTC (rev 12973)
@@ -85,7 +85,6 @@
// Main loop.
while ((nActive < matX.n_cols) && (maxCorr > tolerance))
{
-
// explicit computation of max correlation, among inactive indices
changeInd = -1;
maxCorr = 0;
@@ -112,7 +111,9 @@
// {
// newGramCol[i] = dot(matX.col(activeSet[i]), matX.col(changeInd));
// }
- vec newGramCol = matGram.elem(changeInd * matX.n_cols + conv_to< uvec >::from(activeSet)); // this is equivalent to the above 5 lines - check this!
+ // This is equivalent to the above 5 lines.
+ vec newGramCol = matGram.elem(changeInd * matX.n_cols +
+ conv_to<uvec>::from(activeSet));
//CholeskyInsert(matX.col(changeInd), newGramCol);
CholeskyInsert(matGram(changeInd, changeInd), newGramCol);
@@ -125,9 +126,7 @@
// compute signs of correlations
vec s = vec(nActive);
for (uword i = 0; i < nActive; i++)
- {
s(i) = corr(activeSet[i]) / fabs(corr(activeSet[i]));
- }
// compute "equiangular" direction in parameter space (betaDirection)
/* We use quotes because in the case of non-unit norm variables,
@@ -314,15 +313,13 @@
activeSet.push_back(varInd);
}
- void LARS::ComputeYHatDirection(const mat& matX,
- const vec& betaDirection,
- vec& yHatDirection)
+void LARS::ComputeYHatDirection(const mat& matX,
+ const vec& betaDirection,
+ vec& yHatDirection)
{
yHatDirection.fill(0);
- for(uword i = 0; i < nActive; i++)
- {
+ for (uword i = 0; i < nActive; i++)
yHatDirection += betaDirection(i) * matX.col(activeSet[i]);
- }
}
void LARS::InterpolateBeta()
@@ -346,14 +343,11 @@
if (matUtriCholFactor.n_rows == 0)
{
matUtriCholFactor = mat(1, 1);
+
if (elasticNet)
- {
matUtriCholFactor(0, 0) = sqrt(dot(newX, newX) + lambda2);
- }
else
- {
matUtriCholFactor(0, 0) = norm(newX, 2);
- }
}
else
{
@@ -369,38 +363,35 @@
if (n == 0)
{
matUtriCholFactor = mat(1, 1);
+
if (elasticNet)
- {
matUtriCholFactor(0, 0) = sqrt(sqNormNewX + lambda2);
- }
else
- {
matUtriCholFactor(0, 0) = sqrt(sqNormNewX);
- }
}
else
{
mat matNewR = mat(n + 1, n + 1);
if (elasticNet)
- {
sqNormNewX += lambda2;
- }
vec matUtriCholFactork = solve(trimatl(trans(matUtriCholFactor)),
- newGramCol);
+ newGramCol);
matNewR(span(0, n - 1), span(0, n - 1)) = matUtriCholFactor;
matNewR(span(0, n - 1), n) = matUtriCholFactork;
matNewR(n, span(0, n - 1)).fill(0.0);
matNewR(n, n) = sqrt(sqNormNewX - dot(matUtriCholFactork,
- matUtriCholFactork));
+ matUtriCholFactork));
matUtriCholFactor = matNewR;
}
}
-void LARS::GivensRotate(const vec::fixed<2>& x, vec::fixed<2>& rotatedX, mat& matG)
+void LARS::GivensRotate(const vec::fixed<2>& x,
+ vec::fixed<2>& rotatedX,
+ mat& matG)
{
if (x(1) == 0)
{
@@ -439,7 +430,7 @@
matUtriCholFactor.shed_col(colToKill); // remove column colToKill
n--;
- for(uword k = colToKill; k < n; k++)
+ for (uword k = colToKill; k < n; k++)
{
mat matG;
vec::fixed<2> rotatedVec;
@@ -451,6 +442,7 @@
matG * matUtriCholFactor(span(k, k + 1), span(k + 1, n - 1));
}
}
+
matUtriCholFactor.shed_row(n);
}
}
More information about the mlpack-svn
mailing list