[mlpack-git] master: Ensure the last batch size isn't zero, to avoid division by zero before updating, when using a batchsize that fulfilled the constraint: (numFunctions % batchSize) == 1. (68e2cbf)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 1 11:57:25 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9e4e126589d19ddbaaaebec256c77d1f3eb75ce2...6359987ecb0cbf762dc2b2167e574ae595a120d8

>---------------------------------------------------------------

commit 68e2cbf394149076d75cca1a1f415c1ed9e46bca
Author: marcus <marcus.edel at fu-berlin.de>
Date:   Fri Apr 1 17:56:10 2016 +0200

    Ensure the last batch size isn't zero, to avoid division by zero before updating, when using a batchsize that fulfilled the constraint: (numFunctions % batchSize) == 1.


>---------------------------------------------------------------

68e2cbf394149076d75cca1a1f415c1ed9e46bca
 .../core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp   | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
index fe7db5c..38c72b3 100644
--- a/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
+++ b/src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd_impl.hpp
@@ -120,8 +120,18 @@ double MiniBatchSGD<DecomposableFunctionType>::Optimize(arma::mat& iterate)
         gradient += funcGradient;
       }
 
-      // Now update the iterate.
-      iterate -= (stepSize / lastBatchSize) * gradient;
+      // Ensure the last batch size isn't zero, to avoid division by zero before
+      // updating.
+      if (lastBatchSize > 0)
+      {
+        // Now update the iterate.
+        iterate -= (stepSize / lastBatchSize) * gradient;
+      }
+      else
+      {
+        // Now update the iterate.
+        iterate -= stepSize * gradient;
+      }
 
       // Add that to the overall objective function.
       for (size_t j = 0; j < lastBatchSize; ++j)




More information about the mlpack-git mailing list