[mlpack-git] master: Refactor main det program. (a990392)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 13:26:43 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/09cbc6e13aa3cb8a7c4ea6d2e1612977a40c6be7...be72510a765362f86782a8892f0e979aaa4a9f62

>---------------------------------------------------------------

commit a9903929204697424d4255bc7739479e1b52b7a7
Author: ryan <ryan at ratml.org>
Date:   Mon Dec 21 13:23:41 2015 -0500

    Refactor main det program.


>---------------------------------------------------------------

a9903929204697424d4255bc7739479e1b52b7a7
 src/mlpack/methods/det/det_main.cpp | 217 ++++++++++++++++--------------------
 1 file changed, 97 insertions(+), 120 deletions(-)

diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index f79c78b..e896b4c 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -26,37 +26,33 @@ PROGRAM_INFO("Density Estimation With Density Estimation Trees",
     "for the test set and the variable importances.");
 
 // Input data files.
-PARAM_STRING_REQ("train_file", "The data set on which to build a density "
-    "estimation tree.", "t");
-PARAM_STRING("test_file", "A set of test points to estimate the density of.",
-    "T", "");
-PARAM_STRING("labels_file", "The labels for the given training data to "
-    "generate the class membership of each leaf (as an extra statistic)", "l",
-    "");
+PARAM_STRING("training_file", "The data set on which to build a density "
+    "estimation tree.", "t", "");
+
+// Input or output model.
+PARAM_STRING("input_model_file", "File containing already trained density "
+    "estimation tree.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained density estimation tree"
+    " to.", "M", "");
 
 // Output data files.
-PARAM_STRING("unpruned_tree_estimates_file", "The file in which to output the "
-    "density estimates on the training set from the large unpruned tree.", "u",
-    "");
+PARAM_STRING("test_file", "A set of test points to estimate the density of.",
+    "T", "");
 PARAM_STRING("training_set_estimates_file", "The file in which to output the "
     "density estimates on the training set from the final optimally pruned "
     "tree.", "e", "");
 PARAM_STRING("test_set_estimates_file", "The file in which to output the "
     "estimates on the test set from the final optimally pruned tree.", "E", "");
-PARAM_STRING("leaf_class_table_file", "The file in which to output the leaf "
-    "class membership table.", "L", "leaf_class_membership.txt");
-PARAM_STRING("tree_file", "The file in which to print the final optimally "
-    "pruned tree.", "r", "");
 PARAM_STRING("vi_file", "The file to output the variable importance values "
     "for each feature.", "i", "");
 
-// Parameters for the algorithm.
+// Parameters for the training algorithm.
 PARAM_INT("folds", "The number of folds of cross-validation to perform for the "
     "estimation (0 is LOOCV)", "f", 10);
 PARAM_INT("min_leaf_size", "The minimum size of a leaf in the unpruned, fully "
-    "grown DET.", "N", 5);
+    "grown DET.", "l", 5);
 PARAM_INT("max_leaf_size", "The maximum size of a leaf in the unpruned, fully "
-    "grown DET.", "M", 10);
+    "grown DET.", "L", 10);
 /*
 PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
     "a form of regularization similar to the usual alpha-pruning in decision "
@@ -65,61 +61,88 @@ PARAM_FLAG("volume_regularization", "This flag gives the used the option to use"
     "penalize low volume leaves.", "R");
 */
 
-// Some flags for output of some information about the tree.
-PARAM_FLAG("print_tree", "Print the tree out on the command line (or in the "
-    "file specified with --tree_file).", "p");
-PARAM_FLAG("print_vi", "Print the variable importance of each feature out on "
-    "the command line (or in the file specified with --vi_file).", "I");
-
 int main(int argc, char *argv[])
 {
   CLI::ParseCommandLine(argc, argv);
 
-  string trainSetFile = CLI::GetParam<string>("train_file");
-  arma::Mat<double> trainingData;
+  // Validate input parameters.
+  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Only one of --training_file (-t) or --input_model_file (-m) "
+        << "may be specified!" << endl;
 
-  data::Load(trainSetFile, trainingData, true);
+  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) "
+        << "are specified!" << endl;
 
-  // Cross-validation here.
-  size_t folds = CLI::GetParam<int>("folds");
-  if (folds == 0)
+  if (!CLI::HasParam("training_file"))
   {
-    folds = trainingData.n_cols;
-    Log::Info << "Performing leave-one-out cross validation." << endl;
-  }
-  else
-  {
-    Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+    if (CLI::HasParam("training_set_estimates_file"))
+      Log::Warn << "--training_set_estimates_file (-e) ignored because "
+          << "--training_file (-t) is not specified." << endl;
+    if (CLI::HasParam("folds"))
+      Log::Warn << "--folds (-f) ignored because --training_file (-t) is not "
+          << "specified." << endl;
+    if (CLI::HasParam("min_leaf_size"))
+      Log::Warn << "--min_leaf_size (-l) ignored because --training_file (-t) "
+          << "is not specified." << endl;
+    if (CLI::HasParam("max_leaf_size"))
+      Log::Warn << "--max_leaf_size (-L) ignored because --training_file (-t) "
+          << "is not specified." << endl;
   }
 
-  const string unprunedTreeEstimateFile =
-      CLI::GetParam<string>("unpruned_tree_estimates_file");
-  const bool regularization = false;
-//  const bool regularization = CLI::HasParam("volume_regularization");
-  const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
-  const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
+  if (!CLI::HasParam("test_file") && CLI::HasParam("test_set_estimates_file"))
+    Log::Warn << "--test_set_estimates_file (-E) ignored because --test_file "
+        << "(-T) is not specified." << endl;
 
-  // Obtain the optimal tree.
-  Timer::Start("det_training");
-  DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
-      minLeafSize, unprunedTreeEstimateFile);
-  Timer::Stop("det_training");
+  // Are we training a DET or loading from file?
+  DTree* tree;
+  if (CLI::HasParam("training_file"))
+  {
+    const string trainSetFile = CLI::GetParam<string>("training_file");
+    arma::mat trainingData;
+    data::Load(trainSetFile, trainingData, true);
 
-  // Compute densities for the training points in the optimal tree.
-  FILE *fp = NULL;
+    // Cross-validation here.
+    size_t folds = CLI::GetParam<int>("folds");
+    if (folds == 0)
+    {
+      folds = trainingData.n_cols;
+      Log::Info << "Performing leave-one-out cross validation." << endl;
+    }
+    else
+    {
+      Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+    }
 
-  if (CLI::GetParam<string>("training_set_estimates_file") != "")
-  {
-    fp = fopen(CLI::GetParam<string>("training_set_estimates_file").c_str(),
-        "w");
+    const bool regularization = false;
+//    const bool regularization = CLI::HasParam("volume_regularization");
+    const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
+    const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
 
-    // Compute density estimates for each point in the training set.
-    Timer::Start("det_estimation_time");
-    for (size_t i = 0; i < trainingData.n_cols; i++)
-      fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i)));
-    Timer::Stop("det_estimation_time");
+    // Obtain the optimal tree.
+    Timer::Start("det_training");
+    tree = Trainer(trainingData, folds, regularization, maxLeafSize,
+        minLeafSize, "");
+    Timer::Stop("det_training");
 
-    fclose(fp);
+    // Compute training set estimates, if desired.
+    if (CLI::GetParam<string>("training_set_estimates_file") != "")
+    {
+      // Compute density estimates for each point in the training set.
+      arma::rowvec trainingDensities(trainingData.n_cols);
+      Timer::Start("det_estimation_time");
+      for (size_t i = 0; i < trainingData.n_cols; i++)
+        trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i));
+      Timer::Stop("det_estimation_time");
+
+      data::Save(CLI::GetParam<string>("training_set_estimates_file"),
+          trainingDensities);
+    }
+  }
+  else
+  {
+    data::Load(CLI::GetParam<string>("input_model_file"), "det_model", tree,
+        true);
   }
 
   // Compute the density at the provided test points and output the density in
@@ -130,72 +153,26 @@ int main(int argc, char *argv[])
     arma::mat testData;
     data::Load(testFile, testData, true);
 
-    fp = NULL;
+    // Compute test set densities.
+    Timer::Start("det_test_set_estimation");
+    arma::rowvec testDensities(testData.n_cols);
+    for (size_t i = 0; i < testData.n_cols; i++)
+      testDensities[i] = tree->ComputeValue(testData.unsafe_col(i));
+    Timer::Stop("det_test_set_estimation");
 
     if (CLI::GetParam<string>("test_set_estimates_file") != "")
-    {
-      fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w");
-
-      Timer::Start("det_test_set_estimation");
-      for (size_t i = 0; i < testData.n_cols; i++)
-        fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i)));
-      Timer::Stop("det_test_set_estimation");
-
-      fclose(fp);
-    }
-  }
-
-  // Print the final tree.
-  if (CLI::HasParam("print_tree"))
-  {
-    fp = NULL;
-    if (CLI::GetParam<string>("tree_file") != "")
-    {
-      fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w");
-
-      if (fp != NULL)
-      {
-        dtreeOpt->WriteTree(fp);
-        fclose(fp);
-      }
-    }
-    else
-    {
-      dtreeOpt->WriteTree(stdout);
-      printf("\n");
-    }
-  }
-
-  // Print the leaf memberships for the optimal tree.
-  if (CLI::GetParam<string>("labels_file") != "")
-  {
-    std::string labelsFile = CLI::GetParam<string>("labels_file");
-    arma::Mat<size_t> labels;
-
-    data::Load(labelsFile, labels, true);
-
-    size_t numClasses = 0;
-    for (size_t i = 0; i < labels.n_elem; ++i)
-    {
-      if (labels[i] > numClasses)
-        numClasses = labels[i];
-    }
-
-    Log::Info << numClasses << " found in labels file '" << labelsFile << "'."
-        << std::endl;
-
-    Log::Assert(trainingData.n_cols == labels.n_cols);
-    Log::Assert(labels.n_rows == 1);
-
-    PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses,
-       CLI::GetParam<string>("leaf_class_table_file"));
+      data::Save(CLI::GetParam<string>("test_set_estimates_file"),
+          testDensities);
   }
 
   // Print variable importance.
-  if (CLI::HasParam("print_vi"))
-  {
-    PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file"));
-  }
+  if (CLI::HasParam("vi_file"))
+    PrintVariableImportance(tree, CLI::GetParam<string>("vi_file"));
+
+  // Save the model, if desired.
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "det_model", tree,
+        false);
 
-  delete dtreeOpt;
+  delete tree;
 }



More information about the mlpack-git mailing list