[mlpack-git] master,mlpack-1.0.x: Perceptron Added (501ef6b)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:49:27 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 501ef6b14d079f5e7d5e271343753eb42b42b967
Author: Udit Saxena <saxena.udit at gmail.com>
Date:   Tue Jun 24 06:58:46 2014 +0000

    Perceptron Added


>---------------------------------------------------------------

501ef6b14d079f5e7d5e271343753eb42b42b967
 src/mlpack/methods/CMakeLists.txt                  |   1 +
 .../{decision_stump => perceptron}/CMakeLists.txt  |  15 ++-
 .../InitializationMethods}/CMakeLists.txt          |   2 +-
 .../InitializationMethods/random_init.hpp          |  31 +++++
 .../perceptron/InitializationMethods/zero_init.hpp |  34 +++++
 .../LearnPolicy}/CMakeLists.txt                    |   3 +-
 .../perceptron/LearnPolicy/SimpleWeightUpdate.hpp  |  53 ++++++++
 src/mlpack/methods/perceptron/perceptron.hpp       |  86 ++++++++++++
 src/mlpack/methods/perceptron/perceptron_impl.cpp  | 118 ++++++++++++++++
 .../perceptron_main.cpp}                           |  43 +++---
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 src/mlpack/tests/perceptron_test.cpp               | 150 +++++++++++++++++++++
 12 files changed, 502 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index b930aaa..d9eea39 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -22,6 +22,7 @@ set(DIRS
   nmf
 #  lmf
   pca
+  perceptron
   radical
   range_search
   rann
diff --git a/src/mlpack/methods/decision_stump/CMakeLists.txt b/src/mlpack/methods/perceptron/CMakeLists.txt
similarity index 67%
copy from src/mlpack/methods/decision_stump/CMakeLists.txt
copy to src/mlpack/methods/perceptron/CMakeLists.txt
index 0bc9b8b..c25c549 100644
--- a/src/mlpack/methods/decision_stump/CMakeLists.txt
+++ b/src/mlpack/methods/perceptron/CMakeLists.txt
@@ -3,8 +3,8 @@ cmake_minimum_required(VERSION 2.8)
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  decision_stump.hpp
-  decision_stump_impl.hpp
+  perceptron.hpp
+  perceptron_impl.cpp
 )
 
 # Add directory name to sources.
@@ -16,11 +16,14 @@ endforeach()
 # the parent scope).
 set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
-add_executable(dec_stu
-  decision_stump_main.cpp
+add_subdirectory(InitializationMethods)
+add_subdirectory(LearnPolicy)
+
+add_executable(percep
+  perceptron_main.cpp
 )
-target_link_libraries(dec_stu
+target_link_libraries(percep
   mlpack
 )
 
-install(TARGETS dec_stu RUNTIME DESTINATION bin)
+install(TARGETS percep RUNTIME DESTINATION bin)
diff --git a/src/mlpack/methods/amf/init_rules/CMakeLists.txt b/src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt
similarity index 95%
copy from src/mlpack/methods/amf/init_rules/CMakeLists.txt
copy to src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt
index a31d281..d5d9c31 100644
--- a/src/mlpack/methods/amf/init_rules/CMakeLists.txt
+++ b/src/mlpack/methods/perceptron/InitializationMethods/CMakeLists.txt
@@ -2,7 +2,7 @@
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
   random_init.hpp
-  random_acol_init.hpp
+  zero_init.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp b/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp
new file mode 100644
index 0000000..7cdeb19
--- /dev/null
+++ b/src/mlpack/methods/perceptron/InitializationMethods/random_init.hpp
@@ -0,0 +1,31 @@
+/*
+ *  @file: randominit.hpp
+ *  @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOS_PERCEPTRON_RANDOMINIT
+#define _MLPACK_METHOS_PERCEPTRON_RANDOMINIT
+
+#include <mlpack/core.hpp>
+/*
+This class is used to initialize weights for the 
+weightVectors matrix in a random manner. 
+*/
+namespace mlpack {
+namespace perceptron {
+  class RandomInitialization
+  {
+  public:
+    RandomInitialization()
+    { }
+
+    inline static void initialize(arma::mat& W, size_t row, size_t col)
+    {
+      W = arma::randu<arma::mat>(row,col);
+    }
+  }; // class RandomInitialization
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp b/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp
new file mode 100644
index 0000000..7115c81
--- /dev/null
+++ b/src/mlpack/methods/perceptron/InitializationMethods/zero_init.hpp
@@ -0,0 +1,34 @@
+/*
+ *  @file: zeroinit.hpp
+ *  @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOS_PERCEPTRON_ZEROINIT
+#define _MLPACK_METHOS_PERCEPTRON_ZEROINIT
+
+#include <mlpack/core.hpp>
+/*
+This class is used to initialize the matrix
+weightVectors to zero.
+*/
+namespace mlpack {
+namespace perceptron {
+  class ZeroInitialization
+  {
+  public:
+    ZeroInitialization()
+    { }
+
+    inline static void initialize(arma::mat& W, size_t row, size_t col)
+    {
+      arma::mat tempWeights(row, col);
+      tempWeights.fill(0.0);
+
+      W = tempWeights;
+    }
+  }; // class ZeroInitialization
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/amf/init_rules/CMakeLists.txt b/src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt
similarity index 91%
copy from src/mlpack/methods/amf/init_rules/CMakeLists.txt
copy to src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt
index a31d281..a07bc01 100644
--- a/src/mlpack/methods/amf/init_rules/CMakeLists.txt
+++ b/src/mlpack/methods/perceptron/LearnPolicy/CMakeLists.txt
@@ -1,8 +1,7 @@
 # Define the files we need to compile
 # Anything not in this list will not be compiled into MLPACK.
 set(SOURCES
-  random_init.hpp
-  random_acol_init.hpp
+  SimpleWeightUpdate.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp b/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp
new file mode 100644
index 0000000..893ed0d
--- /dev/null
+++ b/src/mlpack/methods/perceptron/LearnPolicy/SimpleWeightUpdate.hpp
@@ -0,0 +1,53 @@
+/*
+ *  @file: SimpleWeightUpdate.hpp
+ *  @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHOD_PERCEPTRON_LEARN_SIMPLEWEIGHTUPDATE
+#define _MLPACK_METHOD_PERCEPTRON_LEARN_SIMPLEWEIGHTUPDATE
+
+#include <mlpack/core.hpp>
+/*
+This class is used to update the weightVectors matrix according to 
+the simple update rule as discussed by Rosenblatt:
+  if a vector x has been incorrectly classified by a weight w, 
+  then w = w - x
+  and  w'= w'+ x
+  where w' is the weight vector which correctly classifies x.
+*/
+namespace mlpack {
+namespace perceptron {
+
+class SimpleWeightUpdate 
+{
+public:
+  SimpleWeightUpdate()
+  { }
+  /*
+  This function is called to update the weightVectors matrix. 
+  It decreases the weights of the incorrectly classified class while
+  increasing the weight of the correct class it should have been classified to.
+  
+  @param: trainData - the training dataset.
+  @param: weightVectors - matrix of weight vectors.
+  @param: rowIndex - index of the row which has been incorrectly predicted.
+  @param: labelIndex - index of the vector in trainData.
+  @param: vectorIndex - index of the class which should have been predicted.
+ */
+  void UpdateWeights(const arma::mat& trainData, arma::mat& weightVectors,
+                     size_t labelIndex, size_t vectorIndex, size_t rowIndex )
+  {
+    arma::mat instance = trainData.col(labelIndex);
+  
+    weightVectors.row(rowIndex) = weightVectors.row(rowIndex) - 
+                               instance.t();
+
+    weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) + 
+                                 instance.t();
+  }
+};
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
new file mode 100644
index 0000000..7d26875
--- /dev/null
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -0,0 +1,86 @@
+/*
+ * @file: perceptron.hpp
+ * @author: Udit Saxena
+ *
+ *
+ * Definition of Perceptron
+ */
+
+#ifndef _MLPACK_METHODS_PERCEPTRON_HPP
+#define _MLPACK_METHODS_PERCEPTRON_HPP
+
+#include <mlpack/core.hpp>
+#include "InitializationMethods/zero_init.hpp"
+#include "InitializationMethods/random_init.hpp"
+#include "LearnPolicy/SimpleWeightUpdate.hpp"
+
+
+namespace mlpack {
+namespace perceptron {
+
+template <typename LearnPolicy = SimpleWeightUpdate, 
+          typename WeightInitializationPolicy = ZeroInitialization, 
+          typename MatType = arma::mat>
+class Perceptron
+{
+  /*
+  This class implements a simple perceptron i.e. a single layer 
+  neural network. It converges if the supplied training dataset is 
+  linearly separable.
+
+  LearnPolicy: Options of SimpleWeightUpdate and GradientDescent.
+  WeightInitializationPolicy: Option of ZeroInitialization and 
+                              RandomInitialization.
+  */
+public:
+  /*
+  Constructor - Constructs the perceptron. Or rather, builds the weightVectors
+  matrix, which is later used in Classification. 
+  It adds a bias input vector of 1 to the input data to take care of the bias
+  weights.
+
+  @param: data - Input, training data.
+  @param: labels - Labels of dataset.
+  @param: iterations - maximum number of iterations the perceptron
+                       learn algorithm is to be run.
+  */
+  Perceptron(const MatType& data, const arma::Row<size_t>& labels, int iterations);
+
+  /*
+  Classification function. After training, use the weightVectors matrix to 
+  classify test, and put the predicted classes in predictedLabels.
+
+  @param: test - testing data or data to classify. 
+  @param: predictedLabels - vector to store the predicted classes after
+                            classifying test
+  */
+  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
+private:
+  
+  /* Stores the class labels for the input data*/
+  arma::Row<size_t> classLabels;
+
+  /* Stores the weight vectors for each of the input class labels. */
+  arma::mat weightVectors;
+
+  /* Stores the training data to be used later on in UpdateWeights.*/
+  arma::mat trainData;
+
+  /*
+  This function is called by the constructor to update the weightVectors
+  matrix. It decreases the weights of the incorrectly classified class while
+  increasing the weight of the correct class it should have been classified to.
+
+  @param: rowIndex - index of the row which has been incorrectly predicted.
+  @param: labelIndex - index of the vector in trainData.
+  @param: vectorIndex - index of the class which should have been predicted.
+  */
+  // void UpdateWeights(size_t rowIndex, size_t labelIndex, size_t vectorIndex);
+};
+} // namespace perceptron
+} // namespace mlpack
+
+#include "perceptron_impl.cpp"
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.cpp b/src/mlpack/methods/perceptron/perceptron_impl.cpp
new file mode 100644
index 0000000..b29c722
--- /dev/null
+++ b/src/mlpack/methods/perceptron/perceptron_impl.cpp
@@ -0,0 +1,118 @@
+/*
+ *  @file: perceptron_impl.hpp
+ *  @author: Udit Saxena
+ *
+ */
+
+#ifndef _MLPACK_METHODS_PERCEPTRON_IMPL_CPP
+#define _MLPACK_METHODS_PERCEPTRON_IMPL_CPP
+
+#include "perceptron.hpp"
+
+namespace mlpack {
+namespace perceptron {
+
+/*
+  Constructor - Constructs the perceptron. Or rather, builds the weightVectors
+  matrix, which is later used in Classification. 
+  It adds a bias input vector of 1 to the input data to take care of the bias
+  weights.
+
+  @param: data - Input, training data.
+  @param: labels - Labels of dataset.
+   @param: iterations - maximum number of iterations the perceptron
+                       learn algorithm is to be run.
+*/
+template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(const MatType& data,
+                                const arma::Row<size_t>& labels, int iterations)
+{
+  arma::Row<size_t> uniqueLabels = arma::unique(labels);
+
+  WeightInitializationPolicy WIP;
+  WIP.initialize(weightVectors, uniqueLabels.n_elem, data.n_rows + 1);
+  
+  // Start training.
+  classLabels = labels; 
+
+  trainData = data;
+  // inserting a row of 1's at the top of the training data set.
+  MatType zOnes(1, data.n_cols);
+  zOnes.fill(1);
+  trainData.insert_rows(0, zOnes);
+
+  int j, i = 0, converged = 0;
+  size_t tempLabel; 
+  arma::uword maxIndexRow, maxIndexCol;
+  double maxVal;
+  arma::mat tempLabelMat;
+
+  LearnPolicy LP;
+
+  while ((i < iterations) && (!converged))
+  {
+    // This outer loop is for each iteration, 
+    // and we use the 'converged' variable for noting whether or not
+    // convergence has been reached.
+    i++;
+    converged = 1;
+
+    // Now this inner loop is for going through the dataset in each iteration
+    for (j = 0; j < data.n_cols; j++)
+    {
+      // Multiplying for each variable and checking 
+      // whether the current weight vector correctly classifies this.
+      tempLabelMat = weightVectors * trainData.col(j);
+
+      maxVal = tempLabelMat.max(maxIndexRow, maxIndexCol);
+      maxVal *= 2;
+      //checking whether prediction is correct.
+      if(maxIndexRow != classLabels(0,j))
+      {
+        // due to incorrect prediction, convergence set to 0
+        converged = 0;
+        tempLabel = labels(0,j);
+        // send maxIndexRow for knowing which weight to update, 
+        // send j to know the value of the vector to update it with.
+        // send tempLabel to know the correct class 
+        LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow);
+      }
+    }
+  }
+}
+
+/*
+  Classification function. After training, use the weightVectors matrix to 
+  classify test, and put the predicted classes in predictedLabels.
+
+  @param: test - testing data or data to classify. 
+  @param: predictedLabels - vector to store the predicted classes after
+                            classifying test
+ */
+template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
+void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
+                const MatType& test, arma::Row<size_t>& predictedLabels)
+{
+  int i;
+  arma::mat tempLabelMat;
+  arma::uword maxIndexRow, maxIndexCol;
+  double maxVal;
+  MatType testData = test;
+  
+  MatType zOnes(1, test.n_cols);
+  zOnes.fill(1);
+  testData.insert_rows(0, zOnes);
+  
+  for (i = 0; i < test.n_cols; i++)
+  {
+    tempLabelMat = weightVectors * testData.col(i);
+    maxVal = tempLabelMat.max(maxIndexRow, maxIndexCol);
+    maxVal *= 2;
+    predictedLabels(0,i) = maxIndexRow;
+  }
+}
+
+}; // namespace perceptron
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
diff --git a/src/mlpack/methods/decision_stump/decision_stump_main.cpp b/src/mlpack/methods/perceptron/perceptron_main.cpp
similarity index 62%
copy from src/mlpack/methods/decision_stump/decision_stump_main.cpp
copy to src/mlpack/methods/perceptron/perceptron_main.cpp
index 4c998d5..a6082d7 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_main.cpp
+++ b/src/mlpack/methods/perceptron/perceptron_main.cpp
@@ -1,34 +1,31 @@
 /*
+ * @file: perceptron_main.cpp
  * @author: Udit Saxena
- * @file: decision_stump_main.cpp
  *
- * Main executable for the decision stump.
+ *
  */
 
 #include <mlpack/core.hpp>
-#include "decision_stump.hpp"
+#include "perceptron.hpp"
 
 using namespace mlpack;
-using namespace mlpack::decision_stump;
+using namespace mlpack::perceptron;
 using namespace std;
 using namespace arma;
 
-PROGRAM_INFO("Decision Stump","This program implements a decision stump, "
-    "a single level decision tree, on the given training data set. "
-    "Default size of buckets is 6");
+PROGRAM_INFO("","");
 
-// necessary parameters
+//necessary parameters
 PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
 PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
   "l");
 PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
-PARAM_STRING_REQ("num_classes","The number of classes","c");
 
-// output parameters (optional)
+//optional parameters.
 PARAM_STRING("output", "The file in which the predicted labels for the test set"
     " will be written.", "o", "output.csv");
-
-PARAM_INT("bucket_size","The size of ranges/buckets to be used while splitting the decision stump.","b", 6);
+PARAM_INT("iterations","The maximum number of iterations the perceptron is "
+  "to be run", "i", 1000)
 
 int main(int argc, char *argv[])
 {
@@ -51,17 +48,10 @@ int main(int argc, char *argv[])
   if (labelsIn.n_rows == 1)
     labelsIn = labelsIn.t();
 
-  size_t inpBucketSize = CLI::GetParam<int>("bucket_size");
-
   // normalize the labels
   data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
 
-  const size_t num_classes = CLI::GetParam<size_t>("num_classes");
-  /*
-  Should number of classes be input or should it be
-  derived from the labels row ?
-  */
-  const string testingDataFilename = CLI::GetParam<std::string>("test_file");
+  const string testingDataFilename = CLI::GetParam<string>("test_file");
   mat testingData;
   data::Load(testingDataFilename, testingData, true);
 
@@ -69,15 +59,16 @@ int main(int argc, char *argv[])
     Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
         << "must be the same as training data (" << trainingData.n_rows - 1
         << ")!" << std::endl;
+  int iterations = CLI::GetParam<int>("iterations");
   
-  Timer::Start("training");
-  DecisionStump<> ds(trainingData, labels, num_classes, inpBucketSize);
-  Timer::Stop("training");
+  Timer::Start("Training");
+  Perceptron<> p(trainingData, labels, iterations);
+  Timer::Stop("Training");
 
   Row<size_t> predictedLabels(testingData.n_cols);
-  Timer::Start("testing");
-  ds.Classify(testingData, predictedLabels);
-  Timer::Stop("testing");
+  Timer::Start("Testing");
+  p.Classify(testingData, predictedLabels);
+  Timer::Stop("Testing");
 
   vec results;
   data::RevertLabels(predictedLabels, mappings, results);
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 2aebb62..d35779b 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -34,6 +34,7 @@ add_executable(mlpack_test
   nca_test.cpp
   nmf_test.cpp
   pca_test.cpp
+  perceptron_test.cpp
   radical_test.cpp
   range_search_test.cpp
   save_restore_utility_test.cpp
diff --git a/src/mlpack/tests/perceptron_test.cpp b/src/mlpack/tests/perceptron_test.cpp
new file mode 100644
index 0000000..70b368e
--- /dev/null
+++ b/src/mlpack/tests/perceptron_test.cpp
@@ -0,0 +1,150 @@
+/*
+ * @file: perceptron_test.cpp
+ * @author: Udit Saxena
+ * 
+ * Tests for perceptron.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/perceptron/perceptron.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace arma;
+using namespace mlpack::perceptron;
+
+BOOST_AUTO_TEST_SUITE(PERCEPTRONTEST);
+/*
+This test tests whether the perceptron converges for the 
+AND gate classifier.
+*/
+BOOST_AUTO_TEST_CASE(AND)
+{
+  mat trainData;
+  trainData << 0 << 1 << 1 << 0 << endr
+            << 1 << 0 << 1 << 0 << endr;
+  Mat<size_t> labels;
+  labels << 0 << 0 << 1 << 0;
+
+  Perceptron<> p(trainData, labels.row(0), 1000);
+
+  mat testData;
+  testData << 0 << 1 << 1 << 0 << endr
+           << 1 << 0 << 1 << 0 << endr;
+  Row<size_t> predictedLabels(testData.n_cols);
+  p.Classify(testData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+  BOOST_CHECK_EQUAL(predictedLabels(0,1),0);
+  BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+  BOOST_CHECK_EQUAL(predictedLabels(0,3),0);
+  
+}
+
+/*
+This test tests whether the perceptron converges for the 
+OR gate classifier. 
+*/
+BOOST_AUTO_TEST_CASE(OR)
+{
+  mat trainData;
+  trainData << 0 << 1 << 1 << 0 << endr
+            << 1 << 0 << 1 << 0 << endr;
+
+  Mat<size_t> labels;
+  labels << 1 << 1 << 1 << 0;
+
+  Perceptron<> p(trainData, labels.row(0), 1000);
+
+  mat testData;
+  testData << 0 << 1 << 1 << 0 << endr
+            << 1 << 0 << 1 << 0 << endr;
+  Row<size_t> predictedLabels(testData.n_cols);
+  p.Classify(testData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0,0),1);
+  BOOST_CHECK_EQUAL(predictedLabels(0,1),1);
+  BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+  BOOST_CHECK_EQUAL(predictedLabels(0,3),0);
+}
+
+/*
+This tests the convergence on a set of linearly 
+separable data with 3 classes. 
+*/
+BOOST_AUTO_TEST_CASE(RANDOM3)
+{
+  mat trainData;
+  trainData << 0 << 1 << 1 << 4 << 5 << 4 << 1 << 2 << 1 << endr
+           << 1 << 0 << 1 << 1 << 1 << 2 << 4 << 5 << 4 << endr;
+
+  Mat<size_t> labels;
+  labels << 0 << 0 << 0 << 1 << 1 << 1 << 2 << 2 << 2;
+
+  Perceptron<> p(trainData, labels.row(0), 1000);
+
+  mat testData;
+  testData << 0 << 1 << 1 << endr
+           << 1 << 0 << 1 << endr;
+  Row<size_t> predictedLabels(testData.n_cols);
+  p.Classify(testData, predictedLabels);
+  
+  for (size_t i = 0; i<predictedLabels.n_cols; i++)
+    BOOST_CHECK_EQUAL(predictedLabels(0,i),0);
+
+}
+
+/*
+This tests the convergence of the perceptron on a dataset
+which has only TWO points which belong to different classes.
+*/
+BOOST_AUTO_TEST_CASE(TWOPOINTS)
+{
+  mat trainData;
+  trainData << 0 << 1 << endr
+           << 1 << 0 << endr;
+
+  Mat<size_t> labels;
+  labels << 0 << 1 ;
+
+  Perceptron<> p(trainData, labels.row(0), 1000);
+
+  mat testData;
+  testData << 0 << 1 << endr
+           << 1 << 0 << endr;
+  Row<size_t> predictedLabels(testData.n_cols);
+  p.Classify(testData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+  BOOST_CHECK_EQUAL(predictedLabels(0,1),1);
+}
+/*
+This tests the convergence of the perceptron on a dataset
+which has a non-linearly separable dataset.
+*/
+BOOST_AUTO_TEST_CASE(NONLINSEPDS)
+{
+  mat trainData;
+  trainData << 1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 
+            << 1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << endr
+            << 1 << 1 << 1 << 1 << 1 << 1 << 1 << 1 
+            << 2 << 2 << 2 << 2 << 2 << 2 << 2 << 2 << endr;
+
+  Mat<size_t> labels;
+  labels << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1
+         << 0 << 0 << 0 << 1 << 0 << 1 << 1 << 1;
+  Perceptron<> p(trainData, labels.row(0), 1000);
+
+  mat testData;
+  testData << 3 << 4 << 5 << 6 << endr
+           << 3 << 2.3 << 1.7 << 1.5 << endr;
+  Row<size_t> predictedLabels(testData.n_cols);
+  p.Classify(testData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0,0),0);
+  BOOST_CHECK_EQUAL(predictedLabels(0,1),0);
+  BOOST_CHECK_EQUAL(predictedLabels(0,2),1);
+  BOOST_CHECK_EQUAL(predictedLabels(0,3),1);
+}
+BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file



More information about the mlpack-git mailing list