[mlpack-svn] r13803 - mlpack/trunk/src/mlpack/methods/nca
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Oct 31 16:19:22 EDT 2012
Author: rcurtin
Date: 2012-10-31 16:19:22 -0400 (Wed, 31 Oct 2012)
New Revision: 13803
Modified:
mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
Log:
More parameters for NCA in accordance with the change to SGD.
Modified: mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp 2012-10-31 19:54:06 UTC (rev 13802)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp 2012-10-31 20:19:22 UTC (rev 13803)
@@ -27,6 +27,12 @@
PARAM_STRING_REQ("output_file", "Output file for learned distance matrix.",
"o");
PARAM_STRING("labels_file", "File of labels for input dataset.", "l", "");
+PARAM_DOUBLE("step_size", "Step size for stochastic gradient descent.", "s",
+ 0.01);
+PARAM_INT("max_iterations", "Maximum number of iterations for stochastic "
+ "gradient descent (0 indicates no limit).", "n", 500000);
+PARAM_DOUBLE("tolerance", "Maximum tolerance for termination of stochastic "
+ "gradient descent.", "t", 1e-7);
using namespace mlpack;
using namespace mlpack::nca;
@@ -43,6 +49,10 @@
const string labelsFile = CLI::GetParam<string>("labels_file");
const string outputFile = CLI::GetParam<string>("output_file");
+ const double stepSize = CLI::GetParam<double>("step_size");
+ const size_t maxIterations = CLI::GetParam<int>("max_iterations");
+ const double tolerance = CLI::GetParam<double>("tolerance");
+
// Load data.
mat data;
data::Load(inputFile.c_str(), data, true);
@@ -68,7 +78,8 @@
}
// Now create the NCA object and run the optimization.
- NCA<LMetric<2> > nca(data, labels.unsafe_col(0));
+ NCA<LMetric<2> > nca(data, labels.unsafe_col(0), stepSize, maxIterations,
+ tolerance);
mat distance;
nca.LearnDistance(distance);
More information about the mlpack-svn
mailing list