[mlpack-svn] r15400 - mlpack/trunk/src/mlpack/methods/nca
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 3 15:03:21 EDT 2013
Author: rcurtin
Date: Wed Jul 3 15:03:21 2013
New Revision: 15400
Log:
Normalize the labels before performing computation, and use arma::Col<size_t>
instead of arma::uvec.
Modified:
mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp Wed Jul 3 15:03:21 2013
@@ -172,28 +172,33 @@
// Load data.
arma::mat data;
- data::Load(inputFile.c_str(), data, true);
+ data::Load(inputFile, data, true);
// Do we want to load labels separately?
- arma::umat labels(data.n_cols, 1);
+ arma::umat rawLabels(data.n_cols, 1);
if (labelsFile != "")
{
- data::Load(labelsFile.c_str(), labels, true);
+ data::Load(labelsFile, rawLabels, true);
- if (labels.n_rows == 1)
- labels = trans(labels);
+ if (rawLabels.n_rows == 1)
+ rawLabels = trans(rawLabels);
- if (labels.n_cols > 1)
+ if (rawLabels.n_cols > 1)
Log::Fatal << "Labels must have only one column or row!" << endl;
}
else
{
for (size_t i = 0; i < data.n_cols; i++)
- labels[i] = (int) data(data.n_rows - 1, i);
+ rawLabels[i] = (int) data(data.n_rows - 1, i);
data.shed_row(data.n_rows - 1);
}
+ // Now, normalize the labels.
+ arma::uvec mappings;
+ arma::Col<size_t> labels;
+ data::NormalizeLabels(rawLabels.unsafe_col(0), labels, mappings);
+
arma::mat distance;
// Normalize the data, if necessary.
@@ -215,10 +220,9 @@
}
// Now create the NCA object and run the optimization.
- arma::uvec labelsCol = labels.unsafe_col(0);
if (optimizerType == "sgd")
{
- NCA<LMetric<2> > nca(data, labelsCol);
+ NCA<LMetric<2> > nca(data, labels);
nca.Optimizer().StepSize() = stepSize;
nca.Optimizer().MaxIterations() = maxIterations;
nca.Optimizer().Tolerance() = tolerance;
@@ -228,7 +232,7 @@
}
else if (optimizerType == "lbfgs")
{
- NCA<LMetric<2>, L_BFGS> nca(data, labelsCol);
+ NCA<LMetric<2>, L_BFGS> nca(data, labels);
nca.Optimizer().NumBasis() = numBasis;
nca.Optimizer().MaxIterations() = maxIterations;
nca.Optimizer().ArmijoConstant() = armijoConstant;
@@ -242,5 +246,5 @@
}
// Save the output.
- data::Save(CLI::GetParam<string>("output_file").c_str(), distance, true);
+ data::Save(CLI::GetParam<string>("output_file"), distance, true);
}
More information about the mlpack-svn
mailing list