[mlpack-git] master: Better main executable. (6cdc7c5)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:06 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 6cdc7c5d230a5f546bd428c687ce715c22bea53a
Author: ryan <ryan at ratml.org>
Date: Fri Oct 2 01:07:21 2015 -0400
Better main executable.
>---------------------------------------------------------------
6cdc7c5d230a5f546bd428c687ce715c22bea53a
.../streaming_decision_tree_main.cpp | 138 ++++++++++++++-------
1 file changed, 91 insertions(+), 47 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 93659a7..6d559c8 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -14,7 +14,7 @@ using namespace mlpack;
using namespace mlpack::tree;
using namespace mlpack::data;
-PARAM_STRING_REQ("training_file", "Training dataset file.", "t");
+PARAM_STRING("training_file", "Training dataset file.", "t", "");
PARAM_STRING("labels_file", "Labels for training dataset.", "l", "");
PARAM_DOUBLE("confidence", "Confidence before splitting (between 0 and 1).",
@@ -22,9 +22,14 @@ PARAM_DOUBLE("confidence", "Confidence before splitting (between 0 and 1).",
PARAM_INT("max_samples", "Maximum number of samples before splitting.", "n",
5000);
-PARAM_STRING("model_file", "File to save trained tree to.", "m", "");
+PARAM_STRING("input_model_file", "File to load trained tree from.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained tree to.", "M", "");
PARAM_STRING("test_file", "File of testing data.", "T", "");
+PARAM_STRING("predictions_file", "File to output label predictions for test "
+ "data into.", "p", "");
+PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
+ "prediction probabilities in this file.", "P", "");
int main(int argc, char** argv)
@@ -35,55 +40,94 @@ int main(int argc, char** argv)
const string labelsFile = CLI::GetParam<string>("labels_file");
const double confidence = CLI::GetParam<double>("confidence");
const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
-
- arma::mat trainingSet;
+ const string inputModelFile = CLI::GetParam<string>("input_model_file");
+ const string outputModelFile = CLI::GetParam<string>("output_model_file");
+ const string testFile = CLI::GetParam<string>("test_file");
+ const string predictionsFile = CLI::GetParam<string>("predictions_file");
+ const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+
+ if ((!predictionsFile.empty() || !probabilitiesFile.empty()) &&
+ testFile.empty())
+ Log::Fatal << "--test_file must be specified if --predictions_file or "
+ << "--probabilities_file is specified." << endl;
+
+ if (trainingFile.empty() && inputModelFile.empty())
+ Log::Fatal << "One of --training_file or --input_model_file must be "
+ << "specified!" << endl;
+
+ if (!trainingFile.empty() && labelsFile.empty())
+ Log::Fatal << "If --training_file is specified, --labels_file must be "
+ << "specified too!" << endl;
+
+ StreamingDecisionTree<HoeffdingSplit<>>* tree = NULL;
DatasetInfo datasetInfo;
- data::Load(trainingFile, trainingSet, datasetInfo, true);
- for (size_t i = 0; i < trainingSet.n_rows; ++i)
- Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension " << i <<
-".\n";
-
- arma::Col<size_t> labelsIn;
- data::Load(labelsFile, labelsIn, true, false);
- arma::Row<size_t> labels = labelsIn.t();
-
- // Now create the decision tree.
- Timer::Start("tree_training");
- StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
- max(labels) + 1, confidence, maxSamples);
- Timer::Stop("tree_training");
-
- // Great. Good job team.
- std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
- stack.push(&tree);
- size_t nodes = 0;
- while (!stack.empty())
+ if (inputModelFile.empty())
+ {
+ arma::mat trainingSet;
+ data::Load(trainingFile, trainingSet, datasetInfo, true);
+ for (size_t i = 0; i < trainingSet.n_rows; ++i)
+ Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
+ << i << ".\n";
+
+ arma::Col<size_t> labelsIn;
+ data::Load(labelsFile, labelsIn, true, false);
+ arma::Row<size_t> labels = labelsIn.t();
+
+ // Now create the decision tree.
+ Timer::Start("tree_training");
+ tree = new StreamingDecisionTree<HoeffdingSplit<>>(trainingSet, datasetInfo,
+ labels, max(labels) + 1, confidence, maxSamples);
+ Timer::Stop("tree_training");
+ }
+ else
+ {
+ tree = new StreamingDecisionTree<HoeffdingSplit<>>(datasetInfo, 1, 1);
+ data::Load(inputModelFile, "streamingDecisionTree", *tree, true);
+
+ if (!trainingFile.empty())
+ {
+ arma::mat trainingSet;
+ data::Load(trainingFile, trainingSet, datasetInfo, true);
+ for (size_t i = 0; i < trainingSet.n_rows; ++i)
+ Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
+ << i << ".\n";
+
+ arma::Col<size_t> labelsIn;
+ data::Load(labelsFile, labelsIn, true, false);
+ arma::Row<size_t> labels = labelsIn.t();
+
+ // Now create the decision tree.
+ Timer::Start("tree_training");
+ tree->Train(trainingSet, labels);
+ Timer::Stop("tree_training");
+ }
+ }
+
+ // Great. Good job team. Now work on the test set if we have one.
+ DatasetInfo testInfo;
+ if (!testFile.empty())
{
- StreamingDecisionTree<HoeffdingSplit<>>* node = stack.top();
- stack.pop();
- ++nodes;
+ arma::mat testSet;
+ data::Load(testFile, testSet, testInfo, true);
+
+ arma::Row<size_t> predictions;
+ arma::rowvec probabilities;
+
+ Timer::Start("tree_testing");
+ tree->Classify(testSet, predictions, probabilities);
+ Timer::Stop("tree_testing");
- for (size_t i = 0; i < node->NumChildren(); ++i)
- stack.push(&node->Child(i));
+ if (!predictionsFile.empty())
+ data::Save(predictionsFile, predictions);
+
+ if (!probabilitiesFile.empty())
+ data::Save(probabilitiesFile, probabilities);
}
- Log::Info << nodes << " nodes in tree.\n";
// Check the accuracy on the training set.
- Timer::Start("tree_testing");
- arma::Row<size_t> predictedLabels;
- tree.Classify(trainingSet, predictedLabels);
- Timer::Stop("tree_testing");
-
- size_t correct = 0;
- for (size_t i = 0; i < predictedLabels.n_elem; ++i)
- if (labels[i] == predictedLabels[i])
- ++correct;
- else if (predictedLabels[i] > 10)
- Log::Warn << "Invalid label " << predictedLabels[i] << " for point " << i
- << "!\n";
-
- Log::Info << correct << " correct out of " << predictedLabels.n_elem << ".\n";
-
- const string modelFile = CLI::GetParam<string>("model_file");
- data::Save(modelFile, "streamingDecisionTree", tree, true);
+ if (!outputModelFile.empty())
+ data::Save(outputModelFile, "streamingDecisionTree", tree, true);
+
+ // Clean up memory.
+ delete tree;
}
More information about the mlpack-git
mailing list