[mlpack-git] master: No need to specify the type of the input data. (dd18858)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Jul 30 17:02:07 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6ee21879488fe98612a4619b17f8b51e8da5215b...dd188581a86e64a0e0dc7854e1c7075d6c8bfe90

>---------------------------------------------------------------

commit dd188581a86e64a0e0dc7854e1c7075d6c8bfe90
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Wed Jul 29 20:49:28 2015 +0200

    No need to specify the type of the input data.


>---------------------------------------------------------------

dd188581a86e64a0e0dc7854e1c7075d6c8bfe90
 .../ann/layer/multiclass_classification_layer.hpp  | 33 +++++++---------------
 1 file changed, 10 insertions(+), 23 deletions(-)

diff --git a/src/mlpack/methods/ann/layer/multiclass_classification_layer.hpp b/src/mlpack/methods/ann/layer/multiclass_classification_layer.hpp
index 32f9878..f61863a 100644
--- a/src/mlpack/methods/ann/layer/multiclass_classification_layer.hpp
+++ b/src/mlpack/methods/ann/layer/multiclass_classification_layer.hpp
@@ -21,14 +21,7 @@ namespace ann /** Artificial Neural Network. */ {
  * A convenience typedef is given:
  *
  *  - ClassificationLayer
- *
- * @tparam MatType Type of data (arma::mat or arma::sp_mat).
- * @tparam VecType Type of data (arma::colvec, arma::mat or arma::sp_mat).
  */
-template <
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
->
 class MulticlassClassificationLayer
 {
  public:
@@ -49,9 +42,10 @@ class MulticlassClassificationLayer
    * @param error The calculated error with respect to the input activation and
    * the given target.
    */
-  void CalculateError(const VecType& inputActivations,
-                      const VecType& target,
-                      VecType& error)
+  template<typename DataType>
+  void CalculateError(const DataType& inputActivations,
+                      const DataType& target,
+                      DataType& error)
   {
     error = inputActivations - target;
   }
@@ -62,18 +56,16 @@ class MulticlassClassificationLayer
    * @param inputActivations Input data used to calculate the output class.
    * @param output Output class of the input activation.
    */
-  void OutputClass(const VecType& inputActivations, VecType& output)
+  template<typename DataType>
+  void OutputClass(const DataType& inputActivations, DataType& output)
   {
     output = inputActivations;
   }
 }; // class MulticlassClassificationLayer
 
 //! Layer traits for the multiclass classification layer.
-template <
-    typename MatType,
-    typename VecType
->
-class LayerTraits<MulticlassClassificationLayer<MatType, VecType> >
+template <>
+class LayerTraits<MulticlassClassificationLayer>
 {
  public:
   static const bool IsBinary = false;
@@ -82,14 +74,9 @@ class LayerTraits<MulticlassClassificationLayer<MatType, VecType> >
 };
 
 /***
- * Standard Input-Layer using the tanh activation function and the
- * Nguyen-Widrow method to initialize the weights.
+ * Alias ClassificationLayer.
  */
-template <
-    typename MatType = arma::mat,
-    typename VecType = arma::colvec
->
-using ClassificationLayer = MulticlassClassificationLayer<MatType, VecType>;
+using ClassificationLayer = MulticlassClassificationLayer;
 
 }; // namespace ann
 }; // namespace mlpack



More information about the mlpack-git mailing list