[mlpack-svn] r13269 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 15:16:12 EDT 2012


Author: rcurtin
Date: 2012-07-20 15:16:12 -0400 (Fri, 20 Jul 2012)
New Revision: 13269

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
Log:
Clean up main executable file.


Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-20 18:53:40 UTC (rev 13268)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-20 19:16:12 UTC (rev 13269)
@@ -13,66 +13,52 @@
 using namespace std;
 
 PROGRAM_INFO("Density estimation with DET", "This program provides an example "
-    "use of the Density Estimation "
-	     "Tree for density estimation. For more details, "
-	     "please look at the paper titled "
-	     "'Density Estimation Trees'.");
+    "use of the Density Estimation Tree for density estimation. For more "
+    "details, please look at the paper titled 'Density Estimation Trees'.");
 
-// input data files
-PARAM_STRING_REQ("input/training_set", "The data set on which to "
-		 "perform density estimation.", "S");
-PARAM_STRING("input/test_set", "An extra set of test points on "
-	     "which to estimate the density given the estimator.",
-	     "T", "");
+// Input data files.
+PARAM_STRING_REQ("input/training_set", "The data set on which to perform "
+    "density estimation.", "S");
+PARAM_STRING("input/test_set", "An extra set of test points on which to "
+    "estimate the density given the estimator.", "T", "");
 PARAM_STRING("input/labels", "The labels for the given training data to "
-	     "generate the class membership of each leaf (as an "
-	     "extra statistic)", "L", "");
+    "generate the class membership of each leaf (as an extra statistic)", "L",
+    "");
 
-// output data files
-PARAM_STRING("output/unpruned_tree_estimates", "The file "
-	     "in which to output the estimates on the "
-	     "training set from the large unpruned tree.", "u", "");
-PARAM_STRING("output/training_set_estimates", "The file "
-	     "in which to output the estimates on the "
-	     "training set from the final optimally pruned"
-	     " tree.", "s", "");
-PARAM_STRING("output/test_set_estimates", "The file "
-	     "in which to output the estimates on the "
-	     "test set from the final optimally pruned"
-	     " tree.", "t", "");
-PARAM_STRING("output/leaf_class_table", "The file "
-	     "in which to output the leaf class membership "
-	     "table.", "l", "leaf_class_membership.txt");
-PARAM_STRING("output/tree", "The file in which to print "
-	     "the final optimally pruned tree.", "p", "");
-PARAM_STRING("output/vi", "The file to output the "
-	     "variable importance values for each feature.",
-	     "i", "");
+// Output data files.
+PARAM_STRING("output/unpruned_tree_estimates", "The file in which to output the"
+    " estimates on the training set from the large unpruned tree.", "u", "");
+PARAM_STRING("output/training_set_estimates", "The file in which to output the "
+    "estimates on the training set from the final optimally pruned tree.", "s",
+    "");
+PARAM_STRING("output/test_set_estimates", "The file in which to output the "
+    "estimates on the test set from the final optimally pruned tree.", "t", "");
+PARAM_STRING("output/leaf_class_table", "The file in which to output the leaf "
+    "class membership table.", "l", "leaf_class_membership.txt");
+PARAM_STRING("output/tree", "The file in which to print the final optimally "
+    "pruned tree.", "p", "");
+PARAM_STRING("output/vi", "The file to output the variable importance values "
+    "for each feature.", "i", "");
 
-// parameters for the algorithm
-PARAM_INT("param/number_of_classes", "The number of classes present "
-	  "in the 'labels' set provided", "C", 0);
-PARAM_INT("param/folds", "The number of folds of cross-validation"
-	  " to performed for the estimation (enter 0 for LOOCV)",
-	  "F", 10);
-PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf"
-	  " in the unpruned fully grown DET.", "N", 5);
-PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf"
-	  " in the unpruned fully grown DET.", "M", 10);
-PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the "
-	   "option to use a form of regularization similar to "
-	   "the usual alpha-pruning in decision tree. But "
-	   "instead of regularizing on the number of leaves, "
-	   "you regularize on the sum of the inverse of the volume "
-	   "of the leaves (meaning you penalize  "
-	   "low volume leaves.", "R");
+// Parameters for the algorithm.
+PARAM_INT("param/number_of_classes", "The number of classes present in the "
+    "'labels' set provided", "C", 0);
+PARAM_INT("param/folds", "The number of folds of cross-validation to perform "
+    "for the estimation (enter 0 for LOOCV)", "F", 10);
+PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf in the unpruned "
+    "fully grown DET.", "N", 5);
+PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf in the unpruned "
+    "fully grown DET.", "M", 10);
+PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the option to use a "
+    "form of regularization similar to the usual alpha-pruning in decision "
+    "tree. But instead of regularizing on the number of leaves, you regularize "
+    "on the sum of the inverse of the volume of the leaves (meaning you "
+    "penalize low volume leaves.", "R");
 
-// some flags for output of some information about the tree
-PARAM_FLAG("flag/print_tree", "If you just wish to print the tree "
-	   "out on the command line.", "P");
-PARAM_FLAG("flag/print_vi", "If you just wish to print the "
-	   "variable importance of each feature "
-	   "out on the command line.", "I");
+// Some flags for output of some information about the tree.
+PARAM_FLAG("flag/print_tree", "Print the tree out on the command line.", "P");
+PARAM_FLAG("flag/print_vi", "Print the variable importance of each feature "
+    "out on the command line.", "I");
 
 int main(int argc, char *argv[])
 {
@@ -81,147 +67,136 @@
   string train_set_file = CLI::GetParam<string>("S");
   arma::Mat<double> training_data;
 
-  Log::Info << "Loading training set..." << endl;
-  if (!data::Load(train_set_file, training_data))
-    Log::Fatal << "Training set file "<< train_set_file
-	       << " can't be loaded." << endl;
+  data::Load(train_set_file, training_data, true);
 
-  Log::Info << "Training set (" << training_data.n_rows
-	    << ", " << training_data.n_cols
-	    << ")" << endl;
-
-  // cross-validation here
+  // Cross-validation here.
   size_t folds = CLI::GetParam<int>("F");
-  if (folds == 0) {
+  if (folds == 0)
+  {
     folds = training_data.n_cols;
-    Log::Info << "Starting Leave-One-Out Cross validation" << endl;
-  } else
-    Log::Info << "Starting " << folds
-	      << "-fold Cross validation" << endl;
+    Log::Info << "Performing leave-one-out cross validation." << endl;
+  }
+  else
+  {
+    Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+  }
 
+  const string unpruned_tree_estimate_file =
+      CLI::GetParam<string>("output/unpruned_tree_estimates");
+  const bool regularization = CLI::HasParam("DET/use_volume_reg");
+  const int maxLeafSize = CLI::GetParam<int>("DET/max_leaf_size");
+  const int minLeafSize = CLI::GetParam<int>("DET/min_leaf_size");
 
+  // Obtain the optimal tree.
+  Timer::Start("det_training");
+  DTree<double> *dtree_opt = Trainer<double>(&training_data, folds,
+      regularization, maxLeafSize, minLeafSize, unpruned_tree_estimate_file);
+  Timer::Stop("det_training");
 
-  // obtaining the optimal tree
-  string unpruned_tree_estimate_file
-    = CLI::GetParam<string>("u");
-
-  Timer::Start("DET/Training");
-  DTree<double> *dtree_opt = Trainer<double>
-    (&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
-     CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
-  Timer::Stop("DET/Training");
-
-  // computing densities for the train points in the
-  // optimal tree
+  // Compute densities for the training points in the optimal tree.
   FILE *fp = NULL;
 
-  if (CLI::GetParam<string>("s") != "") {
-    string optimal_estimates_file = CLI::GetParam<string>("s");
-    fp = fopen(optimal_estimates_file.c_str(), "w");
+  if (CLI::GetParam<string>("output/training_set_estimates") != "")
+  {
+    fp = fopen(CLI::GetParam<string>("output/training_set_estimates").c_str(),
+        "w");
   }
 
-  // Computation timing is more accurate when you do not
-  // perform the printing.
-  Timer::Start("DET/EstimationTime");
-  for (size_t i = 0; i < training_data.n_cols; i++) {
-    arma::Col<double> test_p = training_data.unsafe_col(i);
-    long double f = dtree_opt->ComputeValue(test_p);
+  // Computation timing is more accurate when printing is not performed.
+  Timer::Start("det_estimation_time");
+  for (size_t i = 0; i < training_data.n_cols; i++)
+  {
+    arma::vec test_p = training_data.unsafe_col(i);
+    double f = dtree_opt->ComputeValue(test_p);
+
     if (fp != NULL)
-      fprintf(fp, "%Lg\n", f);
-  } // end for
-  Timer::Stop("DET/EstimationTime");
+      fprintf(fp, "%lg\n", f);
+  }
+  Timer::Stop("det_estimation_time");
 
   if (fp != NULL)
     fclose(fp);
 
+  // Compute the density at the provided test points and output the density in
+  // the given file.
+  if (CLI::GetParam<string>("input/test_set") != "")
+  {
+    const string test_file = CLI::GetParam<string>("input/test_set");
+    arma::mat test_data;
+    data::Load(test_file, test_data, true);
 
-  // computing the density at the provided test points
-  // and outputting the density in the given file.
-  if (CLI::GetParam<string>("T") != "") {
-    string test_file = CLI::GetParam<string>("T");
-    arma::Mat<double> test_data;
-    Log::Info << "Loading test set..." << endl;
-    if (!data::Load(test_file, test_data))
-      Log::Fatal << "Test set file "<< test_file
-		 << " can't be loaded." << endl;
-
-    Log::Info << "Test set (" << test_data.n_rows
-	      << ", " << test_data.n_cols
-	      << ")" << endl;
-
     fp = NULL;
 
-    if (CLI::GetParam<string>("t") != "") {
-      string test_density_file
-	= CLI::GetParam<string>("t");
-      fp = fopen(test_density_file.c_str(), "w");
+    if (CLI::GetParam<string>("output/test_set_estimates") != "")
+    {
+      fp = fopen(CLI::GetParam<string>("output/test_set_estimates").c_str(),
+          "w");
     }
 
-    Timer::Start("DET/TestSetEstimation");
-    for (size_t i = 0; i < test_data.n_cols; i++) {
-      arma::Col<double> test_p = test_data.unsafe_col(i);
-      long double f = dtree_opt->ComputeValue(test_p);
+    Timer::Start("det_test_set_estimation");
+    for (size_t i = 0; i < test_data.n_cols; i++)
+    {
+      arma::vec test_p = test_data.unsafe_col(i);
+      double f = dtree_opt->ComputeValue(test_p);
+
       if (fp != NULL)
-	fprintf(fp, "%Lg\n", f);
-    } // end for
-    Timer::Stop("DET/TestSetEstimation");
+        fprintf(fp, "%lg\n", f);
+    }
+    Timer::Stop("det_test_set_estimation");
 
     if (fp != NULL)
       fclose(fp);
-  } // Test set estimation
+  }
 
-  // printing the final tree
-  if (CLI::HasParam("P")) {
-
+  // Print the final tree.
+  if (CLI::HasParam("flag/print_tree"))
+  {
     fp = NULL;
-    if (CLI::GetParam<string>("p") != "") {
-      string print_tree_file = CLI::GetParam<string>("p");
-      fp = fopen(print_tree_file.c_str(), "w");
+    if (CLI::GetParam<string>("output/tree") != "")
+    {
+      fp = fopen(CLI::GetParam<string>("output/tree").c_str(), "w");
 
-      if (fp != NULL) {
-	dtree_opt->WriteTree(0, fp);
-	fclose(fp);
+      if (fp != NULL)
+      {
+        dtree_opt->WriteTree(0, fp);
+        fclose(fp);
       }
-    } else {
+    }
+    else
+    {
       dtree_opt->WriteTree(0, stdout);
       printf("\n");
     }
-  } // Printing the tree
+  }
 
-  // print the leaf memberships for the optimal tree
-  if (CLI::GetParam<string>("L") != "") {
-    std::string labels_file = CLI::GetParam<string>("L");
-    arma::Mat<int> labels;
+  // Print the leaf memberships for the optimal tree.
+  if (CLI::GetParam<string>("input/labels") != "")
+  {
+    std::string labels_file = CLI::GetParam<string>("input/labels");
+    arma::Mat<size_t> labels;
 
-    Log::Info << "Loading label file..." << endl;
-    if (!data::Load(labels_file, labels))
-      Log::Fatal << "Label file "<< labels_file
-		 << " can't be loaded." << endl;
+    data::Load(labels_file, labels, true);
 
-    Log::Info << "Labels (" << labels.n_rows
-	      << ", " << labels.n_cols
-	      << ")" << endl;
-
-    size_t num_classes = CLI::GetParam<int>("C");
+    size_t num_classes = CLI::GetParam<int>("param/number_of_classes");
     if (num_classes == 0)
-      Log::Fatal << "Please provide the number of classes"
-		 << " present in the label file" << endl;
+    {
+      Log::Fatal << "Number of classes (param/number_of_classes) not specified!"
+          << endl;
+    }
 
-    assert(training_data.n_cols == labels.n_cols);
-    assert(labels.n_rows == 1);
+    Log::Assert(training_data.n_cols == labels.n_cols);
+    Log::Assert(labels.n_rows == 1);
 
-    PrintLeafMembership<double>
-      (dtree_opt, training_data, labels, num_classes,
-       (string) CLI::GetParam<string>("l"));
-  } // leaf class membership
+    PrintLeafMembership<double>(dtree_opt, training_data, labels, num_classes,
+       CLI::GetParam<string>("output/leaf_class_table"));
+  }
 
-
-  if (CLI::HasParam("I"))
+  // Print variable importance.
+  if (CLI::HasParam("flag/print_vi"))
   {
-    PrintVariableImportance<double>(dtree_opt, CLI::GetParam<string>("i"));
-  } // print variable importance
+    PrintVariableImportance<double>(dtree_opt,
+        CLI::GetParam<string>("output/vi"));
+  }
 
-
   delete dtree_opt;
-  return 0;
-} // end main
+}

Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-20 18:53:40 UTC (rev 13268)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-20 19:16:12 UTC (rev 13269)
@@ -19,7 +19,7 @@
 template<typename eT>
 void PrintLeafMembership(DTree<eT> *dtree,
                          const arma::Mat<eT>& data,
-                         const arma::Mat<int>& labels,
+                         const arma::Mat<size_t>& labels,
                          size_t num_classes,
                          string leaf_class_membership_file = "")
 {
@@ -33,7 +33,7 @@
   {
     arma::Col<eT> test_p = data.unsafe_col(i);
     int leaf_tag = dtree->FindBucket(test_p);
-    int label = labels[i];
+    size_t label = labels[i];
     table(leaf_tag, label) += 1;
   }
 




More information about the mlpack-svn mailing list