[mlpack-git] master: Print test and training error. (a00f5df)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:42 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit a00f5dfe7fa4d8d96f99844b50382be9731acc11
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 23 14:15:03 2015 -0800

    Print test and training error.


>---------------------------------------------------------------

a00f5dfe7fa4d8d96f99844b50382be9731acc11
 .../hoeffding_trees/hoeffding_tree_main.cpp        | 42 ++++++++++++++++++++++
 1 file changed, 42 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index 7e9821e..3cdd976 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -28,6 +28,7 @@ PARAM_STRING("input_model_file", "File to load trained tree from.", "m", "");
 PARAM_STRING("output_model_file", "File to save trained tree to.", "M", "");
 
 PARAM_STRING("test_file", "File of testing data.", "T", "");
+PARAM_STRING("test_labels_file", "Labels of test data.", "L", "");
 PARAM_STRING("predictions_file", "File to output label predictions for test "
     "data into.", "p", "");
 PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
@@ -216,6 +217,29 @@ void PerformActions(const typename TreeType::NumericSplit& numericSplit)
     }
   }
 
+  if (!trainingFile.empty())
+  {
+    // Get training error.
+    arma::mat trainingSet;
+    data::Load(trainingFile, trainingSet, datasetInfo, true);
+    arma::Row<size_t> predictions;
+    tree->Classify(trainingSet, predictions);
+
+    arma::Col<size_t> labelsIn;
+    data::Load(labelsFile, labelsIn, true, false);
+    arma::Row<size_t> labels = labelsIn.t();
+
+    size_t correct = 0;
+    for (size_t i = 0; i < labels.n_elem; ++i)
+      if (labels[i] == predictions[i])
+        ++correct;
+
+    Log::Info << correct << " out of " << labels.n_elem << " correct "
+        << "on training set (" << double(correct) / double(labels.n_elem) *
+        100.0 << ")." << endl;
+
+  }
+
   // The tree is trained or loaded.  Now do any testing if we need.
   if (!testFile.empty())
   {
@@ -229,6 +253,24 @@ void PerformActions(const typename TreeType::NumericSplit& numericSplit)
     tree->Classify(testSet, predictions, probabilities);
     Timer::Stop("tree_testing");
 
+    if (CLI::HasParam("test_labels_file"))
+    {
+      string testLabelsFile = CLI::GetParam<string>("test_labels_file");
+      arma::Col<size_t> testLabelsIn;
+      data::Load(testLabelsFile, testLabelsIn, true, false);
+      arma::Row<size_t> testLabels = testLabelsIn.t();
+
+      size_t correct = 0;
+      for (size_t i = 0; i < testLabels.n_elem; ++i)
+      {
+        if (predictions[i] == testLabels[i])
+          ++correct;
+      }
+      Log::Info << correct << " out of " << testLabels.n_elem << " correct "
+          << "on test set (" << double(correct) / double(testLabels.n_elem) *
+          100.0 << ")." << endl;
+    }
+
     if (!predictionsFile.empty())
       data::Save(predictionsFile, predictions);
 



More information about the mlpack-git mailing list