[mlpack-svn] r13219 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 12 11:16:14 EDT 2012
Author: rcurtin
Date: 2012-07-12 11:16:13 -0400 (Thu, 12 Jul 2012)
New Revision: 13219
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
Log:
Switch to doubles not floats in main executable so it compiles.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-12 14:52:45 UTC (rev 13218)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-12 15:16:13 UTC (rev 13219)
@@ -12,8 +12,8 @@
using namespace mlpack::det;
using namespace std;
-PROGRAM_INFO("Density estimation with DET", "This program "
- "provides an example use of the Density Estimation "
+PROGRAM_INFO("Density estimation with DET", "This program provides an example "
+ "use of the Density Estimation "
"Tree for density estimation. For more details, "
"please look at the paper titled "
"'Density Estimation Trees'.");
@@ -74,19 +74,19 @@
"variable importance of each feature "
"out on the command line.", "I");
-int main(int argc, char *argv[])
+int main(int argc, char *argv[])
{
CLI::ParseCommandLine(argc, argv);
string train_set_file = CLI::GetParam<string>("S");
- arma::Mat<float> training_data;
+ arma::Mat<double> training_data;
Log::Info << "Loading training set..." << endl;
if (!data::Load(train_set_file, training_data))
- Log::Fatal << "Training set file "<< train_set_file
+ Log::Fatal << "Training set file "<< train_set_file
<< " can't be loaded." << endl;
- Log::Info << "Training set (" << training_data.n_rows
+ Log::Info << "Training set (" << training_data.n_rows
<< ", " << training_data.n_cols
<< ")" << endl;
@@ -95,22 +95,22 @@
if (folds == 0) {
folds = training_data.n_cols;
Log::Info << "Starting Leave-One-Out Cross validation" << endl;
- } else
- Log::Info << "Starting " << folds
+ } else
+ Log::Info << "Starting " << folds
<< "-fold Cross validation" << endl;
-
+
// obtaining the optimal tree
- string unpruned_tree_estimate_file
+ string unpruned_tree_estimate_file
= CLI::GetParam<string>("u");
Timer::Start("DET/Training");
- DTree<float> *dtree_opt = Trainer<float>
+ DTree<double> *dtree_opt = Trainer<double>
(&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
Timer::Stop("DET/Training");
-
+
// computing densities for the train points in the
// optimal tree
FILE *fp = NULL;
@@ -118,13 +118,13 @@
if (CLI::GetParam<string>("s") != "") {
string optimal_estimates_file = CLI::GetParam<string>("s");
fp = fopen(optimal_estimates_file.c_str(), "w");
- }
+ }
- // Computation timing is more accurate when you do not
+ // Computation timing is more accurate when you do not
// perform the printing.
Timer::Start("DET/EstimationTime");
for (size_t i = 0; i < training_data.n_cols; i++) {
- arma::Col<float> test_p = training_data.unsafe_col(i);
+ arma::Col<double> test_p = training_data.unsafe_col(i);
long double f = dtree_opt->ComputeValue(&test_p);
if (fp != NULL)
fprintf(fp, "%Lg\n", f);
@@ -133,19 +133,19 @@
if (fp != NULL)
fclose(fp);
-
+
// computing the density at the provided test points
// and outputting the density in the given file.
if (CLI::GetParam<string>("T") != "") {
string test_file = CLI::GetParam<string>("T");
- arma::Mat<float> test_data;
+ arma::Mat<double> test_data;
Log::Info << "Loading test set..." << endl;
if (!data::Load(test_file, test_data))
- Log::Fatal << "Test set file "<< test_file
+ Log::Fatal << "Test set file "<< test_file
<< " can't be loaded." << endl;
- Log::Info << "Test set (" << test_data.n_rows
+ Log::Info << "Test set (" << test_data.n_rows
<< ", " << test_data.n_cols
<< ")" << endl;
@@ -159,14 +159,14 @@
Timer::Start("DET/TestSetEstimation");
for (size_t i = 0; i < test_data.n_cols; i++) {
- arma::Col<float> test_p = test_data.unsafe_col(i);
+ arma::Col<double> test_p = test_data.unsafe_col(i);
long double f = dtree_opt->ComputeValue(&test_p);
if (fp != NULL)
fprintf(fp, "%Lg\n", f);
} // end for
Timer::Stop("DET/TestSetEstimation");
- if (fp != NULL)
+ if (fp != NULL)
fclose(fp);
} // Test set estimation
@@ -195,13 +195,13 @@
Log::Info << "Loading label file..." << endl;
if (!data::Load(labels_file, labels))
- Log::Fatal << "Label file "<< labels_file
+ Log::Fatal << "Label file "<< labels_file
<< " can't be loaded." << endl;
- Log::Info << "Labels (" << labels.n_rows
+ Log::Info << "Labels (" << labels.n_rows
<< ", " << labels.n_cols
<< ")" << endl;
-
+
size_t num_classes = CLI::GetParam<int>("C");
if (num_classes == 0)
Log::Fatal << "Please provide the number of classes"
@@ -210,14 +210,14 @@
assert(training_data.n_cols == labels.n_cols);
assert(labels.n_rows == 1);
- PrintLeafMembership<float>
+ PrintLeafMembership<double>
(dtree_opt, training_data, labels, num_classes,
(string) CLI::GetParam<string>("l"));
} // leaf class membership
-
+
if(CLI::HasParam("I")) {
- PrintVariableImportance<float>
+ PrintVariableImportance<double>
(dtree_opt, training_data.n_rows,
(string) CLI::GetParam<string>("i"));
} // print variable importance
More information about the mlpack-svn
mailing list