[mlpack-svn] r13219 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 12 11:16:14 EDT 2012


Author: rcurtin
Date: 2012-07-12 11:16:13 -0400 (Thu, 12 Jul 2012)
New Revision: 13219

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
Log:
Switch to doubles not floats in main executable so it compiles.


Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-12 14:52:45 UTC (rev 13218)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-07-12 15:16:13 UTC (rev 13219)
@@ -12,8 +12,8 @@
 using namespace mlpack::det;
 using namespace std;
 
-PROGRAM_INFO("Density estimation with DET", "This program "
-             "provides an example use of the Density Estimation "
+PROGRAM_INFO("Density estimation with DET", "This program provides an example "
+    "use of the Density Estimation "
 	     "Tree for density estimation. For more details, "
 	     "please look at the paper titled "
 	     "'Density Estimation Trees'.");
@@ -74,19 +74,19 @@
 	   "variable importance of each feature "
 	   "out on the command line.", "I");
 
-int main(int argc, char *argv[]) 
+int main(int argc, char *argv[])
 {
   CLI::ParseCommandLine(argc, argv);
 
   string train_set_file = CLI::GetParam<string>("S");
-  arma::Mat<float> training_data;
+  arma::Mat<double> training_data;
 
   Log::Info << "Loading training set..." << endl;
   if (!data::Load(train_set_file, training_data))
-    Log::Fatal << "Training set file "<< train_set_file 
+    Log::Fatal << "Training set file "<< train_set_file
 	       << " can't be loaded." << endl;
 
-  Log::Info << "Training set (" << training_data.n_rows 
+  Log::Info << "Training set (" << training_data.n_rows
 	    << ", " << training_data.n_cols
 	    << ")" << endl;
 
@@ -95,22 +95,22 @@
   if (folds == 0) {
     folds = training_data.n_cols;
     Log::Info << "Starting Leave-One-Out Cross validation" << endl;
-  } else 
-    Log::Info << "Starting " << folds 
+  } else
+    Log::Info << "Starting " << folds
 	      << "-fold Cross validation" << endl;
 
 
- 
+
   // obtaining the optimal tree
-  string unpruned_tree_estimate_file 
+  string unpruned_tree_estimate_file
     = CLI::GetParam<string>("u");
 
   Timer::Start("DET/Training");
-  DTree<float> *dtree_opt = Trainer<float>
+  DTree<double> *dtree_opt = Trainer<double>
     (&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
      CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
   Timer::Stop("DET/Training");
- 
+
   // computing densities for the train points in the
   // optimal tree
   FILE *fp = NULL;
@@ -118,13 +118,13 @@
   if (CLI::GetParam<string>("s") != "") {
     string optimal_estimates_file = CLI::GetParam<string>("s");
     fp = fopen(optimal_estimates_file.c_str(), "w");
-  }   
+  }
 
-  // Computation timing is more accurate when you do not 
+  // Computation timing is more accurate when you do not
   // perform the printing.
   Timer::Start("DET/EstimationTime");
   for (size_t i = 0; i < training_data.n_cols; i++) {
-    arma::Col<float> test_p = training_data.unsafe_col(i);
+    arma::Col<double> test_p = training_data.unsafe_col(i);
     long double f = dtree_opt->ComputeValue(&test_p);
     if (fp != NULL)
       fprintf(fp, "%Lg\n", f);
@@ -133,19 +133,19 @@
 
   if (fp != NULL)
     fclose(fp);
-  
 
+
   // computing the density at the provided test points
   // and outputting the density in the given file.
   if (CLI::GetParam<string>("T") != "") {
     string test_file = CLI::GetParam<string>("T");
-    arma::Mat<float> test_data;
+    arma::Mat<double> test_data;
     Log::Info << "Loading test set..." << endl;
     if (!data::Load(test_file, test_data))
-      Log::Fatal << "Test set file "<< test_file 
+      Log::Fatal << "Test set file "<< test_file
 		 << " can't be loaded." << endl;
 
-    Log::Info << "Test set (" << test_data.n_rows 
+    Log::Info << "Test set (" << test_data.n_rows
 	      << ", " << test_data.n_cols
 	      << ")" << endl;
 
@@ -159,14 +159,14 @@
 
     Timer::Start("DET/TestSetEstimation");
     for (size_t i = 0; i < test_data.n_cols; i++) {
-      arma::Col<float> test_p = test_data.unsafe_col(i);
+      arma::Col<double> test_p = test_data.unsafe_col(i);
       long double f = dtree_opt->ComputeValue(&test_p);
       if (fp != NULL)
 	fprintf(fp, "%Lg\n", f);
     } // end for
     Timer::Stop("DET/TestSetEstimation");
 
-    if (fp != NULL) 
+    if (fp != NULL)
       fclose(fp);
   } // Test set estimation
 
@@ -195,13 +195,13 @@
 
     Log::Info << "Loading label file..." << endl;
     if (!data::Load(labels_file, labels))
-      Log::Fatal << "Label file "<< labels_file 
+      Log::Fatal << "Label file "<< labels_file
 		 << " can't be loaded." << endl;
 
-    Log::Info << "Labels (" << labels.n_rows 
+    Log::Info << "Labels (" << labels.n_rows
 	      << ", " << labels.n_cols
 	      << ")" << endl;
-    
+
     size_t num_classes = CLI::GetParam<int>("C");
     if (num_classes == 0)
       Log::Fatal << "Please provide the number of classes"
@@ -210,14 +210,14 @@
     assert(training_data.n_cols == labels.n_cols);
     assert(labels.n_rows == 1);
 
-    PrintLeafMembership<float>
+    PrintLeafMembership<double>
       (dtree_opt, training_data, labels, num_classes,
        (string) CLI::GetParam<string>("l"));
   } // leaf class membership
-  
 
+
   if(CLI::HasParam("I")) {
-    PrintVariableImportance<float>
+    PrintVariableImportance<double>
       (dtree_opt, training_data.n_rows,
        (string) CLI::GetParam<string>("i"));
   } // print variable importance




More information about the mlpack-svn mailing list