[mlpack-git] master: First pass at serialization. (71aca20)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 23 11:43:17 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de9cc4b05069e1fa4793d9355f2f595af5ff45d2...6070527af14296cd99739de6c62666cc5d2a2125

>---------------------------------------------------------------

commit 71aca2084f18243b73644943a092b6c86a5cd30d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Sep 30 16:34:52 2015 -0400

    First pass at serialization.


>---------------------------------------------------------------

71aca2084f18243b73644943a092b6c86a5cd30d
 .../hoeffding_trees/categorical_split_info.hpp     |  4 ++
 .../hoeffding_categorical_split.hpp                |  7 +++
 .../hoeffding_trees/hoeffding_numeric_split.hpp    |  3 ++
 .../hoeffding_numeric_split_impl.hpp               | 41 ++++++++++++++
 .../methods/hoeffding_trees/hoeffding_split.hpp    |  4 ++
 .../hoeffding_trees/hoeffding_split_impl.hpp       | 62 ++++++++++++++++++++++
 .../methods/hoeffding_trees/numeric_split_info.hpp |  7 +++
 .../streaming_decision_tree_main.cpp               |  4 ++
 8 files changed, 132 insertions(+)

diff --git a/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp b/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
index 21a2927..73d29ab 100644
--- a/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/categorical_split_info.hpp
@@ -24,6 +24,10 @@ class CategoricalSplitInfo
     // range [0, categories).
     return size_t(value);
   }
+
+  //! Serialize the object.  (Nothing needs to be saved.)
+  template<typename Archive>
+  void Serialize(Archive& /* ar */, const unsigned int /* version */) { }
 };
 
 } // namespace tree
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
index 6c2edcf..22aeb5a 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_categorical_split.hpp
@@ -51,6 +51,13 @@ class HoeffdingCategoricalSplit
 
   size_t MajorityClass() const;
 
+  //! Serialize the categorical split.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(sufficientStatistics, "sufficientStatistics");
+  }
+
  private:
   arma::Mat<size_t> sufficientStatistics;
 };
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
index e4ba4f0..d99045e 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split.hpp
@@ -63,6 +63,9 @@ class HoeffdingNumericSplit
 
   size_t Bins() const { return bins; }
 
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   // Cache the values of the points seen before we make bins.
   arma::Col<ObservationType> observations;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
index f6a1284..0e18cd9 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_numeric_split_impl.hpp
@@ -139,6 +139,47 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
   }
 }
 
+template<typename FitnessFunction, typename ObservationType>
+template<typename Archive>
+void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(samplesSeen, "samplesSeen");
+  ar & CreateNVP(observationsBeforeBinning, "observationsBeforeBinning");
+  ar & CreateNVP(bins, "bins");
+
+  if (samplesSeen > observationsBeforeBinning)
+  {
+    // The binning has happened, so we only need to save the resulting bins.
+    ar & CreateNVP(splitPoints, "splitPoints");
+    ar & CreateNVP(sufficientStatistics, "sufficientStatistics");
+
+    if (Archive::is_loading::value)
+    {
+      // Clean other objects.
+      observations.clear();
+      labels.clear();
+    }
+  }
+  else
+  {
+    // The binning has not happened yet, so we only need to save the information
+    // required before binning.
+    ar & CreateNVP(observations, "observations");
+    ar & CreateNVP(labels, "labels");
+
+    if (Archive::is_loading::value)
+    {
+      // Clean other objects.
+      splitPoints.clear();
+      sufficientStatistics.clear();
+    }
+  }
+}
+
 } // namespace tree
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
index 31d352a..5af29ae 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split.hpp
@@ -58,6 +58,10 @@ class HoeffdingSplit
   template<typename StreamingDecisionTreeType>
   void CreateChildren(std::vector<StreamingDecisionTreeType>& children);
 
+  //! Serialize the split.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
  private:
   // We need to keep some information for before we have split.
   std::vector<NumericSplitType> numericSplits;
diff --git a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
index e698552..69ed643 100644
--- a/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
+++ b/src/mlpack/methods/hoeffding_trees/hoeffding_split_impl.hpp
@@ -304,6 +304,68 @@ void HoeffdingSplit<
   categoricalSplits.clear();
 }
 
+template<
+    typename FitnessFunction,
+    typename NumericSplitType,
+    typename CategoricalSplitType
+>
+template<typename Archive>
+void HoeffdingSplit<
+    FitnessFunction,
+    NumericSplitType,
+    CategoricalSplitType
+>::Serialize(Archive& ar, const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(splitDimension, "splitDimension");
+  ar & CreateNVP(dimensionMappings, "dimensionMappings");
+  // What to do here about ownership...?
+  if (Archive::is_loading::value)
+    ownsMappings = true;
+
+  // Depending on whether or not we have split yet, we may need to save
+  // different things.
+  if (splitDimension == size_t(-1))
+  {
+    // We have not yet split.  So we have to serialize the splits.
+    ar & CreateNVP(numericSplits, "numericSplits");
+    ar & CreateNVP(categoricalSplits, "categoricalSplits");
+
+    ar & CreateNVP(numSamples, "numSamples");
+    ar & CreateNVP(numClasses, "numClasses");
+    ar & CreateNVP(maxSamples, "maxSamples");
+    ar & CreateNVP(successProbability, "successProbability");
+
+    if (Archive::is_loading::value)
+    {
+      // Clear things we don't need.
+      majorityClass = 0;
+      categoricalSplit = CategoricalSplitType::SplitInfo();
+      numericSplit = NumericSplitType::SplitInfo();
+    }
+  }
+  else
+  {
+    // We have split, so we only need to cache the numeric and categorical
+    // split.
+    ar & CreateNVP(categoricalSplit, "categoricalSplit");
+    ar & CreateNVP(numericSplit, "numericSplit");
+    ar & CreateNVP(majorityClass, "majorityClass");
+
+    if (Archive::is_loading::value)
+    {
+      numericSplits.clear();
+      categoricalSplits.clear();
+
+      numSamples = 0;
+      numClasses = 0;
+      maxSamples = 0;
+      successProbability = 0.0;
+    }
+  }
+}
+
 } // namespace tree
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
index 1d751dc..eaac2db 100644
--- a/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
+++ b/src/mlpack/methods/hoeffding_trees/numeric_split_info.hpp
@@ -31,6 +31,13 @@ class NumericSplitInfo
     return bin;
   }
 
+  //! Serialize the split (save/load the split points).
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(splitPoints, "splitPoints");
+  }
+
  private:
   arma::Col<ObservationType> splitPoints;
 };
diff --git a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
index 0d85848..35a304d 100644
--- a/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
+++ b/src/mlpack/methods/hoeffding_trees/streaming_decision_tree_main.cpp
@@ -46,8 +46,10 @@ int main(int argc, char** argv)
   arma::Row<size_t> labels = labelsIn.t();
 
   // Now create the decision tree.
+  Timer::Start("tree_training");
   StreamingDecisionTree<HoeffdingSplit<>> tree(trainingSet, datasetInfo, labels,
       max(labels) + 1, confidence, maxSamples);
+  Timer::Stop("tree_training");
 
   // Great.  Good job team.
   std::stack<StreamingDecisionTree<HoeffdingSplit<>>*> stack;
@@ -65,8 +67,10 @@ int main(int argc, char** argv)
   Log::Info << nodes << " nodes in tree.\n";
 
   // Check the accuracy on the training set.
+  Timer::Start("tree_testing");
   arma::Row<size_t> predictedLabels;
   tree.Classify(trainingSet, predictedLabels);
+  Timer::Stop("tree_testing");
 
   size_t correct = 0;
   for (size_t i = 0; i < predictedLabels.n_elem; ++i)



More information about the mlpack-git mailing list