[mlpack-git] master: Simplify the usage of the performance function by reducing the template parameter. (38ac5cc)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Jun 1 17:28:31 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0547fb75a32eda7e273651a7e6b6a258c5885a1e...61d7876048f2208cf45d41d71f9d4baa825e2a51
>---------------------------------------------------------------
commit 38ac5ccbb3d6fb77d1b2dfb6fabc09247f1db48f
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date: Mon Jun 1 23:12:00 2015 +0200
Simplify the usage of the performance function by reducing the template parameter.
>---------------------------------------------------------------
38ac5ccbb3d6fb77d1b2dfb6fabc09247f1db48f
src/mlpack/methods/ann/performance_functions/cee_function.hpp | 7 +++----
src/mlpack/methods/ann/performance_functions/mse_function.hpp | 8 +++-----
src/mlpack/methods/ann/performance_functions/sse_function.hpp | 6 ++----
3 files changed, 8 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/methods/ann/performance_functions/cee_function.hpp b/src/mlpack/methods/ann/performance_functions/cee_function.hpp
index 66b5471..f9fcc64 100644
--- a/src/mlpack/methods/ann/performance_functions/cee_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/cee_function.hpp
@@ -23,11 +23,9 @@ namespace ann /** Artificial Neural Network. */ {
* granular way to calculate the error.
*
* @tparam Layer The layer that is connected with the output layer.
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
template<
- class Layer = NeuronLayer< >,
- typename VecType = arma::colvec
+ class Layer = NeuronLayer< >
>
class CrossEntropyErrorFunction
{
@@ -39,7 +37,8 @@ class CrossEntropyErrorFunction
* @param target Target data.
* @return cross-entropy error.
*/
- static double Error(const VecType& input, const VecType& target)
+ template<typename DataType>
+ static double Error(const DataType& input, const DataType& target)
{
if (LayerTraits<Layer>::IsBinary)
return -arma::dot(arma::trunc_log(arma::abs(target - input)), target);
diff --git a/src/mlpack/methods/ann/performance_functions/mse_function.hpp b/src/mlpack/methods/ann/performance_functions/mse_function.hpp
index ebe29f5..455b4cc 100644
--- a/src/mlpack/methods/ann/performance_functions/mse_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/mse_function.hpp
@@ -15,10 +15,7 @@ namespace ann /** Artificial Neural Network. */ {
/**
* The mean squared error performance function measures the network's
* performance according to the mean of squared errors.
- *
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
-template<typename VecType = arma::colvec>
class MeanSquaredErrorFunction
{
public:
@@ -29,9 +26,10 @@ class MeanSquaredErrorFunction
* @param target Target data.
* @return mean of squared errors.
*/
- static double Error(const VecType& input, const VecType& target)
+ template<typename DataType>
+ static double Error(const DataType& input, const DataType& target)
{
- return arma::mean(arma::square(target - input));
+ return arma::mean(arma::mean(arma::square(target - input)));
}
}; // class MeanSquaredErrorFunction
diff --git a/src/mlpack/methods/ann/performance_functions/sse_function.hpp b/src/mlpack/methods/ann/performance_functions/sse_function.hpp
index 01b5418..f7e50e5 100644
--- a/src/mlpack/methods/ann/performance_functions/sse_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/sse_function.hpp
@@ -15,10 +15,7 @@ namespace ann /** Artificial Neural Network. */ {
/**
* The sum squared error performance function measures the network's performance
* according to the sum of squared errors.
- *
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
*/
-template<typename VecType = arma::colvec>
class SumSquaredErrorFunction
{
public:
@@ -29,7 +26,8 @@ class SumSquaredErrorFunction
* @param target Target data.
* @return sum of squared errors.
*/
- static double Error(const VecType& input, const VecType& target)
+ template<typename DataType>
+ static double Error(const DataType& input, const DataType& target)
{
return arma::sum(arma::square(target - input));
}
More information about the mlpack-git
mailing list