[mlpack-git] master: Add function to get the number of cols/slices. Thanks Shangtong for pointing this out. (d6c23a4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Mar 11 09:00:09 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f349aa79f97f893b2d98f5c795bff3751c9be71a...d6c23a4f1c83da3c4c604a4a7fcaa390562427a4

>---------------------------------------------------------------

commit d6c23a4f1c83da3c4c604a4a7fcaa390562427a4
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Wed Mar 11 13:59:24 2015 +0100

    Add function to get the number of cols/slices. Thanks Shangtong for pointing this out.


>---------------------------------------------------------------

d6c23a4f1c83da3c4c604a4a7fcaa390562427a4
 src/mlpack/methods/ann/trainer/trainer.hpp | 56 ++++++++++++++----------------
 1 file changed, 26 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index d2724e6..249d3be 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -80,7 +80,10 @@ class Trainer
                InputType& validationData,
                OutputType& validationLabels)
     {
-      Index(trainingData);
+      // This generates [0 1 2 3 ... (ElementCount(trainingData) - 1)]. The
+      // sequence will be used to iterate through the training data.
+      index = arma::linspace<arma::Col<size_t> >(0,
+          ElementCount(trainingData) - 1, ElementCount(trainingData));
       epoch = 0;
 
       while(true)
@@ -138,7 +141,7 @@ class Trainer
       // Reset the training error.
       trainingError = 0;
 
-      for (size_t i = 0; i < data.n_cols; i++)
+      for (size_t i = 0; i < index.n_elem; i++)
       {
         net.FeedForward(Element(data, index(i)),
             Element(target, index(i)), error);
@@ -150,10 +153,10 @@ class Trainer
           net.ApplyGradients();
       }
 
-      if ((data.n_cols % batchSize) != 0)
+      if ((index.n_elem % batchSize) != 0)
         net.ApplyGradients();
 
-      trainingError /= data.n_cols;
+      trainingError /= index.n_elem;
     }
 
     /**
@@ -168,67 +171,60 @@ class Trainer
       // Reset the validation error.
       validationError = 0;
 
-      for (size_t i = 0; i < data.n_cols; i++)
+      for (size_t i = 0; i < ElementCount(data); i++)
       {
          validationError += net.Evaluate(Element(data, i),
             Element(target, i), error);
       }
 
-      validationError /= data.n_cols;
+      validationError /= ElementCount(data);
     }
 
     /*
-     * Generate index sequence to iterate through the data.
+     * Create a Col object which uses memory from an existing matrix object.
+     * (This approach is currently not alias safe)
      *
-     * @param input The reference data.
+     * @param data The reference data.
+     * @param sliceNum Provide a Col object of the specified index.
      */
     template<typename eT>
-    void Index(const arma::Mat<eT>& input)
+    arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
     {
-      // This generates [0 1 2 3 ... (input.n_cols - 1)]. The sequence
-      // will be used to iterate through the training data.
-      index = arma::linspace<arma::Col<size_t> >(0, input.n_cols - 1,
-          input.n_cols);
+      return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
     }
 
     /*
-     * Generate index sequence to iterate through the data.
+     * Provide the reference to the matrix representing a single slice.
      *
-     * @param input The reference data.
+     * @param data The reference data.
+     * @param sliceNum Provide a single slice of the specified index.
      */
     template<typename eT>
-    void Index(const arma::Cube<eT>& input)
+    const arma::Mat<eT>& Element(arma::Cube<eT>& input, const size_t sliceNum)
     {
-      // This generates [0 1 2 3 ... (input.n_slices - 1)]. The sequence
-      // will be used to iterate through the training data.
-      index = arma::linspace<arma::Col<size_t> >(0, input.n_slices - 1,
-          input.n_slices);
+      return *(input.mat_ptrs[sliceNum]);
     }
 
     /*
-     * Create a Col object which uses memory from an existing matrix object.
-     * (This approach is currently not alias safe)
+     * Get the number of elements.
      *
      * @param data The reference data.
-     * @param sliceNum Provide a Col object of the specified index.
      */
     template<typename eT>
-    arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+    size_t ElementCount(const arma::Mat<eT>& data) const
     {
-      return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
+      return data.n_cols;
     }
 
     /*
-     * Provide the reference to the matrix representing a single slice.
+     * Get the number of elements.
      *
      * @param data The reference data.
-     * @param sliceNum Provide a single slice of the specified index.
      */
     template<typename eT>
-    const arma::Mat<eT>& Element(arma::Cube<eT>& input,
-        const size_t sliceNum)
+    size_t ElementCount(const arma::Cube<eT>& data) const
     {
-      return *(input.mat_ptrs[sliceNum]);
+      return data.n_slices;
     }
 
     //! The network which should be trained and evaluated.



More information about the mlpack-git mailing list