[mlpack-git] master: Add non-training constructor and write Serialize(). (a33bc45)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:53:10 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6
>---------------------------------------------------------------
commit a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 9 02:08:03 2015 +0000
Add non-training constructor and write Serialize().
>---------------------------------------------------------------
a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6
src/mlpack/methods/perceptron/perceptron.hpp | 30 ++++++++++++++++++---
src/mlpack/methods/perceptron/perceptron_impl.hpp | 33 +++++++++++++++++++----
src/mlpack/tests/serialization_test.cpp | 33 +++++++++++++++++++++++
3 files changed, 88 insertions(+), 8 deletions(-)
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index 0c4541f..ada0154 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -32,6 +32,20 @@ class Perceptron
{
public:
/**
+ * Constructor: create the perceptron with the given number of classes and
+ * initialize the weight matrix, but do not perform any training. (Call the
+ * Train() function to perform training.)
+ *
+ * @param numClasses Number of classes in the dataset.
+ * @param dimensionality Dimensionality of the dataset.
+ * @param maxIterations Maximum number of iterations for the perceptron
+ * learning algorithm.
+ */
+ Perceptron(const size_t numClasses,
+ const size_t dimensionality,
+ const size_t maxIterations = 1000);
+
+ /**
* Constructor: constructs the perceptron by building the weights matrix,
* which is later used in classification. The number of classes should be
* specified separately, and the labels vector should contain values in the
@@ -41,13 +55,13 @@ class Perceptron
* @param data Input, training data.
* @param labels Labels of dataset.
* @param numClasses Number of classes in the dataset.
- * @param iterations Maximum number of iterations for the perceptron learning
- * algorithm.
+ * @param maxIterations Maximum number of iterations for the perceptron
+ * learning algorithm.
*/
Perceptron(const MatType& data,
const arma::Row<size_t>& labels,
const size_t numClasses,
- const size_t maxIterations);
+ const size_t maxIterations = 1000);
/**
* Alternate constructor which copies parameters from an already initiated
@@ -107,6 +121,16 @@ class Perceptron
//! Get the number of classes this perceptron has been trained for.
size_t NumClasses() const { return weights.n_cols; }
+ //! Get the weight matrix.
+ const arma::mat& Weights() const { return weights; }
+ //! Modify the weight matrix. You had better know what you are doing!
+ arma::mat& Weights() { return weights; }
+
+ //! Get the biases.
+ const arma::vec& Biases() const { return biases; }
+ //! Modify the biases. You had better know what you are doing!
+ arma::vec& Biases() { return biases; }
+
private:
//! The maximum number of iterations during training.
size_t maxIterations;
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index dca2515..be6c27e 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -13,6 +13,25 @@ namespace mlpack {
namespace perceptron {
/**
+ * Construct the perceptron with the given number of classes and maximum number
+ * of iterations.
+ */
+template<
+ typename LearnPolicy,
+ typename WeightInitializationPolicy,
+ typename MatType
+>
+Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
+ const size_t numClasses,
+ const size_t dimensionality,
+ const size_t maxIterations) :
+ maxIterations(maxIterations)
+{
+ WeightInitializationPolicy wip;
+ wip.Initialize(weights, biases, dimensionality, numClasses);
+}
+
+/**
* Constructor - constructs the perceptron. Or rather, builds the weights
* matrix, which is later used in classification. It adds a bias input vector
* of 1 to the input data to take care of the bias weights.
@@ -34,8 +53,8 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
const size_t maxIterations) :
maxIterations(maxIterations)
{
- WeightInitializationPolicy WIP;
- WIP.Initialize(weights, biases, data.n_rows, numClasses);
+ WeightInitializationPolicy wip;
+ wip.Initialize(weights, biases, data.n_rows, numClasses);
// Start training.
Train(data, labels);
@@ -65,8 +84,8 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
maxIterations(other.maxIterations)
{
// Insert a row of ones at the top of the training data set.
- WeightInitializationPolicy WIP;
- WIP.Initialize(weights, biases, data.n_rows, other.NumClasses());
+ WeightInitializationPolicy wip;
+ wip.Initialize(weights, biases, data.n_rows, other.NumClasses());
Train(data, labels, instanceWeights);
}
@@ -175,7 +194,11 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Serialize(
Archive& ar,
const unsigned int /* version */)
{
- // For now, do nothing.
+ // We just need to serialize the maximum number of iterations, the weights,
+ // and the biases.
+ ar & data::CreateNVP(maxIterations, "maxIterations");
+ ar & data::CreateNVP(weights, "weights");
+ ar & data::CreateNVP(biases, "biases");
}
} // namespace perceptron
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index e1be4ca..abdc4d4 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,12 +22,15 @@
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/methods/perceptron/perceptron.hpp>
+
using namespace mlpack;
using namespace mlpack::distribution;
using namespace mlpack::regression;
using namespace mlpack::bound;
using namespace mlpack::metric;
using namespace mlpack::tree;
+using namespace mlpack::perceptron;
using namespace arma;
using namespace boost;
using namespace boost::archive;
@@ -687,4 +690,34 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
CheckTrees(tree, xmlTree, textTree, binaryTree);
}
+BOOST_AUTO_TEST_CASE(PerceptronTest)
+{
+ // Create a perceptron. Train it randomly. Then check that it hasn't
+ // changed.
+ arma::mat data;
+ data.randu(3, 100);
+ arma::Row<size_t> labels(100);
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ {
+ if (data(1, i) > 0.5)
+ labels[i] = 0;
+ else
+ labels[i] = 1;
+ }
+
+ Perceptron<> p(data, labels, 2, 15);
+
+ Perceptron<> pXml(2, 3), pText(2, 3), pBinary(2, 3);
+ SerializeObjectAll(p, pXml, pText, pBinary);
+
+ // Now check that things are the same.
+ CheckMatrices(p.Weights(), pXml.Weights(), pText.Weights(),
+ pBinary.Weights());
+ CheckMatrices(p.Biases(), pXml.Biases(), pText.Biases(), pBinary.Biases());
+
+ BOOST_REQUIRE_EQUAL(p.MaxIterations(), pXml.MaxIterations());
+ BOOST_REQUIRE_EQUAL(p.MaxIterations(), pText.MaxIterations());
+ BOOST_REQUIRE_EQUAL(p.MaxIterations(), pBinary.MaxIterations());
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list