[mlpack-svn] r16852 - in mlpack/trunk/src/mlpack: methods/adaboost methods/perceptron methods/perceptron/learning_policies tests tests/data

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 23 16:01:16 EDT 2014


Author: saxena.udit
Date: Wed Jul 23 16:01:15 2014
New Revision: 16852

Log:
Adaboost improved. Tests for the UCI iris dataset added.

Added:
   mlpack/trunk/src/mlpack/tests/adaboost_test.cpp
   mlpack/trunk/src/mlpack/tests/data/iris.txt
   mlpack/trunk/src/mlpack/tests/data/iris_labels.txt
Modified:
   mlpack/trunk/src/mlpack/methods/adaboost/adaboost.hpp
   mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp
   mlpack/trunk/src/mlpack/methods/adaboost/adaboost_main.cpp
   mlpack/trunk/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
   mlpack/trunk/src/mlpack/methods/perceptron/perceptron.hpp
   mlpack/trunk/src/mlpack/methods/perceptron/perceptron_impl.hpp
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt

Modified: mlpack/trunk/src/mlpack/methods/adaboost/adaboost.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/adaboost/adaboost.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/adaboost/adaboost.hpp	Wed Jul 23 16:01:15 2014
@@ -19,6 +19,8 @@
 class Adaboost 
 {
 public:
+  arma::Row<size_t> finalHypothesis;
+
   Adaboost(const MatType& data, const arma::Row<size_t>& labels,
            int iterations, size_t classes, const WeakLearner& other);
 

Modified: mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp	Wed Jul 23 16:01:15 2014
@@ -78,25 +78,30 @@
   sumFinalH.fill(0.0);
   
   // load the initial weights into a 2-D matrix
-  const double initWeight = 1 / (data.n_cols * classes);
+  const double initWeight = (double) 1 / (data.n_cols * classes);
   arma::mat D(data.n_cols, classes);
   D.fill(initWeight);
-
+  // D.print("The value of D after initialization.");
+  
   // Weights are to be compressed into this rowvector
   // for focussing on the perceptron weights.
   arma::rowvec weights(predictedLabels.n_cols);
-
+  // weights.print("This is the value of weight just after initialization.");
   // This is the final hypothesis.
-  arma::rowvec finalH(predictedLabels.n_cols);
+  arma::Row<size_t> finalH(predictedLabels.n_cols);
 
+  
+  // int localErrorCount;
   // now start the boosting rounds
   for (i = 0; i < iterations; i++)
   {
+    
     // Initialized to zero in every round.
     rt = 0.0; 
     zt = 0.0;
     
     // Build the weight vectors
+
     buildWeightMatrix(D, weights);
     
     // call the other weak learner and train the labels.
@@ -105,7 +110,16 @@
 
     //Now from predictedLabels, build ht, the weak hypothesis
     buildClassificationMatrix(ht, predictedLabels);
-
+    
+/*    localErrorCount = 0;
+    for (int m = 0; m < labels.n_cols; m++)
+      if (labels(m) != predictedLabels(m))
+      {
+        localErrorCount++;
+        // std::cout<<m<<"th error.\n";
+      }
+    std::cout<<"Local Error is: "<<localErrorCount<<"\n";
+    std::cout<<"Local Error Rate: "<<(double)localErrorCount/predictedLabels.n_cols<<"\n";*/
     // Now, start calculation of alpha(t) using ht
     
     // begin calculation of rt
@@ -115,11 +129,9 @@
       for (k = 0;k < ht.n_cols; k++)
         rt += (D(j,k) * yt(j,k) * ht(j,k));
     }
-
     // end calculation of rt
 
     alphat = 0.5 * log((1 + rt) / (1 - rt));
-
     // end calculation of alphat
     
     // now start modifying weights
@@ -128,6 +140,7 @@
     {
       for (k = 0;k < D.n_cols; k++)
       {  
+        
         // we calculate zt, the normalization constant
         zt += D(j,k) * exp(-1 * alphat * yt(j,k) * ht(j,k));
         D(j,k) = D(j,k) * exp(-1 * alphat * yt(j,k) * ht(j,k));
@@ -136,11 +149,9 @@
         sumFinalH(j,k) += (alphat * ht(j,k));
       }
     }
-
     // normalization of D
 
     D = D / zt;
-  
   }
 
   // Iterations are over, now build a strong hypothesis
@@ -155,7 +166,18 @@
     tempSumFinalH.max(max_index);
     finalH(i) = max_index;
   }
-
+  finalHypothesis = finalH;
+  // labels.print("These are the labels.");
+  // finalH.print("This is the final hypothesis.");
+  /*int counterror = 0;
+  for (i = 0; i < labels.n_cols; i++)
+    if(labels(i) != finalH(i))
+    { 
+      std::cout<<i<<"th prediction not correct!\n";
+      counterror++;
+    }
+  std::cout<<"\nFinally - There are "<<counterror<<" number of misclassified records.\n";  
+  std::cout<<"The error rate is: "<<(double)counterror/labels.n_cols;*/
   //finalH is the final hypothesis.
 }
 

Modified: mlpack/trunk/src/mlpack/methods/adaboost/adaboost_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/adaboost/adaboost_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/adaboost/adaboost_main.cpp	Wed Jul 23 16:01:15 2014
@@ -16,10 +16,10 @@
 PROGRAM_INFO("","");
 
 //necessary parameters
-PARAM_STRING_REQ("train_file", "A file containing the training set.", "tr");
+PARAM_STRING_REQ("train_file", "A file containing the training set.", "t");
 PARAM_STRING_REQ("labels_file", "A file containing labels for the training set.",
   "l");
-PARAM_STRING_REQ("test_file", "A file containing the test set.", "te");
+PARAM_STRING_REQ("test_file", "A file containing the test set.", "T");
 
 //optional parameters.
 PARAM_STRING("output", "The file in which the predicted labels for the test set"
@@ -31,11 +31,11 @@
 int main(int argc, char *argv[])
 {
   CLI::ParseCommandLine(argc, argv);
-
+  
   const string trainingDataFilename = CLI::GetParam<string>("train_file");
   mat trainingData;
   data::Load(trainingDataFilename, trainingData, true);
-
+  
   const string labelsFilename = CLI::GetParam<string>("labels_file");
   // Load labels.
   mat labelsIn;
@@ -59,7 +59,7 @@
     labelsIn = trainingData.row(trainingData.n_rows - 1).t();
     trainingData.shed_row(trainingData.n_rows - 1);
   }
-
+  
   // helpers for normalizing the labels
   Col<size_t> labels;
   vec mappings;
@@ -70,7 +70,7 @@
 
   // normalize the labels
   data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
-
+  
   const string testingDataFilename = CLI::GetParam<string>("test_file");
   mat testingData;
   data::Load(testingDataFilename, testingData, true);
@@ -81,14 +81,15 @@
         << ")!" << std::endl;
   int iterations = CLI::GetParam<int>("iterations");
   
-  int classes = 6;
-
+  int classes = 3;
+  
   // define your own weak learner, perceptron in this case.
-  int iter = 1000;
-  perceptron::Perceptron<> p(trainingData, labels, iter);
-  // 
+  int iter = 4000;
+  
+  perceptron::Perceptron<> p(trainingData, labels.t(), iter);
+  
   Timer::Start("Training");
-  Adaboost<> a(trainingData, labels, iterations, classes, p);
+  Adaboost<> a(trainingData, labels.t(), iterations, classes, p);
   Timer::Stop("Training");
 
   // vec results;

Modified: mlpack/trunk/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/perceptron/learning_policies/simple_weight_update.hpp	Wed Jul 23 16:01:15 2014
@@ -35,18 +35,20 @@
    * @param rowIndex Index of the row which has been incorrectly predicted.
    * @param labelIndex Index of the vector in trainData.
    * @param vectorIndex Index of the class which should have been predicted.
+   * @param D Cost of mispredicting the labelIndex instance.
    */
   void UpdateWeights(const arma::mat& trainData,
                      arma::mat& weightVectors,
                      const size_t labelIndex,
                      const size_t vectorIndex,
-                     const size_t rowIndex)
+                     const size_t rowIndex,
+                     const arma::rowvec& D)
   {
     weightVectors.row(rowIndex) = weightVectors.row(rowIndex) - 
-                                  trainData.col(labelIndex).t();
+                                  D(labelIndex) * trainData.col(labelIndex).t();
 
     weightVectors.row(vectorIndex) = weightVectors.row(vectorIndex) +
-                                     trainData.col(labelIndex).t();
+                                     D(labelIndex) * trainData.col(labelIndex).t();
   }
 };
 

Modified: mlpack/trunk/src/mlpack/methods/perceptron/perceptron.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/perceptron/perceptron.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/perceptron/perceptron.hpp	Wed Jul 23 16:01:15 2014
@@ -54,11 +54,16 @@
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
   /**
-   *
-   *
-   *
+   *  Alternate constructor which copies parameters from an already initiated 
+   *  perceptron.
+   *  
+   *  @param other The other initiated Perceptron object from which we copy the
+   *               values from.
+   *  @param data The data on which to train this Perceptron object on.
+   *  @param D Weight vector to use while training. For boosting purposes.
+   *  @param labels The labels of data.
    */
-  Perceptron(const Perceptron<>& other, MatType& data, const arma::Row<double>& D, const arma::Row<size_t>& labels);
+  Perceptron(const Perceptron<>& other, MatType& data, const arma::rowvec& D, const arma::Row<size_t>& labels);
 
 private:
   //! To store the number of iterations
@@ -74,10 +79,11 @@
   arma::mat trainData;
 
   /**
-   * Train function.
-   *
+   *  Training Function. It trains on trainData using the cost matrix D
+   *  
+   *  @param D Cost matrix. Stores the cost of mispredicting instances
    */
-  void Train();
+  void Train(const arma::rowvec& D);
 };
 
 } // namespace perceptron

Modified: mlpack/trunk/src/mlpack/methods/perceptron/perceptron_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/perceptron/perceptron_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/perceptron/perceptron_impl.hpp	Wed Jul 23 16:01:15 2014
@@ -46,7 +46,10 @@
   trainData.insert_rows(0, zOnes);
 
   iter = iterations;
-  Train();
+  arma::rowvec D(data.n_cols);
+  D.fill(1.0);// giving equal weight to all the points.
+  
+  Train(D);
 }
 
 
@@ -74,34 +77,51 @@
     tempLabelMat.max(maxIndexRow, maxIndexCol);
     predictedLabels(0, i) = maxIndexRow;
   }
+  // predictedLabels.print("These are the labels predicted by the perceptron");
 }
 
+/**
+ *  Alternate constructor which copies parameters from an already initiated 
+ *  perceptron.
+ *  
+ *  @param other The other initiated Perceptron object from which we copy the
+ *               values from.
+ *  @param data The data on which to train this Perceptron object on.
+ *  @param D Weight vector to use while training. For boosting purposes.
+ *  @param labels The labels of data.
+ */
 template <typename LearnPolicy, typename WeightInitializationPolicy, typename MatType>
 Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
-  const Perceptron<>& other, MatType& data, const arma::Row<double>& D, const arma::Row<size_t>& labels)
+  const Perceptron<>& other, MatType& data, const arma::rowvec& D, const arma::Row<size_t>& labels)
 {
-  int i;
-  //transform data, as per rules for perceptron
-  for (i = 0;i < data.n_cols; i++)
-    data.col(i) = D(i) * data.col(i);
-
+  
   classLabels = labels;
   trainData = data;
   iter = other.iter;
 
-  Train();
+  // Insert a row of ones at the top of the training data set.
+  MatType zOnes(1, data.n_cols);
+  zOnes.fill(1);
+  trainData.insert_rows(0, zOnes);
+
+  WeightInitializationPolicy WIP;
+  WIP.Initialize(weightVectors, arma::max(labels) + 1, data.n_rows + 1);
+
+  Train(D);
 }
 
 /**
- *  Training Function. 
- *
+ *  Training Function. It trains on trainData using the cost matrix D
+ *  
+ *  @param D Cost matrix. Stores the cost of mispredicting instances
  */
 template<
     typename LearnPolicy,
     typename WeightInitializationPolicy,
     typename MatType
 >
-void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train()
+void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
+     const arma::rowvec& D)
 {
   int j, i = 0;
   bool converged = false;
@@ -136,7 +156,7 @@
         // Send maxIndexRow for knowing which weight to update, send j to know
         // the value of the vector to update it with.  Send tempLabel to know
         // the correct class.
-        LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow);
+        LP.UpdateWeights(trainData, weightVectors, j, tempLabel, maxIndexRow, D);
       }
     }
   }

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	Wed Jul 23 16:01:15 2014
@@ -1,6 +1,7 @@
 # MLPACK test executable.
 add_executable(mlpack_test
   mlpack_test.cpp
+  adaboost_test.cpp
   allkfn_test.cpp
   allknn_test.cpp
   allkrann_search_test.cpp

Added: mlpack/trunk/src/mlpack/tests/adaboost_test.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/adaboost_test.cpp	Wed Jul 23 16:01:15 2014
@@ -0,0 +1,54 @@
+/**
+ * @file Adaboost_test.cpp
+ * @author Udit Saxena
+ *
+ * Tests for Adaboost class.
+ */
+
+#include <mlpack/core.hpp>
+#include <mlpack/methods/adaboost/adaboost.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace arma;
+using namespace mlpack::adaboost;
+
+BOOST_AUTO_TEST_SUITE(AdaboostTest);
+
+BOOST_AUTO_TEST_CASE(IrisSet)
+{
+  arma::mat inputData;
+
+  if (!data::Load("iris.txt", inputData))
+    BOOST_FAIL("Cannot load test dataset iris.txt!");
+
+  arma::Mat<size_t> labels;
+
+  if (!data::Load("iris_labels.txt",labels))
+    BOOST_FAIL("Cannot load labels for iris iris_labels.txt");
+  
+  // no need to map the labels here
+
+  // Define your own weak learner, perceptron in this case.
+  // Run the perceptron for perceptron_iter iterations.
+  int perceptron_iter = 4000;
+
+  perceptron::Perceptron<> p(inputData, labels.row(0), perceptron_iter);
+
+  // Define parameters for the adaboost
+  int iterations = 15;
+  int classes = 3;
+  Adaboost<> a(inputData, labels.row(0), iterations, classes, p);
+  int countError = 0;
+  for (size_t i = 0; i < labels.n_cols; i++)
+    if(labels(i) != a.finalHypothesis(i))
+    { 
+      std::cout<<i<<" prediction not correct!\n";
+      countError++;
+    }
+  std::cout<<"\nFinally - There are "<<countError<<" number of misclassified records.\n";  
+  std::cout<<"The error rate is: "<<(double)countError * 100/labels.n_cols<<"%\n";
+}
+BOOST_AUTO_TEST_SUITE_END();
\ No newline at end of file

Added: mlpack/trunk/src/mlpack/tests/data/iris.txt
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/data/iris.txt	Wed Jul 23 16:01:15 2014
@@ -0,0 +1,150 @@
+5.1,3.5,1.4,0.2
+4.9,3.0,1.4,0.2
+4.7,3.2,1.3,0.2
+4.6,3.1,1.5,0.2
+5.0,3.6,1.4,0.2
+5.4,3.9,1.7,0.4
+4.6,3.4,1.4,0.3
+5.0,3.4,1.5,0.2
+4.4,2.9,1.4,0.2
+4.9,3.1,1.5,0.1
+5.4,3.7,1.5,0.2
+4.8,3.4,1.6,0.2
+4.8,3.0,1.4,0.1
+4.3,3.0,1.1,0.1
+5.8,4.0,1.2,0.2
+5.7,4.4,1.5,0.4
+5.4,3.9,1.3,0.4
+5.1,3.5,1.4,0.3
+5.7,3.8,1.7,0.3
+5.1,3.8,1.5,0.3
+5.4,3.4,1.7,0.2
+5.1,3.7,1.5,0.4
+4.6,3.6,1.0,0.2
+5.1,3.3,1.7,0.5
+4.8,3.4,1.9,0.2
+5.0,3.0,1.6,0.2
+5.0,3.4,1.6,0.4
+5.2,3.5,1.5,0.2
+5.2,3.4,1.4,0.2
+4.7,3.2,1.6,0.2
+4.8,3.1,1.6,0.2
+5.4,3.4,1.5,0.4
+5.2,4.1,1.5,0.1
+5.5,4.2,1.4,0.2
+4.9,3.1,1.5,0.1
+5.0,3.2,1.2,0.2
+5.5,3.5,1.3,0.2
+4.9,3.1,1.5,0.1
+4.4,3.0,1.3,0.2
+5.1,3.4,1.5,0.2
+5.0,3.5,1.3,0.3
+4.5,2.3,1.3,0.3
+4.4,3.2,1.3,0.2
+5.0,3.5,1.6,0.6
+5.1,3.8,1.9,0.4
+4.8,3.0,1.4,0.3
+5.1,3.8,1.6,0.2
+4.6,3.2,1.4,0.2
+5.3,3.7,1.5,0.2
+5.0,3.3,1.4,0.2
+7.0,3.2,4.7,1.4
+6.4,3.2,4.5,1.5
+6.9,3.1,4.9,1.5
+5.5,2.3,4.0,1.3
+6.5,2.8,4.6,1.5
+5.7,2.8,4.5,1.3
+6.3,3.3,4.7,1.6
+4.9,2.4,3.3,1.0
+6.6,2.9,4.6,1.3
+5.2,2.7,3.9,1.4
+5.0,2.0,3.5,1.0
+5.9,3.0,4.2,1.5
+6.0,2.2,4.0,1.0
+6.1,2.9,4.7,1.4
+5.6,2.9,3.6,1.3
+6.7,3.1,4.4,1.4
+5.6,3.0,4.5,1.5
+5.8,2.7,4.1,1.0
+6.2,2.2,4.5,1.5
+5.6,2.5,3.9,1.1
+5.9,3.2,4.8,1.8
+6.1,2.8,4.0,1.3
+6.3,2.5,4.9,1.5
+6.1,2.8,4.7,1.2
+6.4,2.9,4.3,1.3
+6.6,3.0,4.4,1.4
+6.8,2.8,4.8,1.4
+6.7,3.0,5.0,1.7
+6.0,2.9,4.5,1.5
+5.7,2.6,3.5,1.0
+5.5,2.4,3.8,1.1
+5.5,2.4,3.7,1.0
+5.8,2.7,3.9,1.2
+6.0,2.7,5.1,1.6
+5.4,3.0,4.5,1.5
+6.0,3.4,4.5,1.6
+6.7,3.1,4.7,1.5
+6.3,2.3,4.4,1.3
+5.6,3.0,4.1,1.3
+5.5,2.5,4.0,1.3
+5.5,2.6,4.4,1.2
+6.1,3.0,4.6,1.4
+5.8,2.6,4.0,1.2
+5.0,2.3,3.3,1.0
+5.6,2.7,4.2,1.3
+5.7,3.0,4.2,1.2
+5.7,2.9,4.2,1.3
+6.2,2.9,4.3,1.3
+5.1,2.5,3.0,1.1
+5.7,2.8,4.1,1.3
+6.3,3.3,6.0,2.5
+5.8,2.7,5.1,1.9
+7.1,3.0,5.9,2.1
+6.3,2.9,5.6,1.8
+6.5,3.0,5.8,2.2
+7.6,3.0,6.6,2.1
+4.9,2.5,4.5,1.7
+7.3,2.9,6.3,1.8
+6.7,2.5,5.8,1.8
+7.2,3.6,6.1,2.5
+6.5,3.2,5.1,2.0
+6.4,2.7,5.3,1.9
+6.8,3.0,5.5,2.1
+5.7,2.5,5.0,2.0
+5.8,2.8,5.1,2.4
+6.4,3.2,5.3,2.3
+6.5,3.0,5.5,1.8
+7.7,3.8,6.7,2.2
+7.7,2.6,6.9,2.3
+6.0,2.2,5.0,1.5
+6.9,3.2,5.7,2.3
+5.6,2.8,4.9,2.0
+7.7,2.8,6.7,2.0
+6.3,2.7,4.9,1.8
+6.7,3.3,5.7,2.1
+7.2,3.2,6.0,1.8
+6.2,2.8,4.8,1.8
+6.1,3.0,4.9,1.8
+6.4,2.8,5.6,2.1
+7.2,3.0,5.8,1.6
+7.4,2.8,6.1,1.9
+7.9,3.8,6.4,2.0
+6.4,2.8,5.6,2.2
+6.3,2.8,5.1,1.5
+6.1,2.6,5.6,1.4
+7.7,3.0,6.1,2.3
+6.3,3.4,5.6,2.4
+6.4,3.1,5.5,1.8
+6.0,3.0,4.8,1.8
+6.9,3.1,5.4,2.1
+6.7,3.1,5.6,2.4
+6.9,3.1,5.1,2.3
+5.8,2.7,5.1,1.9
+6.8,3.2,5.9,2.3
+6.7,3.3,5.7,2.5
+6.7,3.0,5.2,2.3
+6.3,2.5,5.0,1.9
+6.5,3.0,5.2,2.0
+6.2,3.4,5.4,2.3
+5.9,3.0,5.1,1.8
\ No newline at end of file

Added: mlpack/trunk/src/mlpack/tests/data/iris_labels.txt
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/data/iris_labels.txt	Wed Jul 23 16:01:15 2014
@@ -0,0 +1,150 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
+2
\ No newline at end of file



More information about the mlpack-svn mailing list