[mlpack-git] master: Simplify the usage of the performance function by reducing the template parameter. (38ac5cc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jun 1 17:28:31 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0547fb75a32eda7e273651a7e6b6a258c5885a1e...61d7876048f2208cf45d41d71f9d4baa825e2a51

>---------------------------------------------------------------

commit 38ac5ccbb3d6fb77d1b2dfb6fabc09247f1db48f
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Mon Jun 1 23:12:00 2015 +0200

    Simplify the usage of the performance function by reducing the template parameter.


>---------------------------------------------------------------

38ac5ccbb3d6fb77d1b2dfb6fabc09247f1db48f
 src/mlpack/methods/ann/performance_functions/cee_function.hpp | 7 +++----
 src/mlpack/methods/ann/performance_functions/mse_function.hpp | 8 +++-----
 src/mlpack/methods/ann/performance_functions/sse_function.hpp | 6 ++----
 3 files changed, 8 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/methods/ann/performance_functions/cee_function.hpp b/src/mlpack/methods/ann/performance_functions/cee_function.hpp
index 66b5471..f9fcc64 100644
--- a/src/mlpack/methods/ann/performance_functions/cee_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/cee_function.hpp
@@ -23,11 +23,9 @@ namespace ann /** Artificial Neural Network. */ {
  * granular way to calculate the error.
  *
  * @tparam Layer The layer that is connected with the output layer.
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
 template<
-    class Layer = NeuronLayer< >,
-    typename VecType = arma::colvec
+    class Layer = NeuronLayer< >
 >
 class CrossEntropyErrorFunction
 {
@@ -39,7 +37,8 @@ class CrossEntropyErrorFunction
    * @param target Target data.
    * @return cross-entropy error.
    */
-  static double Error(const VecType& input, const VecType& target)
+  template<typename DataType>
+  static double Error(const DataType& input, const DataType& target)
   {
     if (LayerTraits<Layer>::IsBinary)
       return -arma::dot(arma::trunc_log(arma::abs(target - input)), target);
diff --git a/src/mlpack/methods/ann/performance_functions/mse_function.hpp b/src/mlpack/methods/ann/performance_functions/mse_function.hpp
index ebe29f5..455b4cc 100644
--- a/src/mlpack/methods/ann/performance_functions/mse_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/mse_function.hpp
@@ -15,10 +15,7 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * The mean squared error performance function measures the network's
  * performance according to the mean of squared errors.
- *
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
-template<typename VecType = arma::colvec>
 class MeanSquaredErrorFunction
 {
   public:
@@ -29,9 +26,10 @@ class MeanSquaredErrorFunction
    * @param target Target data.
    * @return mean of squared errors.
    */
-  static double Error(const VecType& input, const VecType& target)
+  template<typename DataType>
+  static double Error(const DataType& input, const DataType& target)
   {
-    return arma::mean(arma::square(target - input));
+    return arma::mean(arma::mean(arma::square(target - input)));
   }
 
 }; // class MeanSquaredErrorFunction
diff --git a/src/mlpack/methods/ann/performance_functions/sse_function.hpp b/src/mlpack/methods/ann/performance_functions/sse_function.hpp
index 01b5418..f7e50e5 100644
--- a/src/mlpack/methods/ann/performance_functions/sse_function.hpp
+++ b/src/mlpack/methods/ann/performance_functions/sse_function.hpp
@@ -15,10 +15,7 @@ namespace ann /** Artificial Neural Network. */ {
 /**
  * The sum squared error performance function measures the network's performance
  * according to the sum of squared errors.
- *
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
-template<typename VecType = arma::colvec>
 class SumSquaredErrorFunction
 {
   public:
@@ -29,7 +26,8 @@ class SumSquaredErrorFunction
    * @param target Target data.
    * @return sum of squared errors.
    */
-  static double Error(const VecType& input, const VecType& target)
+  template<typename DataType>
+  static double Error(const DataType& input, const DataType& target)
   {
     return arma::sum(arma::square(target - input));
   }



More information about the mlpack-git mailing list