[mlpack-svn] r13822 - mlpack/trunk/src/mlpack/methods/nca
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 1 17:18:35 EDT 2012
Author: rcurtin
Date: 2012-11-01 17:18:35 -0400 (Thu, 01 Nov 2012)
New Revision: 13822
Modified:
mlpack/trunk/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp
Log:
For real, I seriously considered anger management classes while hunting this one
down.
Modified: mlpack/trunk/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp 2012-11-01 21:17:32 UTC (rev 13821)
+++ mlpack/trunk/src/mlpack/methods/nca/nca_softmax_error_function_impl.hpp 2012-11-01 21:18:35 UTC (rev 13822)
@@ -159,27 +159,30 @@
// If the points are in the same class, we must add to the second term of
// the gradient as well as the numerator of p_i. We will divide by the
- // denominator of p_ik later.
+ // denominator of p_ik later. For x_ik we are not using stretched points.
+ arma::vec x_ik = dataset.col(i) - dataset.col(k);
if (labels[i] == labels[k])
{
numerator += eval;
- secondTerm += eval *
- (stretchedDataset.col(i) - stretchedDataset.col(k)) *
- arma::trans(stretchedDataset.col(i) - stretchedDataset.col(k));
+ secondTerm += eval * x_ik * trans(x_ik);
}
// We always have to add to the denominator of p_i and the first term of the
// gradient computation. We will divide by the denominator of p_ik later.
denominator += eval;
- firstTerm += eval *
- (stretchedDataset.col(i) - stretchedDataset.col(k)) *
- arma::trans(stretchedDataset.col(i) - stretchedDataset.col(k));
+ firstTerm += eval * x_ik * trans(x_ik);
}
// Calculate p_i.
double p = 0;
if (denominator == 0)
+ {
Log::Warn << "Denominator of p_" << i << " is 0!" << std::endl;
+ // If the denominator is zero, then all p_ik should be zero and there is
+ // no gradient contribution from this point.
+ gradient.zeros(coordinates.n_rows, coordinates.n_rows);
+ return;
+ }
else
{
p = numerator / denominator;
More information about the mlpack-svn
mailing list