[mlpack-git] master: Require specification of number of classes. (16cb76a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 11 07:52:59 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/a4d2dc275f6bdc74898386405decc91f072b2465...a33bc45442b3ce8830ea1a3e930c89d05c6dc9c6

>---------------------------------------------------------------

commit 16cb76a12b3e8f973e51f38bca97f0b798d3fbd5
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Sep 8 13:04:58 2015 +0000

    Require specification of number of classes.


>---------------------------------------------------------------

16cb76a12b3e8f973e51f38bca97f0b798d3fbd5
 src/mlpack/methods/adaboost/adaboost_impl.hpp      |   2 +-
 .../methods/decision_stump/decision_stump.hpp      |   4 +-
 .../methods/decision_stump/decision_stump_impl.hpp |  10 +-
 src/mlpack/methods/perceptron/perceptron.hpp       |  72 +++++++++-----
 src/mlpack/methods/perceptron/perceptron_impl.hpp  | 107 +++++++++++----------
 src/mlpack/methods/perceptron/perceptron_main.cpp  |   2 +-
 src/mlpack/tests/adaboost_test.cpp                 |  40 +++++---
 src/mlpack/tests/perceptron_test.cpp               |  12 +--
 8 files changed, 141 insertions(+), 108 deletions(-)

diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index af11586..dc585f3 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -93,7 +93,7 @@ AdaBoost<MatType, WeakLearner>::AdaBoost(
     BuildWeightMatrix(D, weights);
 
     // call the other weak learner and train the labels.
-    WeakLearner w(other, tempData, weights, labels);
+    WeakLearner w(other, tempData, labels, weights);
     w.Classify(tempData, predictedLabels);
 
     // Now from predictedLabels, build ht, the weak hypothesis
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 831bbca..411f4d1 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -67,8 +67,8 @@ class DecisionStump
    */
   DecisionStump(const DecisionStump<>& other,
                 const MatType& data,
-                const arma::rowvec& weights,
-                const arma::Row<size_t>& labels);
+                const arma::Row<size_t>& labels,
+                const arma::rowvec& weights);
 
   //! Access the splitting attribute.
   int SplitAttribute() const { return splitAttribute; }
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index d18365c..568e296 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -129,12 +129,10 @@ void DecisionStump<MatType>::Classify(const MatType& test,
  * @param isWeight Whether we need to run a weighted Decision Stump.
  */
 template <typename MatType>
-DecisionStump<MatType>::DecisionStump(
-                        const DecisionStump<>& other,
-                        const MatType& data,
-                        const arma::rowvec& weights,
-                        const arma::Row<size_t>& labels
-                        )
+DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
+                                      const MatType& data,
+                                      const arma::Row<size_t>& labels,
+                                      const arma::rowvec& weights)
 {
   numClass = other.numClass;
   bucketSize = other.bucketSize;
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index ef0d3c6..0c4541f 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -32,28 +32,22 @@ class Perceptron
 {
  public:
   /**
-   * Constructor - constructs the perceptron by building the weights
-   * matrix, which is later used in Classification.  It adds a bias input vector
-   * of 1 to the input data to take care of the bias weights.
+   * Constructor: constructs the perceptron by building the weights matrix,
+   * which is later used in classification.  The number of classes should be
+   * specified separately, and the labels vector should contain values in the
+   * range [0, numClasses - 1].  The data::NormalizeLabels() function can be
+   * used if the labels vector does not contain values in the required range.
    *
    * @param data Input, training data.
    * @param labels Labels of dataset.
+   * @param numClasses Number of classes in the dataset.
    * @param iterations Maximum number of iterations for the perceptron learning
    *     algorithm.
    */
   Perceptron(const MatType& data,
              const arma::Row<size_t>& labels,
-             const int maxIterations);
-
-  /**
-   * Classification function. After training, use the weights matrix to
-   * classify test, and put the predicted classes in predictedLabels.
-   *
-   * @param test Testing data or data to classify.
-   * @param predictedLabels Vector to store the predicted classes after
-   *     classifying test.
-   */
-  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+             const size_t numClasses,
+             const size_t maxIterations);
 
   /**
    * Alternate constructor which copies parameters from an already initiated
@@ -67,8 +61,37 @@ class Perceptron
    */
   Perceptron(const Perceptron<>& other,
              const MatType& data,
-             const arma::rowvec& D,
-             const arma::Row<size_t>& labels);
+             const arma::Row<size_t>& labels,
+             const arma::rowvec& instanceWeights);
+
+  /**
+   * Train the perceptron on the given data for up to the maximum number of
+   * iterations (specified in the constructor or through MaxIterations()).  A
+   * single iteration corresponds to a single pass through the data, so if you
+   * want to pass through the dataset only once, set MaxIterations() to 1.
+   *
+   * This training does not reset the model weights, so you can call Train() on
+   * multiple datasets sequentially.
+   *
+   * @param data Dataset on which training should be performed.
+   * @param labels Labels of the dataset.  Make sure that these labels don't
+   *      contain any values greater than NumClasses()!
+   * @param instanceWeights Cost matrix. Stores the cost of mispredicting
+   *      instances.  This is useful for boosting.
+   */
+  void Train(const MatType& data,
+             const arma::Row<size_t>& labels,
+             const arma::rowvec& instanceWeights = arma::rowvec());
+
+  /**
+   * Classification function. After training, use the weights matrix to
+   * classify test, and put the predicted classes in predictedLabels.
+   *
+   * @param test Testing data or data to classify.
+   * @param predictedLabels Vector to store the predicted classes after
+   *     classifying test.
+   */
+  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
   /**
    * Serialize the perceptron.
@@ -76,6 +99,14 @@ class Perceptron
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
+  //! Get the maximum number of iterations.
+  size_t MaxIterations() const { return maxIterations; }
+  //! Modify the maximum number of iterations.
+  size_t& MaxIterations() { return maxIterations; }
+
+  //! Get the number of classes this perceptron has been trained for.
+  size_t NumClasses() const { return weights.n_cols; }
+
 private:
   //! The maximum number of iterations during training.
   size_t maxIterations;
@@ -90,15 +121,6 @@ private:
 
   //! The biases for each class.
   arma::vec biases;
-
-  /**
-   * Training Function. It trains on trainData using the cost matrix D
-   *
-   * @param D Cost matrix. Stores the cost of mispredicting instances
-   */
-  void Train(const MatType& data,
-             const arma::Row<size_t>& labels,
-             const arma::rowvec& D = arma::rowvec());
 };
 
 } // namespace perceptron
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 641e001..dca2515 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -30,24 +30,54 @@ template<
 Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
     const MatType& data,
     const arma::Row<size_t>& labels,
-    const int maxIterations) :
+    const size_t numClasses,
+    const size_t maxIterations) :
     maxIterations(maxIterations)
 {
   WeightInitializationPolicy WIP;
-  WIP.Initialize(weights, biases, data.n_rows, arma::max(labels) + 1);
+  WIP.Initialize(weights, biases, data.n_rows, numClasses);
 
   // Start training.
   Train(data, labels);
 }
 
+/**
+ * Alternate constructor which copies parameters from an already initiated
+ * perceptron.
+ *
+ * @param other The other initiated Perceptron object from which we copy the
+ *      values from.
+ * @param data The data on which to train this Perceptron object on.
+ * @param instanceWeights Weight vector to use while training. For boosting
+ *      purposes.
+ * @param labels The labels of data.
+ */
+template<
+    typename LearnPolicy,
+    typename WeightInitializationPolicy,
+    typename MatType
+>
+Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
+    const Perceptron<>& other,
+    const MatType& data,
+    const arma::Row<size_t>& labels,
+    const arma::rowvec& instanceWeights) :
+    maxIterations(other.maxIterations)
+{
+  // Insert a row of ones at the top of the training data set.
+  WeightInitializationPolicy WIP;
+  WIP.Initialize(weights, biases, data.n_rows, other.NumClasses());
+
+  Train(data, labels, instanceWeights);
+}
 
 /**
  * Classification function. After training, use the weights matrix to classify
  * test, and put the predicted classes in predictedLabels.
  *
- * @param test testing data or data to classify.
- * @param predictedLabels vector to store the predicted classes after
- *      classifying test
+ * @param test Testing data or data to classify.
+ * @param predictedLabels Vector to store the predicted classes after
+ *      classifying test.
  */
 template<
     typename LearnPolicy,
@@ -71,50 +101,13 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
 }
 
 /**
- * Alternate constructor which copies parameters from an already initiated
- * perceptron.
- *
- * @param other The other initiated Perceptron object from which we copy the
- *     values from.
- * @param data The data on which to train this Perceptron object on.
- * @param D Weight vector to use while training. For boosting purposes.
- * @param labels The labels of data.
- */
-template<
-    typename LearnPolicy,
-    typename WeightInitializationPolicy,
-    typename MatType
->
-Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
-    const Perceptron<>& other,
-    const MatType& data,
-    const arma::rowvec& D,
-    const arma::Row<size_t>& labels) :
-    maxIterations(other.maxIterations)
-{
-  // Insert a row of ones at the top of the training data set.
-  WeightInitializationPolicy WIP;
-  WIP.Initialize(weights, biases, data.n_rows, arma::max(labels) + 1);
-
-  Train(data, labels, D);
-}
-
-//! Serialize the perceptron.
-template<typename LearnPolicy,
-         typename WeightInitializationPolicy,
-         typename MatType>
-template<typename Archive>
-void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Serialize(
-    Archive& ar,
-    const unsigned int /* version */)
-{
-  // For now, do nothing.
-}
-
-/**
- * Training Function. It trains on trainData using the cost matrix D
+ * Training function.  It trains on trainData using the cost matrix
+ * instanceWeights.
  *
- * @param D Cost matrix. Stores the cost of mispredicting instances
+ * @param data Data to train on.
+ * @param labels Labels of data.
+ * @param instanceWeights Cost matrix. Stores the cost of mispredicting
+ *      instances.  This is useful for boosting.
  */
 template<
     typename LearnPolicy,
@@ -124,7 +117,7 @@ template<
 void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
     const MatType& data,
     const arma::Row<size_t>& labels,
-    const arma::rowvec& D)
+    const arma::rowvec& instanceWeights)
 {
   size_t j, i = 0;
   bool converged = false;
@@ -134,7 +127,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
 
   LearnPolicy LP;
 
-  const bool hasWeights = (D.n_elem > 0);
+  const bool hasWeights = (instanceWeights.n_elem > 0);
 
   while ((i < maxIterations) && (!converged))
   {
@@ -164,7 +157,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
         // the correct class.
         if (hasWeights)
           LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
-              D(j));
+              instanceWeights(j));
         else
           LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow,
               tempLabel);
@@ -173,6 +166,18 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
   }
 }
 
+//! Serialize the perceptron.
+template<typename LearnPolicy,
+         typename WeightInitializationPolicy,
+         typename MatType>
+template<typename Archive>
+void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  // For now, do nothing.
+}
+
 } // namespace perceptron
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/perceptron/perceptron_main.cpp b/src/mlpack/methods/perceptron/perceptron_main.cpp
index d846d66..49fec08 100644
--- a/src/mlpack/methods/perceptron/perceptron_main.cpp
+++ b/src/mlpack/methods/perceptron/perceptron_main.cpp
@@ -105,7 +105,7 @@ int main(int argc, char** argv)
 
   // Create and train the classifier.
   Timer::Start("Training");
-  Perceptron<> p(trainingData, labels.t(), iterations);
+  Perceptron<> p(trainingData, labels.t(), max(labels) + 1, iterations);
   Timer::Stop("Training");
 
   // Time the running of the Perceptron Classifier.
diff --git a/src/mlpack/tests/adaboost_test.cpp b/src/mlpack/tests/adaboost_test.cpp
index 9a66ed6..ef8269c 100644
--- a/src/mlpack/tests/adaboost_test.cpp
+++ b/src/mlpack/tests/adaboost_test.cpp
@@ -38,9 +38,10 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundIris)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 400;
+  int perceptronIter = 400;
 
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
 
   // Define parameters for the adaboost
   int iterations = 100;
@@ -79,10 +80,11 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorIris)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 400;
+  int perceptronIter = 400;
 
   arma::Row<size_t> perceptronPrediction(labels.n_cols);
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
   p.Classify(inputData, perceptronPrediction);
 
   int countWeakLearnerError = 0;
@@ -126,9 +128,10 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundVertebralColumn)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 800;
+  int perceptronIter = 800;
 
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
 
   // Define parameters for the adaboost
   int iterations = 50;
@@ -167,10 +170,11 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorVertebralColumn)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 800;
+  int perceptronIter = 800;
 
   arma::Row<size_t> perceptronPrediction(labels.n_cols);
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
   p.Classify(inputData, perceptronPrediction);
 
   int countWeakLearnerError = 0;
@@ -215,9 +219,10 @@ BOOST_AUTO_TEST_CASE(HammingLossBoundNonLinearSepData)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 800;
+  int perceptronIter = 800;
 
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
 
   // Define parameters for the adaboost
   int iterations = 50;
@@ -256,10 +261,11 @@ BOOST_AUTO_TEST_CASE(WeakLearnerErrorNonLinearSepData)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 800;
+  int perceptronIter = 800;
 
   arma::Row<size_t> perceptronPrediction(labels.n_cols);
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
   p.Classify(inputData, perceptronPrediction);
 
   int countWeakLearnerError = 0;
@@ -596,7 +602,7 @@ BOOST_AUTO_TEST_CASE(ClassifyTest_VERTEBRALCOL)
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
 
-  int perceptron_iter = 1000;
+  int perceptronIter = 1000;
 
   arma::mat testData;
 
@@ -609,7 +615,8 @@ BOOST_AUTO_TEST_CASE(ClassifyTest_VERTEBRALCOL)
     BOOST_FAIL("Cannot load labels for vc2_test_labels.txt");
 
   arma::Row<size_t> perceptronPrediction(labels.n_cols);
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
   p.Classify(inputData, perceptronPrediction);
 
   // Define parameters for the adaboost
@@ -714,9 +721,10 @@ BOOST_AUTO_TEST_CASE(ClassifyTest_IRIS)
 
   // Define your own weak learner, perceptron in this case.
   // Run the perceptron for perceptron_iter iterations.
-  int perceptron_iter = 800;
+  int perceptronIter = 800;
 
-  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+  perceptron::Perceptron<> p(inputData, labels.row(0), max(labels.row(0)) + 1,
+      perceptronIter);
 
   // Define parameters for the adaboost
   int iterations = 50;
diff --git a/src/mlpack/tests/perceptron_test.cpp b/src/mlpack/tests/perceptron_test.cpp
index 82445d1..8f2de1e 100644
--- a/src/mlpack/tests/perceptron_test.cpp
+++ b/src/mlpack/tests/perceptron_test.cpp
@@ -28,7 +28,7 @@ BOOST_AUTO_TEST_CASE(And)
   Mat<size_t> labels;
   labels << 0 << 0 << 1 << 0;
 
-  Perceptron<> p(trainData, labels.row(0), 1000);
+  Perceptron<> p(trainData, labels.row(0), 2, 1000);
 
   mat testData;
   testData << 0 << 1 << 1 << 0 << endr
@@ -54,7 +54,7 @@ BOOST_AUTO_TEST_CASE(Or)
   Mat<size_t> labels;
   labels << 1 << 1 << 1 << 0;
 
-  Perceptron<> p(trainData, labels.row(0), 1000);
+  Perceptron<> p(trainData, labels.row(0), 2, 1000);
 
   mat testData;
   testData << 0 << 1 << 1 << 0 << endr
@@ -81,7 +81,7 @@ BOOST_AUTO_TEST_CASE(Random3)
   Mat<size_t> labels;
   labels << 0 << 0 << 0 << 1 << 1 << 1 << 2 << 2 << 2;
 
-  Perceptron<> p(trainData, labels.row(0), 1000);
+  Perceptron<> p(trainData, labels.row(0), 3, 1000);
 
   mat testData;
   testData << 0 << 1 << 1 << endr
@@ -107,7 +107,7 @@ BOOST_AUTO_TEST_CASE(TwoPoints)
   Mat<size_t> labels;
   labels << 0 << 1;
 
-  Perceptron<> p(trainData, labels.row(0), 1000);
+  Perceptron<> p(trainData, labels.row(0), 2, 1000);
 
   mat testData;
   testData << 0 << 1 << endr
@@ -135,7 +135,7 @@ BOOST_AUTO_TEST_CASE(NonLinearlySeparableDataset)
   labels << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1
          << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1;
 
-  Perceptron<> p(trainData, labels.row(0), 1000);
+  Perceptron<> p(trainData, labels.row(0), 2, 1000);
 
   mat testData;
   testData << 3 << 4   << 5   << 6   << endr
@@ -161,7 +161,7 @@ BOOST_AUTO_TEST_CASE(SecondaryConstructor)
   labels << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1
          << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1;
 
-  Perceptron<> p1(trainData, labels.row(0), 1000);
+  Perceptron<> p1(trainData, labels.row(0), 2, 1000);
 
   Perceptron<> p2(p1);
 }



More information about the mlpack-git mailing list