[mlpack-svn] r16857 - mlpack/trunk/src/mlpack/methods/perceptron

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 25 12:01:52 EDT 2014


Author: marcus
Date: Fri Jul 25 12:01:52 2014
New Revision: 16857

Log:
Fix normalization bug (transpose); Some more comments.

Modified:
   mlpack/trunk/src/mlpack/methods/perceptron/perceptron_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/perceptron/perceptron_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/perceptron/perceptron_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/perceptron/perceptron_main.cpp	Fri Jul 25 12:01:52 2014
@@ -2,7 +2,10 @@
  * @file: perceptron_main.cpp
  * @author: Udit Saxena
  *
- * Main executable for the Perceptron.
+ * This program runs the Simple Perceptron Classifier.
+ *
+ * Perceptrons are simple single-layer binary classifiers, which solve linearly
+ * separable problems with a linear decision boundary.
  */
 
 #include <mlpack/core.hpp>
@@ -46,12 +49,13 @@
 PARAM_STRING("output", "The file in which the predicted labels for the test set"
     " will be written.", "o", "output.csv");
 PARAM_INT("iterations","The maximum number of iterations the perceptron is "
-  "to be run", "i", 1000)
+  "to be run", "i", 1000);
 
-int main(int argc, char *argv[])
+int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
+  // Get reference dataset filename.
   const string trainingDataFilename = CLI::GetParam<string>("train_file");
   mat trainingData;
   data::Load(trainingDataFilename, trainingData, true);
@@ -60,60 +64,61 @@
   // Load labels.
   mat labelsIn;
 
+  // Did the user pass in labels?
   if (CLI::HasParam("labels_file"))
   {
-    const string labelsFilename = CLI::GetParam<string>("labels_file");
     // Load labels.
+    const string labelsFilename = CLI::GetParam<string>("labels_file");
     data::Load(labelsFilename, labelsIn, true);
-
-    // Do the labels need to be transposed?
-    if (labelsIn.n_rows == 1)
-      labelsIn = labelsIn.t();
   }
   else
   {
-    // Extract the labels as the last
+    // Use the last row of the training data as the labels.
     Log::Info << "Using the last dimension of training set as labels." << endl;
-
     labelsIn = trainingData.row(trainingData.n_rows - 1).t();
     trainingData.shed_row(trainingData.n_rows - 1);
   }
-  // helpers for normalizing the labels
-  Col<size_t> labels;
-  vec mappings;
 
   // Do the labels need to be transposed?
   if (labelsIn.n_rows == 1)
+  {
     labelsIn = labelsIn.t();
+  }
 
-  // normalize the labels
+  // Normalize the labels.
+  Col<size_t> labels;
+  vec mappings;
   data::NormalizeLabels(labelsIn.unsafe_col(0), labels, mappings);
 
+  // Load test dataset.
   const string testingDataFilename = CLI::GetParam<string>("test_file");
   mat testingData;
   data::Load(testingDataFilename, testingData, true);
-
   if (testingData.n_rows != trainingData.n_rows)
+  {
     Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
         << "must be the same as training data (" << trainingData.n_rows - 1
         << ")!" << std::endl;
+  }
+
   int iterations = CLI::GetParam<int>("iterations");
   
+  // Create and train the classifier.
   Timer::Start("Training");
-  Perceptron<> p(trainingData, labels, iterations);
+  Perceptron<> p(trainingData, labels.t(), iterations);
   Timer::Stop("Training");
 
+  // Time the running of the Perceptron Classifier.
   Row<size_t> predictedLabels(testingData.n_cols);
   Timer::Start("Testing");
   p.Classify(testingData, predictedLabels);
   Timer::Stop("Testing");
 
+  // Un-normalize labels to prepare output.
   vec results;
-  data::RevertLabels(predictedLabels, mappings, results);
+  data::RevertLabels(predictedLabels.t(), mappings, results);
 
-  const string outputFilename = CLI::GetParam<string>("output");
-  data::Save(outputFilename, results, true, true);
   // saving the predictedLabels in the transposed manner in output
-
-  return 0;
+  const string outputFilename = CLI::GetParam<string>("output");
+  data::Save(outputFilename, results, true, false);
 }
\ No newline at end of file



More information about the mlpack-svn mailing list