[mlpack-git] master: Add functions that allow changing training-time parameters. (ea908de)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 30 11:46:41 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eb30b1775fa57bb1c0519be0e5faf0ccfbb2e336...ea908deb6ae205b99ae8ba063b716c1bd726babd

>---------------------------------------------------------------

commit ea908deb6ae205b99ae8ba063b716c1bd726babd
Author: ryan <ryan at ratml.org>
Date:   Wed Dec 30 11:45:32 2015 -0500

    Add functions that allow changing training-time parameters.


>---------------------------------------------------------------

ea908deb6ae205b99ae8ba063b716c1bd726babd
 .../methods/hoeffding_trees/hoeffding_tree.hpp     |  20 ++++
 .../hoeffding_trees/hoeffding_tree_impl.hpp        |  64 ++++++++++++
 src/mlpack/tests/hoeffding_tree_test.cpp           | 111 +++++++++++++++++++++
 3 files changed, 195 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
index ddf3d1b..f8f8a0c 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree.hpp
@@ -192,6 +192,26 @@ class HoeffdingTree
   //! Modify a child.
   HoeffdingTree& Child(const size_t i) { return *children[i]; }
 
+  //! Get the confidence required for a split.
+  double SuccessProbability() const { return successProbability; }
+  //! Modify the confidence required for a split.
+  void SuccessProbability(const double successProbability);
+
+  //! Get the minimum number of samples for a split.
+  size_t MinSamples() const { return minSamples; }
+  //! Modify the minimum number of samples for a split.
+  void MinSamples(const size_t minSamples);
+
+  //! Get the maximum number of samples before a split is forced.
+  size_t MaxSamples() const { return maxSamples; }
+  //! Modify the maximum number of samples before a split is forced.
+  void MaxSamples(const size_t maxSamples);
+
+  //! Get the number of samples before a split check is performed.
+  size_t CheckInterval() const { return checkInterval; }
+  //! Modify the number of samples before a split check is performed.
+  void CheckInterval(const size_t checkInterval);
+
   /**
    * Given a point and that this node is not a leaf, calculate the index of the
    * child node this point would go towards.  This method is primarily used by
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
index babbc33..3a6f52b 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_tree_impl.hpp
@@ -414,6 +414,70 @@ template<
     template<typename> class NumericSplitType,
     template<typename> class CategoricalSplitType
 >
+void HoeffdingTree<
+    FitnessFunction,
+    NumericSplitType,
+    CategoricalSplitType
+>::SuccessProbability(const double successProbability)
+{
+  this->successProbability = successProbability;
+  for (size_t i = 0; i < children.size(); ++i)
+    children[i]->SuccessProbability(successProbability);
+}
+
+template<
+    typename FitnessFunction,
+    template<typename> class NumericSplitType,
+    template<typename> class CategoricalSplitType
+>
+void HoeffdingTree<
+    FitnessFunction,
+    NumericSplitType,
+    CategoricalSplitType
+>::MinSamples(const size_t minSamples)
+{
+  this->minSamples = minSamples;
+  for (size_t i = 0; i < children.size(); ++i)
+    children[i]->MinSamples(minSamples);
+}
+
+template<
+    typename FitnessFunction,
+    template<typename> class NumericSplitType,
+    template<typename> class CategoricalSplitType
+>
+void HoeffdingTree<
+    FitnessFunction,
+    NumericSplitType,
+    CategoricalSplitType
+>::MaxSamples(const size_t maxSamples)
+{
+  this->maxSamples = maxSamples;
+  for (size_t i = 0; i < children.size(); ++i)
+    children[i]->MaxSamples(maxSamples);
+}
+
+template<
+    typename FitnessFunction,
+    template<typename> class NumericSplitType,
+    template<typename> class CategoricalSplitType
+>
+void HoeffdingTree<
+    FitnessFunction,
+    NumericSplitType,
+    CategoricalSplitType
+>::CheckInterval(const size_t checkInterval)
+{
+  this->checkInterval = checkInterval;
+  for (size_t i = 0; i < children.size(); ++i)
+    children[i]->CheckInterval(checkInterval);
+}
+
+template<
+    typename FitnessFunction,
+    template<typename> class NumericSplitType,
+    template<typename> class CategoricalSplitType
+>
 template<typename VecType>
 size_t HoeffdingTree<
     FitnessFunction,
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index cd0fe2f..ef7c146 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -983,4 +983,115 @@ BOOST_AUTO_TEST_CASE(BatchTrainingTest)
   BOOST_REQUIRE_GT(batchCorrect, streamCorrect);
 }
 
+// Make sure that changing the confidence properly propagates to all leaves.
+BOOST_AUTO_TEST_CASE(ConfidenceChangeTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 9000);
+  arma::Row<size_t> labels(9000);
+  data::DatasetInfo info(4); // All features are numeric, except the fourth.
+  info.MapString("0", 3);
+  for (size_t i = 0; i < 9000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  HoeffdingTree<> tree(info, 3, 0.5); // Low success probability.
+
+  size_t i = 0;
+  while ((tree.NumChildren() == 0) && (i < 9000))
+  {
+    tree.Train(dataset.col(i), labels[i]);
+    i++;
+  }
+
+  BOOST_REQUIRE_LT(i, 9000);
+
+  // Now we have split the root node, but we need to make sure we can feed
+  // through the rest of the points while requiring a confidence of 1.0, and
+  // make sure no splits happen.
+  tree.SuccessProbability(1.0);
+  tree.MaxSamples(0);
+
+  i = 0;
+  while ((tree.NumChildren() == 0) && (i < 90000))
+  {
+    tree.Train(dataset.col(i % 9000), labels[i % 9000]);
+    i++;
+  }
+
+  for (size_t c = 0; c < tree.NumChildren(); ++c)
+    BOOST_REQUIRE_EQUAL(tree.Child(c).NumChildren(), 0);
+}
+
+//! Make sure parameter changes are propagated to children.
+BOOST_AUTO_TEST_CASE(ParameterChangeTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 9000);
+  arma::Row<size_t> labels(9000);
+  data::DatasetInfo info(4); // All features are numeric, except the fourth.
+  info.MapString("0", 3);
+  for (size_t i = 0; i < 9000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  HoeffdingTree<> tree(dataset, info, labels, 3, true); // Batch training.
+
+  // Now change parameters...
+  tree.SuccessProbability(0.7);
+  tree.MinSamples(17);
+  tree.MaxSamples(192);
+  tree.CheckInterval(3);
+
+  std::stack<HoeffdingTree<>*> stack;
+  stack.push(&tree);
+  while (!stack.empty())
+  {
+    HoeffdingTree<>* node = stack.top();
+    stack.pop();
+
+    BOOST_REQUIRE_CLOSE(node->SuccessProbability(), 0.7, 1e-5);
+    BOOST_REQUIRE_EQUAL(node->MinSamples(), 17);
+    BOOST_REQUIRE_EQUAL(node->MaxSamples(), 192);
+    BOOST_REQUIRE_EQUAL(node->CheckInterval(), 3);
+
+    for (size_t i = 0; i < node->NumChildren(); ++i)
+      stack.push(&node->Child(i));
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list