[mlpack-git] master: Print test and training error. (a00f5df)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:42 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit a00f5dfe7fa4d8d96f99844b50382be9731acc11
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 23 14:15:03 2015 -0800
Print test and training error.
>---------------------------------------------------------------
a00f5dfe7fa4d8d96f99844b50382be9731acc11
.../hoeffding_trees/hoeffding_tree_main.cpp | 42 ++++++++++++++++++++++
1 file changed, 42 insertions(+)
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index 7e9821e..3cdd976 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -28,6 +28,7 @@ PARAM_STRING("input_model_file", "File to load trained tree from.", "m", "");
PARAM_STRING("output_model_file", "File to save trained tree to.", "M", "");
PARAM_STRING("test_file", "File of testing data.", "T", "");
+PARAM_STRING("test_labels_file", "Labels of test data.", "L", "");
PARAM_STRING("predictions_file", "File to output label predictions for test "
"data into.", "p", "");
PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
@@ -216,6 +217,29 @@ void PerformActions(const typename TreeType::NumericSplit& numericSplit)
}
}
+ if (!trainingFile.empty())
+ {
+ // Get training error.
+ arma::mat trainingSet;
+ data::Load(trainingFile, trainingSet, datasetInfo, true);
+ arma::Row<size_t> predictions;
+ tree->Classify(trainingSet, predictions);
+
+ arma::Col<size_t> labelsIn;
+ data::Load(labelsFile, labelsIn, true, false);
+ arma::Row<size_t> labels = labelsIn.t();
+
+ size_t correct = 0;
+ for (size_t i = 0; i < labels.n_elem; ++i)
+ if (labels[i] == predictions[i])
+ ++correct;
+
+ Log::Info << correct << " out of " << labels.n_elem << " correct "
+ << "on training set (" << double(correct) / double(labels.n_elem) *
+ 100.0 << ")." << endl;
+
+ }
+
// The tree is trained or loaded. Now do any testing if we need.
if (!testFile.empty())
{
@@ -229,6 +253,24 @@ void PerformActions(const typename TreeType::NumericSplit& numericSplit)
tree->Classify(testSet, predictions, probabilities);
Timer::Stop("tree_testing");
+ if (CLI::HasParam("test_labels_file"))
+ {
+ string testLabelsFile = CLI::GetParam<string>("test_labels_file");
+ arma::Col<size_t> testLabelsIn;
+ data::Load(testLabelsFile, testLabelsIn, true, false);
+ arma::Row<size_t> testLabels = testLabelsIn.t();
+
+ size_t correct = 0;
+ for (size_t i = 0; i < testLabels.n_elem; ++i)
+ {
+ if (predictions[i] == testLabels[i])
+ ++correct;
+ }
+ Log::Info << correct << " out of " << testLabels.n_elem << " correct "
+ << "on test set (" << double(correct) / double(testLabels.n_elem) *
+ 100.0 << ")." << endl;
+ }
+
if (!predictionsFile.empty())
data::Save(predictionsFile, predictions);
More information about the mlpack-git
mailing list