[mlpack-git] master: Add function to get the number of cols/slices. Thanks Shangtong for pointing this out. (d6c23a4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Mar 11 09:00:09 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f349aa79f97f893b2d98f5c795bff3751c9be71a...d6c23a4f1c83da3c4c604a4a7fcaa390562427a4
>---------------------------------------------------------------
commit d6c23a4f1c83da3c4c604a4a7fcaa390562427a4
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Wed Mar 11 13:59:24 2015 +0100
Add function to get the number of cols/slices. Thanks Shangtong for pointing this out.
>---------------------------------------------------------------
d6c23a4f1c83da3c4c604a4a7fcaa390562427a4
src/mlpack/methods/ann/trainer/trainer.hpp | 56 ++++++++++++++----------------
1 file changed, 26 insertions(+), 30 deletions(-)
diff --git a/src/mlpack/methods/ann/trainer/trainer.hpp b/src/mlpack/methods/ann/trainer/trainer.hpp
index d2724e6..249d3be 100644
--- a/src/mlpack/methods/ann/trainer/trainer.hpp
+++ b/src/mlpack/methods/ann/trainer/trainer.hpp
@@ -80,7 +80,10 @@ class Trainer
InputType& validationData,
OutputType& validationLabels)
{
- Index(trainingData);
+ // This generates [0 1 2 3 ... (ElementCount(trainingData) - 1)]. The
+ // sequence will be used to iterate through the training data.
+ index = arma::linspace<arma::Col<size_t> >(0,
+ ElementCount(trainingData) - 1, ElementCount(trainingData));
epoch = 0;
while(true)
@@ -138,7 +141,7 @@ class Trainer
// Reset the training error.
trainingError = 0;
- for (size_t i = 0; i < data.n_cols; i++)
+ for (size_t i = 0; i < index.n_elem; i++)
{
net.FeedForward(Element(data, index(i)),
Element(target, index(i)), error);
@@ -150,10 +153,10 @@ class Trainer
net.ApplyGradients();
}
- if ((data.n_cols % batchSize) != 0)
+ if ((index.n_elem % batchSize) != 0)
net.ApplyGradients();
- trainingError /= data.n_cols;
+ trainingError /= index.n_elem;
}
/**
@@ -168,67 +171,60 @@ class Trainer
// Reset the validation error.
validationError = 0;
- for (size_t i = 0; i < data.n_cols; i++)
+ for (size_t i = 0; i < ElementCount(data); i++)
{
validationError += net.Evaluate(Element(data, i),
Element(target, i), error);
}
- validationError /= data.n_cols;
+ validationError /= ElementCount(data);
}
/*
- * Generate index sequence to iterate through the data.
+ * Create a Col object which uses memory from an existing matrix object.
+ * (This approach is currently not alias safe)
*
- * @param input The reference data.
+ * @param data The reference data.
+ * @param sliceNum Provide a Col object of the specified index.
*/
template<typename eT>
- void Index(const arma::Mat<eT>& input)
+ arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
{
- // This generates [0 1 2 3 ... (input.n_cols - 1)]. The sequence
- // will be used to iterate through the training data.
- index = arma::linspace<arma::Col<size_t> >(0, input.n_cols - 1,
- input.n_cols);
+ return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
}
/*
- * Generate index sequence to iterate through the data.
+ * Provide the reference to the matrix representing a single slice.
*
- * @param input The reference data.
+ * @param data The reference data.
+ * @param sliceNum Provide a single slice of the specified index.
*/
template<typename eT>
- void Index(const arma::Cube<eT>& input)
+ const arma::Mat<eT>& Element(arma::Cube<eT>& input, const size_t sliceNum)
{
- // This generates [0 1 2 3 ... (input.n_slices - 1)]. The sequence
- // will be used to iterate through the training data.
- index = arma::linspace<arma::Col<size_t> >(0, input.n_slices - 1,
- input.n_slices);
+ return *(input.mat_ptrs[sliceNum]);
}
/*
- * Create a Col object which uses memory from an existing matrix object.
- * (This approach is currently not alias safe)
+ * Get the number of elements.
*
* @param data The reference data.
- * @param sliceNum Provide a Col object of the specified index.
*/
template<typename eT>
- arma::Col<eT> Element(arma::Mat<eT>& input, const size_t colNum)
+ size_t ElementCount(const arma::Mat<eT>& data) const
{
- return arma::Col<eT>(input.colptr(colNum), input.n_rows, false, true);
+ return data.n_cols;
}
/*
- * Provide the reference to the matrix representing a single slice.
+ * Get the number of elements.
*
* @param data The reference data.
- * @param sliceNum Provide a single slice of the specified index.
*/
template<typename eT>
- const arma::Mat<eT>& Element(arma::Cube<eT>& input,
- const size_t sliceNum)
+ size_t ElementCount(const arma::Cube<eT>& data) const
{
- return *(input.mat_ptrs[sliceNum]);
+ return data.n_slices;
}
//! The network which should be trained and evaluated.
More information about the mlpack-git
mailing list