[mlpack-git] master: First pass at batch training. (b841b13)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:45:49 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit b841b13d0512fa55dbc69609b717c1f5f3a5a39a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Nov 1 17:47:25 2015 +0000

    First pass at batch training.


>---------------------------------------------------------------

b841b13d0512fa55dbc69609b717c1f5f3a5a39a
 .../hoeffding_trees/hoeffding_tree_impl.hpp        | 46 +++++++++++++++++++++-
 1 file changed, 45 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index badd55a..4c802df 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -186,7 +186,51 @@ void HoeffdingTree<
   // Not yet implemented.
   if (batchTraining)
   {
-    throw std::invalid_argument("batch training not yet implemented");
+    // Pass all the points through the nodes, and then split only after that.
+    checkInterval = data.n_cols; // Only split on the last sample.
+    for (size_t i = 0; i < data.n_cols; ++i)
+      Train(data.col(i), labels[i]);
+
+    // Now, if we did split, find out which points go to which child, and
+    // perform the same batch training.
+    if (children.size() > 0)
+    {
+      // We need to create a vector of indices that represent the points that
+      // must go to each child, so we need children.size() vectors, but we don't
+      // know how long they will be.  Therefore, we will create vectors each of
+      // size data.n_cols, but will probably not use all the memory we
+      // allocated, and then pass subvectors to the submat() function.
+      std::vector<arma::uvec> indices(children.size(), arma::uvec(data.n_cols));
+      arma::Col<size_t> counts =
+          arma::zeros<arma::Col<size_t>>(children.size());
+
+      for (size_t i = 0; i < data.n_cols; ++i)
+      {
+        size_t direction = CalculateDirection(data.col(i));
+        size_t currentIndex = counts[direction];
+        indices[direction][currentIndex] = i;
+        counts[direction]++;
+      }
+
+      // Now pass each of these submatrices to the children to perform
+      // batch-mode training.
+      for (size_t i = 0; i < children.size(); ++i)
+      {
+        // The submatrix here is non-contiguous, but I think this will be faster
+        // than copying the points to an ordered state.  We still have to
+        // assemble the labels vector, though.
+        arma::Row<size_t> childLabels = labels.cols(
+            indices[i].subvec(0, counts[i] - 1));
+
+        // Unfortunately, limitations of Armadillo's non-contiguous subviews
+        // prohibits us from successfully passing the non-contiguous subview to
+        // Train(), since the col() function is not provided.  So,
+        // unfortunately, instead, we'll just extract the non-contiguous
+        // submatrix.
+        MatType childData = data.cols(indices[i].subvec(0, counts[i] - 1));
+        children[i].Train(childData, childLabels, true);
+      }
+    }
   }
   else
   {



More information about the mlpack-git mailing list