[mlpack-git] master: Add batch training option. (e7294b8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:45:53 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit e7294b8c1fe7b78792de7afc8a5d2d945ffd819d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 2 13:37:50 2015 +0000

    Add batch training option.


>---------------------------------------------------------------

e7294b8c1fe7b78792de7afc8a5d2d945ffd819d
 src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
index f09e4d5..32e465e 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_main.cpp
@@ -32,6 +32,9 @@ PARAM_STRING("probabilities_file", "In addition to predicting labels, provide "
 
 PARAM_STRING("numeric_split_strategy", "The splitting strategy to use for "
     "numeric features: 'domingos' or 'binary'.", "N", "binary");
+PARAM_FLAG("batch_mode", "If true, samples will be considered in batch instead "
+    "of as a stream.  This generally results in better trees but at the cost of"
+    " memory usage and runtime.", "b");
 
 // Helper function for once we have chosen a tree type.
 template<typename TreeType>
@@ -64,6 +67,9 @@ int main(int argc, char** argv)
     Log::Fatal << "If --training_file is specified, --labels_file must be "
         << "specified too!" << endl;
 
+  if (trainingFile.empty() && CLI::HasParam("batch_mode"))
+    Log::Warn << "--batch_mode (-b) ignored; no training set provided." << endl;
+
   if (numericSplitStrategy == "domingos")
     PerformActions<HoeffdingTree<GiniImpurity, HoeffdingDoubleNumericSplit,
         HoeffdingCategoricalSplit>>();
@@ -89,6 +95,7 @@ void PerformActions()
   const string testFile = CLI::GetParam<string>("test_file");
   const string predictionsFile = CLI::GetParam<string>("predictions_file");
   const string probabilitiesFile = CLI::GetParam<string>("probabilities_file");
+  const bool batchTraining = CLI::HasParam("batch_mode");
 
   TreeType* tree = NULL;
   DatasetInfo datasetInfo;
@@ -107,7 +114,7 @@ void PerformActions()
     // Now create the decision tree.
     Timer::Start("tree_training");
     tree = new TreeType(trainingSet, datasetInfo, labels, max(labels) + 1,
-        confidence, maxSamples);
+        batchTraining, confidence, maxSamples);
     Timer::Stop("tree_training");
   }
   else
@@ -129,7 +136,7 @@ void PerformActions()
 
       // Now create the decision tree.
       Timer::Start("tree_training");
-      tree->Train(trainingSet, labels);
+      tree->Train(trainingSet, labels, batchTraining);
       Timer::Stop("tree_training");
     }
   }



More information about the mlpack-git mailing list