[mlpack-git] master: First pass at Serialize(). (16a6d65)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:04:30 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc

>---------------------------------------------------------------

commit 16a6d65c43c0a1de0e8d1d9505a8e5cff9a1e71a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Jul 11 12:26:40 2015 +0000

    First pass at Serialize().


>---------------------------------------------------------------

16a6d65c43c0a1de0e8d1d9505a8e5cff9a1e71a
 src/mlpack/methods/hmm/hmm.hpp      | 11 +++++----
 src/mlpack/methods/hmm/hmm_impl.hpp | 47 +++++++++++--------------------------
 2 files changed, 20 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/hmm/hmm.hpp b/src/mlpack/methods/hmm/hmm.hpp
index 03d3049..81b1167 100644
--- a/src/mlpack/methods/hmm/hmm.hpp
+++ b/src/mlpack/methods/hmm/hmm.hpp
@@ -321,16 +321,17 @@ class HMM
    */
   std::string ToString() const;
 
-  //! Save to SaveRestoreUtility
-  void Save(util::SaveRestoreUtility& sr) const;
-  //! Load from SaveRestoreUtility
-  void Load(const util::SaveRestoreUtility& sr);
-
   /**
    * Returns a string indicating the type.
    */
   static std::string const Type() { return "HMM"; }
 
+  /**
+   * Serialize the object.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int version);
+
  protected:
   // Helper functions.
   /**
diff --git a/src/mlpack/methods/hmm/hmm_impl.hpp b/src/mlpack/methods/hmm/hmm_impl.hpp
index b662511..783460d 100644
--- a/src/mlpack/methods/hmm/hmm_impl.hpp
+++ b/src/mlpack/methods/hmm/hmm_impl.hpp
@@ -565,44 +565,25 @@ std::string HMM<Distribution>::ToString() const
   return convert.str();
 }
 
-//! Save to SaveRestoreUtility
+//! Serialize the HMM.
+template<typename Archive>
 template<typename Distribution>
-void HMM<Distribution>::Save(util::SaveRestoreUtility& sr) const
+void HMM<Distribution>::Serialize(Archive& ar, const unsigned int /* version */)
 {
-  //  Save parameters.
-  sr.SaveParameter(Type(), "type");
-  sr.SaveParameter(Emission()[0].Type(), "emission_type");
-  sr.SaveParameter(dimensionality, "dimensionality");
-  sr.SaveParameter(transition.n_rows, "states");
-  sr.SaveParameter(transition, "transition");
-
-  // Now the emissions.
-  util::SaveRestoreUtility mn;
-  for (size_t i = 0; i < transition.n_rows; ++i)
-  {
-    // Generate name.
-    std::stringstream s;
-    s << "emission_distribution_" << i;
-    Emission()[i].Save(mn);
-    sr.AddChild(mn, s.str());
-  }
-}
+  ar & data::CreateNVP(dimensionality, "dimensionality");
+  ar & data::CreateNVP(transition, "transition");
 
-//! Load from SaveRestoreUtility
-template<typename Distribution>
-void HMM<Distribution>::Load(const util::SaveRestoreUtility& sr)
-{
-  // Load parameters.
-  sr.LoadParameter(dimensionality, "dimensionality");
-  sr.LoadParameter(transition, "transition");
+  // Now serialize each emission.  If we are loading, we must resize the vector
+  // of emissions correctly.
+  if (Archive::is_loading::value)
+    emission.resize(transition.n_rows);
 
-  // Now each emission distribution.
-  Emission().resize(transition.n_rows);
-  for (size_t i = 0; i < transition.n_rows; ++i)
+  // Load the emissions; generate the correct name for each one.
+  for (size_t i = 0; i < emission.size(); ++i)
   {
-    std::stringstream s;
-    s << "emission_distribution_" << i;
-    Emission()[i].Load(sr.Children().at(s.str()));
+    std::ostringstream oss;
+    oss << "emission" << i;
+    ar & data::CreateNVP(emission[i], oss.str());
   }
 }
 



More information about the mlpack-git mailing list