[mlpack-svn] r13269 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 15:16:12 EDT 2012
Author: rcurtin
Date: 2012-07-20 15:16:12 -0400 (Fri, 20 Jul 2012)
New Revision: 13269
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
Log:
Clean up main executable file.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-20 18:53:40 UTC (rev 13268)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-20 19:16:12 UTC (rev 13269)
@@ -13,66 +13,52 @@
using namespace std;
PROGRAM_INFO("Density estimation with DET", "This program provides an example "
- "use of the Density Estimation "
- "Tree for density estimation. For more details, "
- "please look at the paper titled "
- "'Density Estimation Trees'.");
+ "use of the Density Estimation Tree for density estimation. For more "
+ "details, please look at the paper titled 'Density Estimation Trees'.");
-// input data files
-PARAM_STRING_REQ("input/training_set", "The data set on which to "
- "perform density estimation.", "S");
-PARAM_STRING("input/test_set", "An extra set of test points on "
- "which to estimate the density given the estimator.",
- "T", "");
+// Input data files.
+PARAM_STRING_REQ("input/training_set", "The data set on which to perform "
+ "density estimation.", "S");
+PARAM_STRING("input/test_set", "An extra set of test points on which to "
+ "estimate the density given the estimator.", "T", "");
PARAM_STRING("input/labels", "The labels for the given training data to "
- "generate the class membership of each leaf (as an "
- "extra statistic)", "L", "");
+ "generate the class membership of each leaf (as an extra statistic)", "L",
+ "");
-// output data files
-PARAM_STRING("output/unpruned_tree_estimates", "The file "
- "in which to output the estimates on the "
- "training set from the large unpruned tree.", "u", "");
-PARAM_STRING("output/training_set_estimates", "The file "
- "in which to output the estimates on the "
- "training set from the final optimally pruned"
- " tree.", "s", "");
-PARAM_STRING("output/test_set_estimates", "The file "
- "in which to output the estimates on the "
- "test set from the final optimally pruned"
- " tree.", "t", "");
-PARAM_STRING("output/leaf_class_table", "The file "
- "in which to output the leaf class membership "
- "table.", "l", "leaf_class_membership.txt");
-PARAM_STRING("output/tree", "The file in which to print "
- "the final optimally pruned tree.", "p", "");
-PARAM_STRING("output/vi", "The file to output the "
- "variable importance values for each feature.",
- "i", "");
+// Output data files.
+PARAM_STRING("output/unpruned_tree_estimates", "The file in which to output the"
+ " estimates on the training set from the large unpruned tree.", "u", "");
+PARAM_STRING("output/training_set_estimates", "The file in which to output the "
+ "estimates on the training set from the final optimally pruned tree.", "s",
+ "");
+PARAM_STRING("output/test_set_estimates", "The file in which to output the "
+ "estimates on the test set from the final optimally pruned tree.", "t", "");
+PARAM_STRING("output/leaf_class_table", "The file in which to output the leaf "
+ "class membership table.", "l", "leaf_class_membership.txt");
+PARAM_STRING("output/tree", "The file in which to print the final optimally "
+ "pruned tree.", "p", "");
+PARAM_STRING("output/vi", "The file to output the variable importance values "
+ "for each feature.", "i", "");
-// parameters for the algorithm
-PARAM_INT("param/number_of_classes", "The number of classes present "
- "in the 'labels' set provided", "C", 0);
-PARAM_INT("param/folds", "The number of folds of cross-validation"
- " to performed for the estimation (enter 0 for LOOCV)",
- "F", 10);
-PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf"
- " in the unpruned fully grown DET.", "N", 5);
-PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf"
- " in the unpruned fully grown DET.", "M", 10);
-PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the "
- "option to use a form of regularization similar to "
- "the usual alpha-pruning in decision tree. But "
- "instead of regularizing on the number of leaves, "
- "you regularize on the sum of the inverse of the volume "
- "of the leaves (meaning you penalize "
- "low volume leaves.", "R");
+// Parameters for the algorithm.
+PARAM_INT("param/number_of_classes", "The number of classes present in the "
+ "'labels' set provided", "C", 0);
+PARAM_INT("param/folds", "The number of folds of cross-validation to perform "
+ "for the estimation (enter 0 for LOOCV)", "F", 10);
+PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf in the unpruned "
+ "fully grown DET.", "N", 5);
+PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf in the unpruned "
+ "fully grown DET.", "M", 10);
+PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the option to use a "
+ "form of regularization similar to the usual alpha-pruning in decision "
+ "tree. But instead of regularizing on the number of leaves, you regularize "
+ "on the sum of the inverse of the volume of the leaves (meaning you "
+ "penalize low volume leaves.", "R");
-// some flags for output of some information about the tree
-PARAM_FLAG("flag/print_tree", "If you just wish to print the tree "
- "out on the command line.", "P");
-PARAM_FLAG("flag/print_vi", "If you just wish to print the "
- "variable importance of each feature "
- "out on the command line.", "I");
+// Some flags for output of some information about the tree.
+PARAM_FLAG("flag/print_tree", "Print the tree out on the command line.", "P");
+PARAM_FLAG("flag/print_vi", "Print the variable importance of each feature "
+ "out on the command line.", "I");
int main(int argc, char *argv[])
{
@@ -81,147 +67,136 @@
string train_set_file = CLI::GetParam<string>("S");
arma::Mat<double> training_data;
- Log::Info << "Loading training set..." << endl;
- if (!data::Load(train_set_file, training_data))
- Log::Fatal << "Training set file "<< train_set_file
- << " can't be loaded." << endl;
+ data::Load(train_set_file, training_data, true);
- Log::Info << "Training set (" << training_data.n_rows
- << ", " << training_data.n_cols
- << ")" << endl;
-
- // cross-validation here
+ // Cross-validation here.
size_t folds = CLI::GetParam<int>("F");
- if (folds == 0) {
+ if (folds == 0)
+ {
folds = training_data.n_cols;
- Log::Info << "Starting Leave-One-Out Cross validation" << endl;
- } else
- Log::Info << "Starting " << folds
- << "-fold Cross validation" << endl;
+ Log::Info << "Performing leave-one-out cross validation." << endl;
+ }
+ else
+ {
+ Log::Info << "Performing " << folds << "-fold cross validation." << endl;
+ }
+ const string unpruned_tree_estimate_file =
+ CLI::GetParam<string>("output/unpruned_tree_estimates");
+ const bool regularization = CLI::HasParam("DET/use_volume_reg");
+ const int maxLeafSize = CLI::GetParam<int>("DET/max_leaf_size");
+ const int minLeafSize = CLI::GetParam<int>("DET/min_leaf_size");
+ // Obtain the optimal tree.
+ Timer::Start("det_training");
+ DTree<double> *dtree_opt = Trainer<double>(&training_data, folds,
+ regularization, maxLeafSize, minLeafSize, unpruned_tree_estimate_file);
+ Timer::Stop("det_training");
- // obtaining the optimal tree
- string unpruned_tree_estimate_file
- = CLI::GetParam<string>("u");
-
- Timer::Start("DET/Training");
- DTree<double> *dtree_opt = Trainer<double>
- (&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
- CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
- Timer::Stop("DET/Training");
-
- // computing densities for the train points in the
- // optimal tree
+ // Compute densities for the training points in the optimal tree.
FILE *fp = NULL;
- if (CLI::GetParam<string>("s") != "") {
- string optimal_estimates_file = CLI::GetParam<string>("s");
- fp = fopen(optimal_estimates_file.c_str(), "w");
+ if (CLI::GetParam<string>("output/training_set_estimates") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("output/training_set_estimates").c_str(),
+ "w");
}
- // Computation timing is more accurate when you do not
- // perform the printing.
- Timer::Start("DET/EstimationTime");
- for (size_t i = 0; i < training_data.n_cols; i++) {
- arma::Col<double> test_p = training_data.unsafe_col(i);
- long double f = dtree_opt->ComputeValue(test_p);
+ // Computation timing is more accurate when printing is not performed.
+ Timer::Start("det_estimation_time");
+ for (size_t i = 0; i < training_data.n_cols; i++)
+ {
+ arma::vec test_p = training_data.unsafe_col(i);
+ double f = dtree_opt->ComputeValue(test_p);
+
if (fp != NULL)
- fprintf(fp, "%Lg\n", f);
- } // end for
- Timer::Stop("DET/EstimationTime");
+ fprintf(fp, "%lg\n", f);
+ }
+ Timer::Stop("det_estimation_time");
if (fp != NULL)
fclose(fp);
+ // Compute the density at the provided test points and output the density in
+ // the given file.
+ if (CLI::GetParam<string>("input/test_set") != "")
+ {
+ const string test_file = CLI::GetParam<string>("input/test_set");
+ arma::mat test_data;
+ data::Load(test_file, test_data, true);
- // computing the density at the provided test points
- // and outputting the density in the given file.
- if (CLI::GetParam<string>("T") != "") {
- string test_file = CLI::GetParam<string>("T");
- arma::Mat<double> test_data;
- Log::Info << "Loading test set..." << endl;
- if (!data::Load(test_file, test_data))
- Log::Fatal << "Test set file "<< test_file
- << " can't be loaded." << endl;
-
- Log::Info << "Test set (" << test_data.n_rows
- << ", " << test_data.n_cols
- << ")" << endl;
-
fp = NULL;
- if (CLI::GetParam<string>("t") != "") {
- string test_density_file
- = CLI::GetParam<string>("t");
- fp = fopen(test_density_file.c_str(), "w");
+ if (CLI::GetParam<string>("output/test_set_estimates") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("output/test_set_estimates").c_str(),
+ "w");
}
- Timer::Start("DET/TestSetEstimation");
- for (size_t i = 0; i < test_data.n_cols; i++) {
- arma::Col<double> test_p = test_data.unsafe_col(i);
- long double f = dtree_opt->ComputeValue(test_p);
+ Timer::Start("det_test_set_estimation");
+ for (size_t i = 0; i < test_data.n_cols; i++)
+ {
+ arma::vec test_p = test_data.unsafe_col(i);
+ double f = dtree_opt->ComputeValue(test_p);
+
if (fp != NULL)
- fprintf(fp, "%Lg\n", f);
- } // end for
- Timer::Stop("DET/TestSetEstimation");
+ fprintf(fp, "%lg\n", f);
+ }
+ Timer::Stop("det_test_set_estimation");
if (fp != NULL)
fclose(fp);
- } // Test set estimation
+ }
- // printing the final tree
- if (CLI::HasParam("P")) {
-
+ // Print the final tree.
+ if (CLI::HasParam("flag/print_tree"))
+ {
fp = NULL;
- if (CLI::GetParam<string>("p") != "") {
- string print_tree_file = CLI::GetParam<string>("p");
- fp = fopen(print_tree_file.c_str(), "w");
+ if (CLI::GetParam<string>("output/tree") != "")
+ {
+ fp = fopen(CLI::GetParam<string>("output/tree").c_str(), "w");
- if (fp != NULL) {
- dtree_opt->WriteTree(0, fp);
- fclose(fp);
+ if (fp != NULL)
+ {
+ dtree_opt->WriteTree(0, fp);
+ fclose(fp);
}
- } else {
+ }
+ else
+ {
dtree_opt->WriteTree(0, stdout);
printf("\n");
}
- } // Printing the tree
+ }
- // print the leaf memberships for the optimal tree
- if (CLI::GetParam<string>("L") != "") {
- std::string labels_file = CLI::GetParam<string>("L");
- arma::Mat<int> labels;
+ // Print the leaf memberships for the optimal tree.
+ if (CLI::GetParam<string>("input/labels") != "")
+ {
+ std::string labels_file = CLI::GetParam<string>("input/labels");
+ arma::Mat<size_t> labels;
- Log::Info << "Loading label file..." << endl;
- if (!data::Load(labels_file, labels))
- Log::Fatal << "Label file "<< labels_file
- << " can't be loaded." << endl;
+ data::Load(labels_file, labels, true);
- Log::Info << "Labels (" << labels.n_rows
- << ", " << labels.n_cols
- << ")" << endl;
-
- size_t num_classes = CLI::GetParam<int>("C");
+ size_t num_classes = CLI::GetParam<int>("param/number_of_classes");
if (num_classes == 0)
- Log::Fatal << "Please provide the number of classes"
- << " present in the label file" << endl;
+ {
+ Log::Fatal << "Number of classes (param/number_of_classes) not specified!"
+ << endl;
+ }
- assert(training_data.n_cols == labels.n_cols);
- assert(labels.n_rows == 1);
+ Log::Assert(training_data.n_cols == labels.n_cols);
+ Log::Assert(labels.n_rows == 1);
- PrintLeafMembership<double>
- (dtree_opt, training_data, labels, num_classes,
- (string) CLI::GetParam<string>("l"));
- } // leaf class membership
+ PrintLeafMembership<double>(dtree_opt, training_data, labels, num_classes,
+ CLI::GetParam<string>("output/leaf_class_table"));
+ }
-
- if (CLI::HasParam("I"))
+ // Print variable importance.
+ if (CLI::HasParam("flag/print_vi"))
{
- PrintVariableImportance<double>(dtree_opt, CLI::GetParam<string>("i"));
- } // print variable importance
+ PrintVariableImportance<double>(dtree_opt,
+ CLI::GetParam<string>("output/vi"));
+ }
-
delete dtree_opt;
- return 0;
-} // end main
+}
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-20 18:53:40 UTC (rev 13268)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-20 19:16:12 UTC (rev 13269)
@@ -19,7 +19,7 @@
template<typename eT>
void PrintLeafMembership(DTree<eT> *dtree,
const arma::Mat<eT>& data,
- const arma::Mat<int>& labels,
+ const arma::Mat<size_t>& labels,
size_t num_classes,
string leaf_class_membership_file = "")
{
@@ -33,7 +33,7 @@
{
arma::Col<eT> test_p = data.unsafe_col(i);
int leaf_tag = dtree->FindBucket(test_p);
- int label = labels[i];
+ size_t label = labels[i];
table(leaf_tag, label) += 1;
}
More information about the mlpack-svn
mailing list