[mlpack-git] master: First pass at serialization. (71aca20)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:17 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125
>---------------------------------------------------------------
commit 71aca2084f18243b73644943a092b6c86a5cd30d
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Sep 30 16:34:52 2015 -0400
First pass at serialization.
>---------------------------------------------------------------
71aca2084f18243b73644943a092b6c86a5cd30d
.../hoeffding_trees/categorical_split_info.hpp | 4 ++
.../hoeffding_categorical_split.hpp | 7 +++
.../hoeffding_trees/hoeffding_numeric_split.hpp | 3 ++
.../hoeffding_numeric_split_impl.hpp | 41 ++++++++++++++
.../methods/hoeffding_trees/hoeffding_split.hpp | 4 ++
.../hoeffding_trees/hoeffding_split_impl.hpp | 62 ++++++++++++++++++++++
.../methods/hoeffding_trees/numeric_split_info.hpp | 7 +++
.../streaming_decision_tree_main.cpp | 4 ++
8 files changed, 132 insertions(+)
diff --git a/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp b/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
index 21a2927..73d29ab 100644
--- a/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
@@ -24,6 +24,10 @@ class CategoricalSplitInfo
// range [0, categories).
return size_t(value);
}
+
+ //! Serialize the object. (Nothing needs to be saved.)
+ template<typename Archive>
+ void Serialize(Archive& /* ar */, const unsigned int /* version */) { }
};
} // namespace tree
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 6c2edcf..22aeb5a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -51,6 +51,13 @@ class HoeffdingCategoricalSplit
size_t MajorityClass() const;
+ //! Serialize the categorical split.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(sufficientStatistics, "sufficientStatistics");
+ }
+
private:
arma::Mat<size_t> sufficientStatistics;
};
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index e4ba4f0..d99045e 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -63,6 +63,9 @@ class HoeffdingNumericSplit
size_t Bins() const { return bins; }
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
// Cache the values of the points seen before we make bins.
arma::Col<ObservationType> observations;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index f6a1284..0e18cd9 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -139,6 +139,47 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
}
}
+template<typename FitnessFunction, typename ObservationType>
+template<typename Archive>
+void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(samplesSeen, "samplesSeen");
+ ar & CreateNVP(observationsBeforeBinning, "observationsBeforeBinning");
+ ar & CreateNVP(bins, "bins");
+
+ if (samplesSeen > observationsBeforeBinning)
+ {
+ // The binning has happened, so we only need to save the resulting bins.
+ ar & CreateNVP(splitPoints, "splitPoints");
+ ar & CreateNVP(sufficientStatistics, "sufficientStatistics");
+
+ if (Archive::is_loading::value)
+ {
+ // Clean other objects.
+ observations.clear();
+ labels.clear();
+ }
+ }
+ else
+ {
+ // The binning has not happened yet, so we only need to save the information
+ // required before binning.
+ ar & CreateNVP(observations, "observations");
+ ar & CreateNVP(labels, "labels");
+
+ if (Archive::is_loading::value)
+ {
+ // Clean other objects.
+ splitPoints.clear();
+ sufficientStatistics.clear();
+ }
+ }
+}
+
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index 31d352a..5af29ae 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -58,6 +58,10 @@ class HoeffdingSplit
template<typename StreamingDecisionTreeType>
void CreateChildren(std::vector<StreamingDecisionTreeType>& children);
+ //! Serialize the split.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
// We need to keep some information for before we have split.
std::vector<NumericSplitType> numericSplits;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index e698552..69ed643 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -304,6 +304,68 @@ void HoeffdingSplit<
categoricalSplits.clear();
}
+template<
+ typename FitnessFunction,
+ typename NumericSplitType,
+ typename CategoricalSplitType
+>
+template<typename Archive>
+void HoeffdingSplit<
+ FitnessFunction,
+ NumericSplitType,
+ CategoricalSplitType
+>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ ar & CreateNVP(splitDimension, "splitDimension");
+ ar & CreateNVP(dimensionMappings, "dimensionMappings");
+ // What to do here about ownership...?
+ if (Archive::is_loading::value)
+ ownsMappings = true;
+
+ // Depending on whether or not we have split yet, we may need to save
+ // different things.
+ if (splitDimension == size_t(-1))
+ {
+ // We have not yet split. So we have to serialize the splits.
+ ar & CreateNVP(numericSplits, "numericSplits");
+ ar & CreateNVP(categoricalSplits, "categoricalSplits");
+
+ ar & CreateNVP(numSamples, "numSamples");
+ ar & CreateNVP(numClasses, "numClasses");
+ ar & CreateNVP(maxSamples, "maxSamples");
+ ar & CreateNVP(successProbability, "successProbability");
+
+ if (Archive::is_loading::value)
+ {
+ // Clear things we don't need.
+ majorityClass = 0;
+ categoricalSplit = CategoricalSplitType::SplitInfo();
+ numericSplit = NumericSplitType::SplitInfo();
+ }
+ }
+ else
+ {
+ // We have split, so we only need to cache the numeric and categorical
+ // split.
+ ar & CreateNVP(categoricalSplit, "categoricalSplit");
+ ar & CreateNVP(numericSplit, "numericSplit");
+ ar & CreateNVP(majorityClass, "majorityClass");
+
+ if (Archive::is_loading::value)
+ {
+ numericSplits.clear();
+ categoricalSplits.clear();
+
+ numSamples = 0;
+ numClasses = 0;
+ maxSamples = 0;
+ successProbability = 0.0;
+ }
+ }
+}
+
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
index 1d751dc..eaac2db 100644
--- a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
@@ -31,6 +31,13 @@ class NumericSplitInfo
return bin;
}
+ //! Serialize the split (save/load the split points).
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(splitPoints, "splitPoints");
+ }
+
private:
arma::Col<ObservationType> splitPoints;
};
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 0d85848..35a304d 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -46,8 +46,10 @@ int main(int argc, char** argv)
arma::Row<size_t> labels = labelsIn.t();
// Now create the decision tree.
+ Timer::Start("tree_training");
StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
max(labels) + 1, confidence, maxSamples);
+ Timer::Stop("tree_training");
// Great. Good job team.
std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
@@ -65,8 +67,10 @@ int main(int argc, char** argv)
Log::Info << nodes << " nodes in tree.\n";
// Check the accuracy on the training set.
+ Timer::Start("tree_testing");
arma::Row<size_t> predictedLabels;
tree.Classify(trainingSet, predictedLabels);
+ Timer::Stop("tree_testing");
size_t correct = 0;
for (size_t i = 0; i < predictedLabels.n_elem; ++i)
More information about the mlpack-git
mailing list