[mlpack-svn] r12580 - mlpack/trunk/src/mlpack/methods/local_coordinate_coding

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 30 17:23:00 EDT 2012


Author: rcurtin
Date: 2012-04-30 17:23:00 -0400 (Mon, 30 Apr 2012)
New Revision: 12580

Modified:
   mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
Log:
Refactor in accordance with new LARS API.


Modified: mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	2012-04-30 21:22:43 UTC (rev 12579)
+++ mlpack/trunk/src/mlpack/methods/local_coordinate_coding/lcc.cpp	2012-04-30 21:23:00 UTC (rev 12580)
@@ -32,12 +32,12 @@
 }
 
 
-void LocalCoordinateCoding::InitDictionary() {  
+void LocalCoordinateCoding::InitDictionary() {
   RandomInitDictionary();
 }
 
 
-void LocalCoordinateCoding::LoadDictionary(const char* dictionaryFilename) {  
+void LocalCoordinateCoding::LoadDictionary(const char* dictionaryFilename) {
   matD.load(dictionaryFilename);
 }
 
@@ -74,16 +74,16 @@
 
   bool converged = false;
   double lastObjVal = 1e99;
-  
+
   Log::Info << "Initial Coding Step" << endl;
   OptimizeCode();
   uvec adjacencies = find(matZ);
-  Log::Info << "\tSparsity level: " 
-	    << 100.0 * ((double)(adjacencies.n_elem)) 
+  Log::Info << "\tSparsity level: "
+	    << 100.0 * ((double)(adjacencies.n_elem))
                      / ((double)(nAtoms * nPoints))
 	    << "%\n";
   Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
-  
+
   for(uword t = 1; t <= nIterations && !converged; t++) {
     Log::Info << "Iteration " << t << " of " << nIterations << endl;
 
@@ -91,12 +91,12 @@
     OptimizeDictionary(adjacencies);
     double dsObjVal = Objective(adjacencies);
     Log::Info << "\tObjective value: " << Objective(adjacencies) << endl;
-    
+
     Log::Info << "Coding Step" << endl;
     OptimizeCode();
     adjacencies = find(matZ);
-    Log::Info << "\tSparsity level: " 
-	      << 100.0 * ((double)(adjacencies.n_elem)) 
+    Log::Info << "\tSparsity level: "
+	      << 100.0 * ((double)(adjacencies.n_elem))
                        / ((double)(nAtoms * nPoints))
 	      << "%\n";
     double curObjVal = Objective(adjacencies);
@@ -105,7 +105,7 @@
     if(curObjVal > dsObjVal) {
       Log::Fatal << "Objective increased in sparse coding step!" << endl;
     }
-    
+
     double objValImprov = lastObjVal - curObjVal;
     Log::Info << "\t\t\t\t\tImprovement: " << std::scientific
 	      <<  objValImprov << endl;
@@ -113,35 +113,35 @@
       converged = true;
       Log::Info << "Converged within tolerance\n";
     }
-    
+
     lastObjVal = curObjVal;
   }
 }
 
 
 void LocalCoordinateCoding::OptimizeCode() {
-  mat matSqDists = 
+  mat matSqDists =
     repmat(trans(sum(square(matD))), 1, nPoints)
     + repmat(sum(square(matX)), nAtoms, 1)
-    - 2 * trans(matD) * matX;			     
-  
+    - 2 * trans(matD) * matX;
+
   mat matInvSqDists = 1.0 / matSqDists;
-  
+
   mat matDTD = trans(matD) * matD;
   mat matDPrimeTDPrime(matDTD.n_rows, matDTD.n_cols);
-  
+
   for(uword i = 0; i < nPoints; i++) {
     // report progress
     if((i % 100) == 0) {
       Log::Debug << "\t" << i << endl;
     }
-    
+
     vec w = matSqDists.unsafe_col(i);
     vec invW = matInvSqDists.unsafe_col(i);
     mat matDPrime = matD * diagmat(invW);
-    
+
     mat matDPrimeTDPrime = diagmat(invW) * matDTD * diagmat(invW);
-    
+
     //LARS lars;
     // do we still need 0.5 * lambda? yes, yes we do
     //lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, true, 0.5 * lambda); // apparently not as fast as using the below duo
@@ -150,11 +150,10 @@
     // the duo
     /* lars.Init(matDPrime.memptr(), matX.colptr(i), nDims, nAtoms, false, 0.5 * lambda); */
     /* lars.SetGram(matDPrimeTDPrime.memptr(), nAtoms); */
-    
+
     bool useCholesky = false;
-    LARS lars(useCholesky, 0.5 * lambda);
-    lars.SetGram(matDPrimeTDPrime);
-    
+    LARS lars(useCholesky, matDPrimeTDPrime, 0.5 * lambda);
+
     lars.DoLARS(matDPrime, matX.unsafe_col(i));
     vec beta;
     lars.Solution(beta);
@@ -182,7 +181,7 @@
     }
     neighborCounts(curPointInd) = curCount;
   }
-  
+
   // build matXPrime := [X x^1 ... x^1 ... x^n ... x^n]
   // where each x^i is repeated for the number of neighbors x^i has
   mat matXPrime = zeros(nDims, nPoints + adjacencies.n_elem);
@@ -195,7 +194,7 @@
     }
     curCol += neighborCounts(i);
   }
-  
+
   // handle the case of inactive atoms (atoms not used in the given coding)
   std::vector<uword> inactiveAtoms;
   std::vector<uword> activeAtoms;
@@ -220,7 +219,7 @@
     uvec inactiveAtomsVec = conv_to< uvec >::from(inactiveAtoms);
     RemoveRows(matZ, inactiveAtomsVec, matActiveZ);
   }
-  
+
   uvec atomReverseLookup = uvec(nAtoms);
   for(uword i = 0; i < nActiveAtoms; i++) {
     atomReverseLookup(activeAtoms[i]) = i;
@@ -230,34 +229,34 @@
   if(nInactiveAtoms > 0) {
     Log::Info << "There are " << nInactiveAtoms << " inactive atoms. They will be re-initialized randomly.\n";
   }
-  
+
   mat matZPrime = zeros(nActiveAtoms, nPoints + adjacencies.n_elem);
   //Log::Debug << "adjacencies.n_elem = " << adjacencies.n_elem << endl;
   matZPrime(span::all, span(0, nPoints - 1)) = matActiveZ;
-  
+
   vec wSquared = ones(nPoints + adjacencies.n_elem, 1);
   //Log::Debug << "building up matZPrime\n";
   for(uword l = 0; l < adjacencies.n_elem; l++) {
     uword atomInd = adjacencies(l) % nAtoms;
     uword pointInd = (uword) (adjacencies(l) / nAtoms);
     matZPrime(atomReverseLookup(atomInd), nPoints + l) = 1.0;
-    wSquared(nPoints + l) = matZ(atomInd, pointInd); 
+    wSquared(nPoints + l) = matZ(atomInd, pointInd);
   }
-  
-  wSquared.subvec(nPoints, wSquared.n_elem - 1) = 
+
+  wSquared.subvec(nPoints, wSquared.n_elem - 1) =
     lambda * abs(wSquared.subvec(nPoints, wSquared.n_elem - 1));
-  
+
   //Log::Debug << "about to solve\n";
   mat matDEstimate;
   if(inactiveAtoms.empty()) {
     mat A = matZPrime * diagmat(wSquared) * trans(matZPrime);
     mat B = matZPrime * diagmat(wSquared) * trans(matXPrime);
-    
+
     //Log::Debug << "solving...\n";
-    matDEstimate = 
+    matDEstimate =
       trans(solve(A, B));
-    /*    
-    matDEstimate = 
+    /*
+    matDEstimate =
       trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
 		  matZPrime * diagmat(wSquared) * trans(matXPrime)));
     */
@@ -265,7 +264,7 @@
   else {
     matDEstimate = zeros(nDims, nAtoms);
     //Log::Debug << "solving...\n";
-    mat matDActiveEstimate = 
+    mat matDActiveEstimate =
       trans(solve(matZPrime * diagmat(wSquared) * trans(matZPrime),
 		  matZPrime * diagmat(wSquared) * trans(matXPrime)));
     for(uword j = 0; j < nActiveAtoms; j++) {
@@ -276,7 +275,7 @@
       RandomAtom(vecD_j);
       /*
       vec new_atom = randn(nDims, 1);
-      matDEstimate.col(inactiveAtoms[i]) = 
+      matDEstimate.col(inactiveAtoms[i]) =
 	new_atom / norm(new_atom, 2);
       */
     }
@@ -315,7 +314,7 @@
   uword n_rows = X.n_rows;
   uword n_to_remove = rows_to_remove.n_elem;
   uword n_to_keep = n_rows - n_to_remove;
-  
+
   if(n_to_remove == 0) {
     X_mod = X;
   }
@@ -334,15 +333,15 @@
     }
     // now, check i'th row to remove to (i + 1)'th row to remove, until i = penultimate row
     while(remove_ind < n_to_remove - 1) {
-      uword height = 
+      uword height =
 	rows_to_remove[remove_ind + 1]
 	- rows_to_remove[remove_ind]
 	- 1;
       if(height > 0) {
-	X_mod(span(cur_row, cur_row + height - 1), 
+	X_mod(span(cur_row, cur_row + height - 1),
 	      span::all) =
 	  X(span(rows_to_remove[remove_ind] + 1,
-		 rows_to_remove[remove_ind + 1] - 1), 
+		 rows_to_remove[remove_ind + 1] - 1),
 	    span::all);
 	cur_row += height;
       }
@@ -350,9 +349,9 @@
     }
     // now that i is last row to remove, check last row to remove to last row
     if(rows_to_remove[remove_ind] < n_rows - 1) {
-      X_mod(span(cur_row, n_to_keep - 1), 
-	    span::all) = 
-	X(span(rows_to_remove[remove_ind] + 1, n_rows - 1), 
+      X_mod(span(cur_row, n_to_keep - 1),
+	    span::all) =
+	X(span(rows_to_remove[remove_ind] + 1, n_rows - 1),
 	  span::all);
     }
   }




More information about the mlpack-svn mailing list