[mlpack-svn] r13821 - mlpack/trunk/src/mlpack/methods/nca

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 1 17:17:32 EDT 2012


Author: rcurtin
Date: 2012-11-01 17:17:32 -0400 (Thu, 01 Nov 2012)
New Revision: 13821

Modified:
   mlpack/trunk/src/mlpack/methods/nca/nca.hpp
   mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
   mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
Log:
Add parameter to determine whether or not SGD is shuffled.


Modified: mlpack/trunk/src/mlpack/methods/nca/nca.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca.hpp	2012-11-01 16:58:42 UTC (rev 13820)
+++ mlpack/trunk/src/mlpack/methods/nca/nca.hpp	2012-11-01 21:17:32 UTC (rev 13821)
@@ -50,12 +50,15 @@
    * @param stepSize Step size for stochastic gradient descent.
    * @param maxIterations Maximum iterations for stochastic gradient descent.
    * @param tolerance Tolerance for termination of stochastic gradient descent.
+   * @param shuffle Whether or not to shuffle the dataset during SGD.
+   * @param metric Instantiated metric to use.
    */
   NCA(const arma::mat& dataset,
       const arma::uvec& labels,
       const double stepSize = 0.01,
       const size_t maxIterations = 500000,
       const double tolerance = 1e-10,
+      const bool shuffle = true,
       MetricType metric = MetricType());
 
   /**
@@ -101,6 +104,8 @@
   size_t maxIterations;
   //! Tolerance for termination of stochastic gradient descent.
   double tolerance;
+  //! Whether or not to shuffle the dataset for SGD.
+  bool shuffle;
 };
 
 }; // namespace nca

Modified: mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp	2012-11-01 16:58:42 UTC (rev 13820)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp	2012-11-01 21:17:32 UTC (rev 13821)
@@ -1,4 +1,4 @@
-/***
+/**
  * @file nca_impl.hpp
  * @author Ryan Curtin
  *
@@ -24,13 +24,15 @@
                      const double stepSize,
                      const size_t maxIterations,
                      const double tolerance,
+                     const bool shuffle,
                      MetricType metric) :
     dataset(dataset),
     labels(labels),
     metric(metric),
     stepSize(stepSize),
     maxIterations(maxIterations),
-    tolerance(tolerance)
+    tolerance(tolerance),
+    shuffle(shuffle)
 { /* Nothing to do. */ }
 
 template<typename MetricType>
@@ -42,7 +44,7 @@
 
   // We will use stochastic gradient descent to optimize the NCA error function.
   optimization::SGD<SoftmaxErrorFunction<MetricType> > sgd(errorFunc, stepSize,
-      maxIterations, tolerance);
+      maxIterations, tolerance, shuffle);
 
   Timer::Start("nca_sgd_optimization");
 

Modified: mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp	2012-11-01 16:58:42 UTC (rev 13820)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_main.cpp	2012-11-01 21:17:32 UTC (rev 13821)
@@ -36,6 +36,8 @@
 PARAM_FLAG("normalize", "Normalize data; useful for datasets where points are "
     "far apart, or when SGD is converging to an objective of NaN.", "N");
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
+PARAM_FLAG("linear_scan", "Don't shuffle the order in which data points are "
+    "visited for SGD.", "L");
 
 using namespace mlpack;
 using namespace mlpack::nca;
@@ -61,6 +63,7 @@
   const size_t maxIterations = (size_t) CLI::GetParam<int>("max_iterations");
   const double tolerance = CLI::GetParam<double>("tolerance");
   const bool normalize = CLI::HasParam("normalize");
+  const bool shuffle = !CLI::HasParam("linear_scan");
 
   // Load data.
   mat data;
@@ -107,11 +110,13 @@
 
   // Now create the NCA object and run the optimization.
   NCA<LMetric<2> > nca(data, labels.unsafe_col(0), stepSize, maxIterations,
-      tolerance);
+      tolerance, shuffle);
 
   mat distance;
   nca.LearnDistance(distance);
 
+  Log::Warn << trans(distance);
+
   // Save the output.
   data::Save(CLI::GetParam<string>("output_file").c_str(), distance, true);
 }




More information about the mlpack-svn mailing list