[mlpack-svn] r15400 - mlpack/trunk/src/mlpack/methods/nca

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 3 15:03:21 EDT 2013


Author: rcurtin
Date: Wed Jul  3 15:03:21 2013
New Revision: 15400

Log:
Normalize the labels before performing computation, and use arma::Col<size_t>
instead of arma::uvec.


Modified:
   mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp	Wed Jul  3 15:03:21 2013
@@ -172,28 +172,33 @@
 
   // Load data.
   arma::mat data;
-  data::Load(inputFile.c_str(), data, true);
+  data::Load(inputFile, data, true);
 
   // Do we want to load labels separately?
-  arma::umat labels(data.n_cols, 1);
+  arma::umat rawLabels(data.n_cols, 1);
   if (labelsFile != "")
   {
-    data::Load(labelsFile.c_str(), labels, true);
+    data::Load(labelsFile, rawLabels, true);
 
-    if (labels.n_rows == 1)
-      labels = trans(labels);
+    if (rawLabels.n_rows == 1)
+      rawLabels = trans(rawLabels);
 
-    if (labels.n_cols > 1)
+    if (rawLabels.n_cols > 1)
       Log::Fatal << "Labels must have only one column or row!" << endl;
   }
   else
   {
     for (size_t i = 0; i < data.n_cols; i++)
-      labels[i] = (int) data(data.n_rows - 1, i);
+      rawLabels[i] = (int) data(data.n_rows - 1, i);
 
     data.shed_row(data.n_rows - 1);
   }
 
+  // Now, normalize the labels.
+  arma::uvec mappings;
+  arma::Col<size_t> labels;
+  data::NormalizeLabels(rawLabels.unsafe_col(0), labels, mappings);
+
   arma::mat distance;
 
   // Normalize the data, if necessary.
@@ -215,10 +220,9 @@
   }
 
   // Now create the NCA object and run the optimization.
-  arma::uvec labelsCol = labels.unsafe_col(0);
   if (optimizerType == "sgd")
   {
-    NCA<LMetric<2> > nca(data, labelsCol);
+    NCA<LMetric<2> > nca(data, labels);
     nca.Optimizer().StepSize() = stepSize;
     nca.Optimizer().MaxIterations() = maxIterations;
     nca.Optimizer().Tolerance() = tolerance;
@@ -228,7 +232,7 @@
   }
   else if (optimizerType == "lbfgs")
   {
-    NCA<LMetric<2>, L_BFGS> nca(data, labelsCol);
+    NCA<LMetric<2>, L_BFGS> nca(data, labels);
     nca.Optimizer().NumBasis() = numBasis;
     nca.Optimizer().MaxIterations() = maxIterations;
     nca.Optimizer().ArmijoConstant() = armijoConstant;
@@ -242,5 +246,5 @@
   }
 
   // Save the output.
-  data::Save(CLI::GetParam<string>("output_file").c_str(), distance, true);
+  data::Save(CLI::GetParam<string>("output_file"), distance, true);
 }



More information about the mlpack-svn mailing list