[mlpack-svn] r13823 - mlpack/trunk/src/mlpack/methods/nca
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 1 17:34:11 EDT 2012
Author: rcurtin
Date: 2012-11-01 17:34:11 -0400 (Thu, 01 Nov 2012)
New Revision: 13823
Modified:
mlpack/trunk/src/mlpack/methods/nca/nca.hpp
mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
Log:
Allow passing an initial matrix.
Modified: mlpack/trunk/src/mlpack/methods/nca/nca.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca.hpp 2012-11-01 21:18:35 UTC (rev 13822)
+++ mlpack/trunk/src/mlpack/methods/nca/nca.hpp 2012-11-01 21:34:11 UTC (rev 13823)
@@ -63,7 +63,10 @@
/**
* Perform Neighborhood Components Analysis. The output distance learning
- * matrix is written into the passed reference.
+ * matrix is written into the passed reference. If LearnDistance() is called
+ * with an outputMatrix which has the correct size (dataset.n_rows x
+ * dataset.n_rows), that matrix will be used as the starting point for
+ * optimization.
*
* @param output_matrix Covariance matrix of Mahalanobis distance.
*/
Modified: mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp 2012-11-01 21:18:35 UTC (rev 13822)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_impl.hpp 2012-11-01 21:34:11 UTC (rev 13823)
@@ -38,7 +38,10 @@
template<typename MetricType>
void NCA<MetricType>::LearnDistance(arma::mat& outputMatrix)
{
- outputMatrix = arma::eye<arma::mat>(dataset.n_rows, dataset.n_rows);
+ // See if we were passed an initialized matrix.
+ if ((outputMatrix.n_rows != dataset.n_rows) ||
+ (outputMatrix.n_cols != dataset.n_rows))
+ outputMatrix.eye(dataset.n_rows, dataset.n_rows);
SoftmaxErrorFunction<MetricType> errorFunc(dataset, labels, metric);
More information about the mlpack-svn
mailing list