[mlpack-git] master: Set majority class when making children. (845218e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:05 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 845218e03a6eff39019ddf590cfb070262d4354d
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 30 10:58:35 2015 -0400
Set majority class when making children.
Revert all manner of debugging output / changes.
>---------------------------------------------------------------
845218e03a6eff39019ddf590cfb070262d4354d
src/mlpack/core/data/load_impl.hpp | 4 ++--
.../hoeffding_trees/hoeffding_numeric_split_impl.hpp | 7 +++++++
src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp | 3 +++
.../methods/hoeffding_trees/hoeffding_split_impl.hpp | 3 ---
.../methods/hoeffding_trees/streaming_decision_tree.hpp | 8 ++++++--
.../hoeffding_trees/streaming_decision_tree_impl.hpp | 10 ++++++----
.../hoeffding_trees/streaming_decision_tree_main.cpp | 16 ++++++++--------
7 files changed, 32 insertions(+), 19 deletions(-)
diff --git a/src/mlpack/core/data/load_impl.hpp b/src/mlpack/core/data/load_impl.hpp
index 5497063..4f953b5 100644
--- a/src/mlpack/core/data/load_impl.hpp
+++ b/src/mlpack/core/data/load_impl.hpp
@@ -369,11 +369,11 @@ bool Load(const std::string& filename,
eT val = eT(0);
token >> val;
-// if (token.fail())
+ if (token.fail())
{
// Conversion failed; but it may be a NaN or inf. Armadillo has
// convenient functions to check.
-// if (!arma::diskio::convert_naninf(val, token.str()))
+ if (!arma::diskio::convert_naninf(val, token.str()))
{
// We need to perform a mapping.
const size_t dim = (transpose) ? col : row;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index d55e99b..81269b7 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -103,9 +103,16 @@ void HoeffdingNumericSplit<FitnessFunction, ObservationType>::CreateChildren(
{
// We'll make one child for each bin.
for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
+ {
+ // We need to set the majority class for the child, too.
children.push_back(StreamingDecisionTreeType(datasetInfo, dimensionality,
sufficientStatistics.n_rows));
+ arma::uword majorityClass;
+ sufficientStatistics.col(i).max(majorityClass);
+ children[i].MajorityClass() = majorityClass;
+ }
+
// Create the SplitInfo object.
splitInfo = SplitInfo(splitPoints);
}
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index 83a6d14..0a9bfcd 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -38,6 +38,9 @@ class HoeffdingSplit
//! Get the splitting dimension (size_t(-1) if no split).
size_t SplitDimension() const { return splitDimension; }
+ //! Modify the majority class.
+ size_t& MajorityClass() { return majorityClass; }
+
// Return index that we should go towards.
template<typename VecType>
size_t CalculateDirection(const VecType& point) const;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index cb007b1..aad4ef2 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -112,7 +112,6 @@ size_t HoeffdingSplit<
{
size_t type = dimensionMappings[i].first;
size_t index = dimensionMappings[i].second;
- Log::Warn << "Evaluate fitness function for dimension " << i << ".\n";
if (type == data::Datatype::categorical)
gains[i] = categoricalSplits[index].EvaluateFitnessFunction();
else if (type == data::Datatype::numeric)
@@ -137,8 +136,6 @@ size_t HoeffdingSplit<
}
}
- Log::Warn << "Split check (" << numSamples << "): largest " << largest << ", "
- << "second largest " << secondLargest << ", epsilon " << epsilon << ".\n";
// Are these far enough apart to split?
if (largest - secondLargest > epsilon || numSamples > maxSamples)
{
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
index 3d34738..ff60354 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree.hpp
@@ -22,11 +22,13 @@ class StreamingDecisionTree
StreamingDecisionTree(const MatType& data,
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
- const size_t numClasses);
+ const size_t numClasses,
+ const double confidence = 0.95);
StreamingDecisionTree(const data::DatasetInfo& datasetInfo,
const size_t dimensionality,
- const size_t numClasses);
+ const size_t numClasses,
+ const double confidence = 0.95);
StreamingDecisionTree(const StreamingDecisionTree& other);
@@ -47,6 +49,8 @@ class StreamingDecisionTree
void Classify(const MatType& data, arma::Row<size_t>& predictions);
+ size_t& MajorityClass() { return split.MajorityClass(); }
+
// How do we encode the actual split itself?
// that's just a split dimension and a rule (categorical or numeric)
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
index 99c2d8e..4026927 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_impl.hpp
@@ -18,8 +18,9 @@ StreamingDecisionTree<SplitType, MatType>::StreamingDecisionTree(
const MatType& data,
const data::DatasetInfo& datasetInfo,
const arma::Row<size_t>& labels,
- const size_t numClasses) :
- split(data.n_rows, numClasses, datasetInfo, 0.95, 5000)
+ const size_t numClasses,
+ const double confidence) :
+ split(data.n_rows, numClasses, datasetInfo, confidence, 1500)
{
Train(data, labels);
}
@@ -28,8 +29,9 @@ template<typename SplitType, typename MatType>
StreamingDecisionTree<SplitType, MatType>::StreamingDecisionTree(
const data::DatasetInfo& datasetInfo,
const size_t dimensionality,
- const size_t numClasses) :
- split(dimensionality, numClasses, datasetInfo, 0.95, 5000)
+ const size_t numClasses,
+ const double confidence) :
+ split(dimensionality, numClasses, datasetInfo, confidence, 1500)
{
// No training. Anything else to do...?
}
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index c541218..4c2ad7e 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -27,6 +27,7 @@ int main(int argc, char** argv)
const string trainingFile = CLI::GetParam<string>("training_file");
const string labelsFile = CLI::GetParam<string>("labels_file");
+ const double confidence = CLI::GetParam<double>("confidence");
arma::mat trainingSet;
DatasetInfo datasetInfo;
@@ -41,26 +42,22 @@ int main(int argc, char** argv)
// Now create the decision tree.
StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
- max(labels) + 1);
+ max(labels) + 1, confidence);
// Great. Good job team.
std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
stack.push(&tree);
+ size_t nodes = 0;
while (!stack.empty())
{
StreamingDecisionTree<HoeffdingSplit<>>* node = stack.top();
stack.pop();
-
- Log::Info << "Node:\n";
- Log::Info << " split dimension " << node->Split().SplitDimension()
- << ".\n";
- Log::Info << " majority class " << node->Split().Classify(arma::vec())
- << ".\n";
- Log::Info << " children " << node->NumChildren() << ".\n";
+ ++nodes;
for (size_t i = 0; i < node->NumChildren(); ++i)
stack.push(&node->Child(i));
}
+ Log::Info << nodes << " nodes in tree.\n";
// Check the accuracy on the training set.
arma::Row<size_t> predictedLabels;
@@ -70,6 +67,9 @@ int main(int argc, char** argv)
for (size_t i = 0; i < predictedLabels.n_elem; ++i)
if (labels[i] == predictedLabels[i])
++correct;
+ else if (predictedLabels[i] > 10)
+ Log::Warn << "Invalid label " << predictedLabels[i] << " for point " << i
+ << "!\n";
Log::Info << correct << " correct out of " << predictedLabels.n_elem << ".\n";
}
More information about the mlpack-git
mailing list