[mlpack-git] master: Refactor main det program. (a990392)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 13:26:43 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7...be72510a765362f86782a8892f0e979aaa4a9f62
>---------------------------------------------------------------
commit a9903929204697424d4255bc7739479e1b52b7a7
Author: ryan <ryan at ratml.org>
Date: Mon Dec 21 13:23:41 2015 -0500
Refactor main det program.
>---------------------------------------------------------------
a9903929204697424d4255bc7739479e1b52b7a7
src/mlpack/methods/det/det_main.cpp | 217 ++++++++++++++++--------------------
1 file changed, 97 insertions(+), 120 deletions(-)
diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index f79c78b..e896b4c 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -26,37 +26,33 @@ PROGRAM_INFO("Density Estimation With Density Estimation Trees",
"for the test set and the variable importances.");
// Input data files.
-PARAM_STRING_REQ("train_file", "The data set on which to build a density "
- "estimation tree.", "t");
-PARAM_STRING("test_file", "A set of test points to estimate the density of.",
- "T", "");
-PARAM_STRING("labels_file", "The labels for the given training data to "
- "generate the class membership of each leaf (as an extra statistic)", "l",
- "");
+PARAM_STRING("training_file", "The data set on which to build a density "
+ "estimation tree.", "t", "");
+
+// Input or output model.
+PARAM_STRING("input_model_file", "File containing already trained density "
+ "estimation tree.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained density estimation tree"
+ " to.", "M", "");
// Output data files.
-PARAM_STRING("unpruned_tree_estimates_file", "The file in which to output the "
- "density estimates on the training set from the large unpruned tree.", "u",
- "");
+PARAM_STRING("test_file", "A set of test points to estimate the density of.",
+ "T", "");
PARAM_STRING("training_set_estimates_file", "The file in which to output the "
"density estimates on the training set from the final optimally pruned "
"tree.", "e", "");
PARAM_STRING("test_set_estimates_file", "The file in which to output the "
"estimates on the test set from the final optimally pruned tree.", "E", "");
-PARAM_STRING("leaf_class_table_file", "The file in which to output the leaf "
- "class membership table.", "L", "leaf_class_membership.txt");
-PARAM_STRING("tree_file", "The file in which to print the final optimally "
- "pruned tree.", "r", "");
PARAM_STRING("vi_file", "The file to output the variable importance values "
"for each feature.", "i", "");
-// Parameters for the algorithm.
+// Parameters for the training algorithm.
PARAM_INT("folds", "The number of folds of cross-validation to perform for the "
"estimation (0 is LOOCV)", "f", 10);
PARAM_INT("min_leaf_size", "The minimum size of a leaf in the unpruned, fully "
- "grown DET.", "N", 5);
+ "grown DET.", "l", 5);
PARAM_INT("max_leaf_size", "The maximum size of a leaf in the unpruned, fully "
- "grown DET.", "M", 10);
+ "grown DET.", "L", 10);
/*
PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
"a form of regularization similar to the usual alpha-pruning in decision "
@@ -65,61 +61,88 @@ PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
"penalize low volume leaves.", "R");
*/
-// Some flags for output of some information about the tree.
-PARAM_FLAG("print_tree", "Print the tree out on the command line (or in the "
- "file specified with --tree_file).", "p");
-PARAM_FLAG("print_vi", "Print the variable importance of each feature out on "
- "the command line (or in the file specified with --vi_file).", "I");
-
int main(int argc, char *argv[])
{
CLI::ParseCommandLine(argc, argv);
- string trainSetFile = CLI::GetParam<string>("train_file");
- arma::Mat<double> trainingData;
+ // Validate input parameters.
+ if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Only one of --training_file (-t) or --input_model_file (-m) "
+ << "may be specified!" << endl;
- data::Load(trainSetFile, trainingData, true);
+ if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) "
+ << "are specified!" << endl;
- // Cross-validation here.
- size_t folds = CLI::GetParam<int>("folds");
- if (folds == 0)
+ if (!CLI::HasParam("training_file"))
{
- folds = trainingData.n_cols;
- Log::Info << "Performing leave-one-out cross validation." << endl;
- }
- else
- {
- Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+ if (CLI::HasParam("training_set_estimates_file"))
+ Log::Warn << "--training_set_estimates_file (-e) ignored because "
+ << "--training_file (-t) is not specified." << endl;
+ if (CLI::HasParam("folds"))
+ Log::Warn << "--folds (-f) ignored because --training_file (-t) is not "
+ << "specified." << endl;
+ if (CLI::HasParam("min_leaf_size"))
+ Log::Warn << "--min_leaf_size (-l) ignored because --training_file (-t) "
+ << "is not specified." << endl;
+ if (CLI::HasParam("max_leaf_size"))
+ Log::Warn << "--max_leaf_size (-L) ignored because --training_file (-t) "
+ << "is not specified." << endl;
}
- const string unprunedTreeEstimateFile =
- CLI::GetParam<string>("unpruned_tree_estimates_file");
- const bool regularization = false;
-// const bool regularization = CLI::HasParam("volume_regularization");
- const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
- const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
+ if (!CLI::HasParam("test_file") && CLI::HasParam("test_set_estimates_file"))
+ Log::Warn << "--test_set_estimates_file (-E) ignored because --test_file "
+ << "(-T) is not specified." << endl;
- // Obtain the optimal tree.
- Timer::Start("det_training");
- DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
- minLeafSize, unprunedTreeEstimateFile);
- Timer::Stop("det_training");
+ // Are we training a DET or loading from file?
+ DTree* tree;
+ if (CLI::HasParam("training_file"))
+ {
+ const string trainSetFile = CLI::GetParam<string>("training_file");
+ arma::mat trainingData;
+ data::Load(trainSetFile, trainingData, true);
- // Compute densities for the training points in the optimal tree.
- FILE *fp = NULL;
+ // Cross-validation here.
+ size_t folds = CLI::GetParam<int>("folds");
+ if (folds == 0)
+ {
+ folds = trainingData.n_cols;
+ Log::Info << "Performing leave-one-out cross validation." << endl;
+ }
+ else
+ {
+ Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+ }
- if (CLI::GetParam<string>("training_set_estimates_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("training_set_estimates_file").c_str(),
- "w");
+ const bool regularization = false;
+// const bool regularization = CLI::HasParam("volume_regularization");
+ const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
+ const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
- // Compute density estimates for each point in the training set.
- Timer::Start("det_estimation_time");
- for (size_t i = 0; i < trainingData.n_cols; i++)
- fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i)));
- Timer::Stop("det_estimation_time");
+ // Obtain the optimal tree.
+ Timer::Start("det_training");
+ tree = Trainer(trainingData, folds, regularization, maxLeafSize,
+ minLeafSize, "");
+ Timer::Stop("det_training");
- fclose(fp);
+ // Compute training set estimates, if desired.
+ if (CLI::GetParam<string>("training_set_estimates_file") != "")
+ {
+ // Compute density estimates for each point in the training set.
+ arma::rowvec trainingDensities(trainingData.n_cols);
+ Timer::Start("det_estimation_time");
+ for (size_t i = 0; i < trainingData.n_cols; i++)
+ trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i));
+ Timer::Stop("det_estimation_time");
+
+ data::Save(CLI::GetParam<string>("training_set_estimates_file"),
+ trainingDensities);
+ }
+ }
+ else
+ {
+ data::Load(CLI::GetParam<string>("input_model_file"), "det_model", tree,
+ true);
}
// Compute the density at the provided test points and output the density in
@@ -130,72 +153,26 @@ int main(int argc, char *argv[])
arma::mat testData;
data::Load(testFile, testData, true);
- fp = NULL;
+ // Compute test set densities.
+ Timer::Start("det_test_set_estimation");
+ arma::rowvec testDensities(testData.n_cols);
+ for (size_t i = 0; i < testData.n_cols; i++)
+ testDensities[i] = tree->ComputeValue(testData.unsafe_col(i));
+ Timer::Stop("det_test_set_estimation");
if (CLI::GetParam<string>("test_set_estimates_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w");
-
- Timer::Start("det_test_set_estimation");
- for (size_t i = 0; i < testData.n_cols; i++)
- fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i)));
- Timer::Stop("det_test_set_estimation");
-
- fclose(fp);
- }
- }
-
- // Print the final tree.
- if (CLI::HasParam("print_tree"))
- {
- fp = NULL;
- if (CLI::GetParam<string>("tree_file") != "")
- {
- fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w");
-
- if (fp != NULL)
- {
- dtreeOpt->WriteTree(fp);
- fclose(fp);
- }
- }
- else
- {
- dtreeOpt->WriteTree(stdout);
- printf("\n");
- }
- }
-
- // Print the leaf memberships for the optimal tree.
- if (CLI::GetParam<string>("labels_file") != "")
- {
- std::string labelsFile = CLI::GetParam<string>("labels_file");
- arma::Mat<size_t> labels;
-
- data::Load(labelsFile, labels, true);
-
- size_t numClasses = 0;
- for (size_t i = 0; i < labels.n_elem; ++i)
- {
- if (labels[i] > numClasses)
- numClasses = labels[i];
- }
-
- Log::Info << numClasses << " found in labels file '" << labelsFile << "'."
- << std::endl;
-
- Log::Assert(trainingData.n_cols == labels.n_cols);
- Log::Assert(labels.n_rows == 1);
-
- PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses,
- CLI::GetParam<string>("leaf_class_table_file"));
+ data::Save(CLI::GetParam<string>("test_set_estimates_file"),
+ testDensities);
}
// Print variable importance.
- if (CLI::HasParam("print_vi"))
- {
- PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file"));
- }
+ if (CLI::HasParam("vi_file"))
+ PrintVariableImportance(tree, CLI::GetParam<string>("vi_file"));
+
+ // Save the model, if desired.
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), "det_model", tree,
+ false);
- delete dtreeOpt;
+ delete tree;
}
More information about the mlpack-git
mailing list