[mlpack-git] master: Add support for the BinaryNumericSplit. (1a19ef9)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:44:43 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 1a19ef9bfb5e927877acc50f43bbc132d94ba64b
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 18 07:29:30 2015 -0400
Add support for the BinaryNumericSplit.
Currently it crashes with datasets over a couple thousand points, but I am not
equipped with a working gdb on this system so I can't efficiently debug it and
will save it for later.
>---------------------------------------------------------------
1a19ef9bfb5e927877acc50f43bbc132d94ba64b
.../streaming_decision_tree_main.cpp | 54 +++++++++++++++++-----
1 file changed, 43 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 6d559c8..ef17cd7 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -7,7 +7,7 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp>
#include <mlpack/methods/hoeffding_trees/hoeffding_split.hpp>
-#include <stack>
+#include <mlpack/methods/hoeffding_trees/binary_numeric_split.hpp>
using namespace std;
using namespace mlpack;
@@ -31,20 +31,26 @@ PARAM_STRING("predictions_file", "File to output label predictions for test "
PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
"prediction probabilities in this file.", "P", "");
+PARAM_STRING("numeric_split_strategy", "The splitting strategy to use for "
+ "numeric features: 'domingos' or 'binary'.", "N", "binary");
+
+// Helper function for once we have chosen a tree type.
+template<typename TreeType>
+void PerformActions();
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
+ // Check input parameters for validity.
const string trainingFile = CLI::GetParam<string>("training_file");
const string labelsFile = CLI::GetParam<string>("labels_file");
- const double confidence = CLI::GetParam<double>("confidence");
- const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
const string inputModelFile = CLI::GetParam<string>("input_model_file");
- const string outputModelFile = CLI::GetParam<string>("output_model_file");
const string testFile = CLI::GetParam<string>("test_file");
const string predictionsFile = CLI::GetParam<string>("predictions_file");
const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+ const string numericSplitStrategy =
+ CLI::GetParam<string>("numeric_split_strategy");
if ((!predictionsFile.empty() || !probabilitiesFile.empty()) &&
testFile.empty())
@@ -59,7 +65,33 @@ int main(int argc, char** argv)
Log::Fatal << "If --training_file is specified, --labels_file must be "
<< "specified too!" << endl;
- StreamingDecisionTree<HoeffdingSplit<>>* tree = NULL;
+ if (numericSplitStrategy == "domingos")
+ PerformActions<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ HoeffdingDoubleNumericSplit, HoeffdingCategoricalSplit>>>();
+ else if (numericSplitStrategy == "binary")
+ PerformActions<StreamingDecisionTree<HoeffdingSplit<GiniImpurity,
+ BinaryDoubleNumericSplit, HoeffdingCategoricalSplit>>>();
+ else
+ Log::Fatal << "Unrecognized numeric split strategy ("
+ << numericSplitStrategy << ")! Must be 'domingos' or 'binary'."
+ << endl;
+}
+
+template<typename TreeType>
+void PerformActions()
+{
+ // Load necessary parameters.
+ const string trainingFile = CLI::GetParam<string>("training_file");
+ const string labelsFile = CLI::GetParam<string>("labels_file");
+ const double confidence = CLI::GetParam<double>("confidence");
+ const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
+ const string inputModelFile = CLI::GetParam<string>("input_model_file");
+ const string outputModelFile = CLI::GetParam<string>("output_model_file");
+ const string testFile = CLI::GetParam<string>("test_file");
+ const string predictionsFile = CLI::GetParam<string>("predictions_file");
+ const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+
+ TreeType* tree = NULL;
DatasetInfo datasetInfo;
if (inputModelFile.empty())
{
@@ -67,7 +99,7 @@ int main(int argc, char** argv)
data::Load(trainingFile, trainingSet, datasetInfo, true);
for (size_t i = 0; i < trainingSet.n_rows; ++i)
Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
- << i << ".\n";
+ << i << "." << endl;
arma::Col<size_t> labelsIn;
data::Load(labelsFile, labelsIn, true, false);
@@ -75,13 +107,13 @@ int main(int argc, char** argv)
// Now create the decision tree.
Timer::Start("tree_training");
- tree = new StreamingDecisionTree<HoeffdingSplit<>>(trainingSet, datasetInfo,
- labels, max(labels) + 1, confidence, maxSamples);
+ tree = new TreeType(trainingSet, datasetInfo, labels, max(labels) + 1,
+ confidence, maxSamples);
Timer::Stop("tree_training");
}
else
{
- tree = new StreamingDecisionTree<HoeffdingSplit<>>(datasetInfo, 1, 1);
+ tree = new TreeType(datasetInfo, 1, 1);
data::Load(inputModelFile, "streamingDecisionTree", *tree, true);
if (!trainingFile.empty())
@@ -90,7 +122,7 @@ int main(int argc, char** argv)
data::Load(trainingFile, trainingSet, datasetInfo, true);
for (size_t i = 0; i < trainingSet.n_rows; ++i)
Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
- << i << ".\n";
+ << i << "." << endl;
arma::Col<size_t> labelsIn;
data::Load(labelsFile, labelsIn, true, false);
@@ -103,7 +135,7 @@ int main(int argc, char** argv)
}
}
- // Great. Good job team. Now work on the test set if we have one.
+ // The tree is trained or loaded. Now do any testing if we need.
DatasetInfo testInfo;
if (!testFile.empty())
{
More information about the mlpack-git
mailing list