[mlpack-git] master: Refactor so that dimensionality is a parameter. (0e795f3)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:42:44 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 0e795f33823dfd40c7a2b767ae99d7b348ab6070
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 23 17:34:01 2015 -0400

    Refactor so that dimensionality is a parameter.
    
    This is still an ugly set of abstractions.  Maybe DatasetInfo should contain the
    dimensionality, but I am not yet sure.


>---------------------------------------------------------------

0e795f33823dfd40c7a2b767ae99d7b348ab6070
 src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp  | 1 +
 .../methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp    | 3 ++-
 src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp      | 1 +
 src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp         | 6 ++++--
 src/mlpack/tests/hoeffding_tree_test.cpp                            | 2 +-
 5 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index d41374d..2615ae0 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -50,6 +50,7 @@ class HoeffdingCategoricalSplit
   template<typename StreamingDecisionTreeType>
   void CreateChildren(std::vector<StreamingDecisionTreeType>& children,
                       const data::DatasetInfo& datasetInfo,
+                      const size_t dimensionality,
                       SplitInfo& splitInfo);
 
   size_t MajorityClass() const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
index b05f8df..32d2375 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split_impl.hpp
@@ -44,11 +44,12 @@ template<typename StreamingDecisionTreeType>
 void HoeffdingCategoricalSplit<FitnessFunction>::CreateChildren(
     std::vector<StreamingDecisionTreeType>& children,
     const data::DatasetInfo& datasetInfo,
+    const size_t dimensionality,
     SplitInfo& splitInfo)
 {
   // We'll make one child for each category.
   for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
-    children.push_back(StreamingDecisionTreeType(datasetInfo, 3,
+    children.push_back(StreamingDecisionTreeType(datasetInfo, dimensionality,
         sufficientStatistics.n_rows));
 
   // Create the according SplitInfo object.
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index 78a46d9..bda7744 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -30,6 +30,7 @@ class HoeffdingNumericSplit
   template<typename StreamingDecisionTreeType>
   void CreateChildren(std::vector<StreamingDecisionTreeType>& children,
                       const data::DatasetInfo& datasetInfo,
+                      const size_t dimensionality,
                       SplitInfo& splitInfo) { } // Nothing to do.
 
   size_t MajorityClass() const { return 0; } // Nothing yet.
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index 2387b51..6e98760 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -199,15 +199,17 @@ void HoeffdingSplit<
       ++categoricalSplitIndex;
   }
 
+  // Create the children.
   if (datasetInfo.Type(splitDimension) == data::Datatype::numeric)
   {
     numericSplits[numericSplitIndex].CreateChildren(children, datasetInfo,
-        numericSplit);
+        numericSplits.size() + categoricalSplits.size(), numericSplit);
   }
   else if (datasetInfo.Type(splitDimension) == data::Datatype::categorical)
   {
     categoricalSplits[categoricalSplitIndex].CreateChildren(children,
-        datasetInfo, categoricalSplit);
+        datasetInfo, numericSplits.size() + categoricalSplits.size(),
+        categoricalSplit);
   }
 }
 
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index e7f5609..d34f945 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -212,7 +212,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
   HoeffdingCategoricalSplit<GiniImpurity>::SplitInfo splitInfo(3);
 
   // Create the children.
-  split.CreateChildren(children, info, splitInfo);
+  split.CreateChildren(children, info, 3, splitInfo);
 
   BOOST_REQUIRE_EQUAL(children.size(), 3);
   BOOST_REQUIRE_EQUAL(splitInfo.CalculateDirection(0), 0);



More information about the mlpack-git mailing list