[mlpack-svn] r13823 - mlpack/trunk/src/mlpack/methods/nca

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 1 17:34:11 EDT 2012


Author: rcurtin
Date: 2012-11-01 17:34:11 -0400 (Thu, 01 Nov 2012)
New Revision: 13823

Modified:
   mlpack/trunk/src/mlpack/methods/nca/nca.hpp
   mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
Log:
Allow passing an initial matrix.


Modified: mlpack/trunk/src/mlpack/methods/nca/nca.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca.hpp	2012-11-01 21:18:35 UTC (rev 13822)
+++ mlpack/trunk/src/mlpack/methods/nca/nca.hpp	2012-11-01 21:34:11 UTC (rev 13823)
@@ -63,7 +63,10 @@
 
   /**
    * Perform Neighborhood Components Analysis.  The output distance learning
-   * matrix is written into the passed reference.
+   * matrix is written into the passed reference.  If LearnDistance() is called
+   * with an outputMatrix which has the correct size (dataset.n_rows x
+   * dataset.n_rows), that matrix will be used as the starting point for
+   * optimization.
    *
    * @param output_matrix Covariance matrix of Mahalanobis distance.
    */

Modified: mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp	2012-11-01 21:18:35 UTC (rev 13822)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp	2012-11-01 21:34:11 UTC (rev 13823)
@@ -38,7 +38,10 @@
 template<typename MetricType>
 void NCA<MetricType>::LearnDistance(arma::mat& outputMatrix)
 {
-  outputMatrix = arma::eye<arma::mat>(dataset.n_rows, dataset.n_rows);
+  // See if we were passed an initialized matrix.
+  if ((outputMatrix.n_rows != dataset.n_rows) ||
+      (outputMatrix.n_cols != dataset.n_rows))
+    outputMatrix.eye(dataset.n_rows, dataset.n_rows);
 
   SoftmaxErrorFunction<MetricType> errorFunc(dataset, labels, metric);
 




More information about the mlpack-svn mailing list