[mlpack-git] master: Add the number of passes to the program. (7bd2144)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:46:22 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 7bd21440b9dceaa3baf9b1e5063532a483784859
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Nov 18 11:20:27 2015 -0800

    Add the number of passes to the program.


>---------------------------------------------------------------

7bd21440b9dceaa3baf9b1e5063532a483784859
 .../methods/hoeffding_trees/hoeffding_tree_main.cpp     | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index 86f22cf..e644c2b 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -38,6 +38,7 @@ PARAM_FLAG("batch_mode", "If true, samples will be considered in batch instead "
     " memory usage and runtime.", "b");
 PARAM_FLAG("info_gain", "If set, information gain is used instead of Gini "
     "impurity for calculating Hoeffding bounds.", "i");
+PARAM_INT("passes", "Number of passes to take over the dataset.", "p", 1);
 
 // Helper function for once we have chosen a tree type.
 template<typename TreeType>
@@ -73,6 +74,10 @@ int main(int argc, char** argv)
   if (trainingFile.empty() && CLI::HasParam("batch_mode"))
     Log::Warn << "--batch_mode (-b) ignored; no training set provided." << endl;
 
+  if (CLI::HasParam("passes") && CLI::HasParam("batch_mode"))
+    Log::Warn << "--batch_mode (-b) ignored because --passes was specified."
+        << endl;
+
   if (CLI::HasParam("info_gain"))
   {
     if (numericSplitStrategy == "domingos")
@@ -115,6 +120,7 @@ void PerformActions()
   const string predictionsFile = CLI::GetParam<string>("predictions_file");
   const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
   const bool batchTraining = CLI::HasParam("batch_mode");
+  const size_t passes = (size_t) CLI::GetParam<int>("passes");
 
   TreeType* tree = NULL;
   DatasetInfo datasetInfo;
@@ -155,7 +161,16 @@ void PerformActions()
 
       // Now create the decision tree.
       Timer::Start("tree_training");
-      tree->Train(trainingSet, labels, batchTraining);
+      if (passes > 1)
+      {
+        Log::Info << "Taking " << passes << " passes over the dataset." << endl;
+        for (size_t i = 0; i < passes; ++i)
+          tree->Train(trainingSet, labels, false);
+      }
+      else
+      {
+        tree->Train(trainingSet, labels, batchTraining);
+      }
       Timer::Stop("tree_training");
     }
   }



More information about the mlpack-git mailing list