[mlpack-git] master: Add network traits class. (89e0ddc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jan 7 04:10:07 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/65bd0550f4ab82a07183603e1a2e9fa7f2bb43d7...89e0ddc4f1e6cae89d5715501b26084d0cb22135

>---------------------------------------------------------------

commit 89e0ddc4f1e6cae89d5715501b26084d0cb22135
Author: Marcus Edel <marcus.edel at fu-berlin.de>
Date:   Wed Jan 7 10:10:00 2015 +0100

    Add network traits class.


>---------------------------------------------------------------

89e0ddc4f1e6cae89d5715501b26084d0cb22135
 src/mlpack/methods/ann/ffnn.hpp           | 17 +++++++++++++
 src/mlpack/methods/ann/network_traits.hpp | 40 +++++++++++++++++++++++++++++++
 src/mlpack/methods/ann/rnn.hpp            | 16 +++++++++++++
 3 files changed, 73 insertions(+)

diff --git a/src/mlpack/methods/ann/ffnn.hpp b/src/mlpack/methods/ann/ffnn.hpp
index 07ffe53..55f770c 100644
--- a/src/mlpack/methods/ann/ffnn.hpp
+++ b/src/mlpack/methods/ann/ffnn.hpp
@@ -9,6 +9,7 @@
 
 #include <mlpack/core.hpp>
 
+#include <mlpack/methods/ann/network_traits.hpp>
 #include <mlpack/methods/ann/performance_functions/cee_function.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
 
@@ -316,7 +317,23 @@ class FFNN
     double err;
 }; // class FFNN
 
+
+//! Network traits for the FFNN network.
+template <
+  typename ConnectionTypes,
+  typename OutputLayerType,
+  class PerformanceFunction
+>
+class NetworkTraits<
+    FFNN<ConnectionTypes, OutputLayerType, PerformanceFunction> >
+{
+ public:
+  static const bool IsFNN = true;
+  static const bool IsRNN = false;
+};
+
 }; // namespace ann
 }; // namespace mlpack
 
 #endif
+
diff --git a/src/mlpack/methods/ann/network_traits.hpp b/src/mlpack/methods/ann/network_traits.hpp
new file mode 100644
index 0000000..4c74f9a
--- /dev/null
+++ b/src/mlpack/methods/ann/network_traits.hpp
@@ -0,0 +1,40 @@
+/**
+ * @file network_traits.hpp
+ * @author Marcus Edel
+ *
+ * NetworkTraits class, a template class to get information about various
+ * networks.
+ */
+#ifndef __MLPACK_METHOS_ANN_NETWORK_TRAITS_HPP
+#define __MLPACK_METHOS_ANN_NETWORK_TRAITS_HPP
+
+namespace mlpack {
+namespace ann {
+
+/**
+ * This is a template class that can provide information about various
+ * networks. By default, this class will provide the weakest possible
+ * assumptions on networks, and each network should override values as
+ * necessary. If a network doesn't need to override a value, then there's no
+ * need to write a NetworkTraits specialization for that class.
+ */
+template<typename NetworkType>
+class NetworkTraits
+{
+ public:
+  /**
+   * This is true if the network is a feed forward neural network.
+   */
+  static const bool IsFNN = false;
+
+  /**
+   * This is true if the network is a recurrent neural network.
+   */
+  static const bool IsRNN = false;
+};
+
+}; // namespace ann
+}; // namespace mlpack
+
+#endif
+
diff --git a/src/mlpack/methods/ann/rnn.hpp b/src/mlpack/methods/ann/rnn.hpp
index 4ecb499..d5862ad 100644
--- a/src/mlpack/methods/ann/rnn.hpp
+++ b/src/mlpack/methods/ann/rnn.hpp
@@ -11,6 +11,7 @@
 
 #include <boost/ptr_container/ptr_vector.hpp>
 
+#include <mlpack/methods/ann/network_traits.hpp>
 #include <mlpack/methods/ann/performance_functions/cee_function.hpp>
 #include <mlpack/methods/ann/layer/layer_traits.hpp>
 #include <mlpack/methods/ann/connections/connection_traits.hpp>
@@ -74,6 +75,7 @@ class RNN
 
       // Reset the overall error.
       err = 0;
+      error = MatType(target.n_elem, input.n_rows);
 
       // Iterate through the input sequence and perform the feed forward pass.
       for (seqNum = 0; seqNum < input.n_rows; seqNum++)
@@ -581,7 +583,21 @@ class RNN
     OutputLayerType& outputLayer;
 }; // class RNN
 
+//! Network traits for the FFNN network.
+template <
+  typename ConnectionTypes,
+  typename OutputLayerType,
+  class PerformanceFunction
+>
+class NetworkTraits<RNN<ConnectionTypes, OutputLayerType, PerformanceFunction> >
+{
+ public:
+  static const bool IsFNN = false;
+  static const bool IsRNN = true;
+};
+
 }; // namespace ann
 }; // namespace mlpack
 
 #endif
+



More information about the mlpack-git mailing list