[mlpack-git] master, mlpack-1.0.x: Ok, handle NaNs correctly, and also check this in in trunk, not in the tags... (b49848a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:45:05 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit b49848a2b6e69c31d7d8d35d1d03ff42f70fe080
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Mar 9 05:13:00 2014 +0000

    Ok, handle NaNs correctly, and also check this in in trunk, not in the tags...


>---------------------------------------------------------------

b49848a2b6e69c31d7d8d35d1d03ff42f70fe080
 src/mlpack/methods/nmf/mult_div_update_rules.hpp | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/nmf/mult_div_update_rules.hpp b/src/mlpack/methods/nmf/mult_div_update_rules.hpp
index 52efa80..6f4075c 100644
--- a/src/mlpack/methods/nmf/mult_div_update_rules.hpp
+++ b/src/mlpack/methods/nmf/mult_div_update_rules.hpp
@@ -60,10 +60,18 @@ class WMultiplicativeDivergenceRule
         t2.set_size(H.n_cols);
         for (size_t k = 0; k < t2.n_elem; ++k)
         {
+          // This may produce NaNs if V(i, k) = 0.
+          // Technically the math in the paper does not define what to do in
+          // this case, but considering the basic intent of the update rules,
+          // we'll make this modification and take t2(k) = 0.0.
           t2(k) = H(j, k) * V(i, k) / t1(i, k);
+          if (t2(k) != t2(k))
+            t2(k) = 0.0;
         }
 
-        W(i, j) = W(i, j) * sum(t2) / sum(H.row(j));
+        // Only update if the sum is not going to be 0, so as to prevent a
+        // divide by zero.  If sum(H.row(j)) is 0, then t2 should be 0 too.
+        W(i, j) *= sum(t2) / sum(H.row(j));
       }
     }
   }
@@ -111,10 +119,18 @@ class HMultiplicativeDivergenceRule
         t2.set_size(W.n_rows);
         for (size_t k = 0; k < t2.n_elem; ++k)
         {
+          // This may produce NaNs if V(i, k) = 0.
+          // Technically the math in the paper does not define what to do in
+          // this case, but considering the basic intent of the update rules,
+          // we'll make this modification and take t2(k) = 0.0.
           t2(k) = W(k, i) * V(k, j) / t1(k, j);
+          if (t2(k) != t2(k))
+            t2(k) = 0.0;
         }
 
-        H(i,j) = H(i,j) * sum(t2) / sum(W.col(i));
+        // Only update if the sum is not going to be 0, so as to prevent a
+        // divide by zero.  If sum(W.col(j)) is 0, then t2 should be 0 too.
+        H(i, j) *= sum(t2) / sum(W.col(i));
       }
     }
   }



More information about the mlpack-git mailing list