[mlpack-git] master: Fix classification error (d7d3493)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Dec 13 21:46:25 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3b926fd86ab143eb8af7327b9fb89fead7538df0...f535c29999c3d57b06664cceb871b5c937666586

>---------------------------------------------------------------

commit d7d34937b172b6b05ec16151a3d7ed653842d420
Author: Pavel <pashaworking at gmail.com>
Date:   Thu Dec 10 18:52:47 2015 +0500

    Fix classification error
    
    Change "invVar" to "variance" matrix when calculating testProbs.
    
    By using "invVar" you have variances product in numenator, but it needs to be in denominator.
    In addition there was potential problem with accuracy when calculate exponents and then calculate logarithm.
    I fixed it too.


>---------------------------------------------------------------

d7d34937b172b6b05ec16151a3d7ed653842d420
 src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
index 2b0b3c7..b40719c 100644
--- a/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
+++ b/src/mlpack/methods/naive_bayes/naive_bayes_classifier_impl.hpp
@@ -186,10 +186,10 @@ void NaiveBayesClassifier<MatType>::Classify(const MatType& data,
     arma::mat rhs = -0.5 * arma::diagmat(invVar.col(i)) * diffs;
     arma::vec exponents(diffs.n_cols);
     for (size_t j = 0; j < diffs.n_cols; ++j)
-      exponents(j) = std::exp(arma::accu(diffs.col(j) % rhs.unsafe_col(j)));
+      exponents(j) = arma::accu(diffs.col(j) % rhs.unsafe_col(j));  //log( exp (value) ) == value
 
-    testProbs.col(i) += log(pow(2 * M_PI, (double) data.n_rows / -2.0) *
-        std::pow(arma::det(arma::diagmat(invVar.col(i))), -0.5) * exponents);
+    //calculate prob as sum of logarithm to decrease floating point errors
+    testProbs.col(i) += (data.n_rows / -2.0 * log(2 * M_PI) - 0.5 * log(arma::det(arma::diagmat(variances.col(i)))) + exponents);
   }
 
   // Now calculate the label.



More information about the mlpack-git mailing list