[mlpack-git] master: Refactor main program to use boost::serialization for saving. (ad4f7b3)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:04:55 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc

>---------------------------------------------------------------

commit ad4f7b3e38771f658b3cb10e721d5813564b2e3d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Jul 12 13:31:18 2015 +0000

    Refactor main program to use boost::serialization for saving.


>---------------------------------------------------------------

ad4f7b3e38771f658b3cb10e721d5813564b2e3d
 src/mlpack/methods/gmm/CMakeLists.txt |  1 +
 src/mlpack/methods/gmm/gmm_main.cpp   | 17 ++++++++-----
 src/mlpack/methods/gmm/gmm_util.hpp   | 45 +++++++++++++++++++++++++++++++++++
 3 files changed, 57 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/gmm/CMakeLists.txt b/src/mlpack/methods/gmm/CMakeLists.txt
index 981f853..16633b3 100644
--- a/src/mlpack/methods/gmm/CMakeLists.txt
+++ b/src/mlpack/methods/gmm/CMakeLists.txt
@@ -9,6 +9,7 @@ set(SOURCES
   positive_definite_constraint.hpp
   diagonal_constraint.hpp
   eigenvalue_ratio_constraint.hpp
+  gmm_util.hpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/gmm/gmm_main.cpp b/src/mlpack/methods/gmm/gmm_main.cpp
index 4215fc6..e1046f9 100644
--- a/src/mlpack/methods/gmm/gmm_main.cpp
+++ b/src/mlpack/methods/gmm/gmm_main.cpp
@@ -8,6 +8,7 @@
 
 #include "gmm.hpp"
 #include "no_constraint.hpp"
+#include "gmm_util.hpp"
 
 #include <mlpack/methods/kmeans/refined_start.hpp>
 
@@ -39,8 +40,8 @@ PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
 PARAM_STRING_REQ("input_file", "File containing the data on which the model "
     "will be fit.", "i");
 PARAM_INT("gaussians", "Number of Gaussians in the GMM.", "g", 1);
-PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
-    "(as XML).", "o", "gmm.xml");
+PARAM_STRING("output_file", "The file to write the trained GMM parameters "
+    "into.", "o", "gmm.xml");
 PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
 
@@ -137,7 +138,8 @@ int main(int argc, char* argv[])
       Timer::Stop("em");
 
       // Save results.
-      gmm.Save(CLI::GetParam<string>("output_file"));
+      const string outputFile = CLI::GetParam<string>("output_file");
+      SaveGMM(gmm, outputFile);
     }
     else
     {
@@ -152,7 +154,8 @@ int main(int argc, char* argv[])
       Timer::Stop("em");
 
       // Save results.
-      gmm.Save(CLI::GetParam<string>("output_file"));
+      const string outputFile = CLI::GetParam<string>("output_file");
+      SaveGMM(gmm, outputFile);
     }
   }
   else
@@ -171,7 +174,8 @@ int main(int argc, char* argv[])
       Timer::Stop("em");
 
       // Save results.
-      gmm.Save(CLI::GetParam<string>("output_file"));
+      const string outputFile = CLI::GetParam<string>("output_file");
+      SaveGMM(gmm, outputFile);
     }
     else
     {
@@ -188,7 +192,8 @@ int main(int argc, char* argv[])
       Timer::Stop("em");
 
       // Save results.
-      gmm.Save(CLI::GetParam<string>("output_file"));
+      const string outputFile = CLI::GetParam<string>("output_file");
+      SaveGMM(gmm, outputFile);
     }
   }
 
diff --git a/src/mlpack/methods/gmm/gmm_util.hpp b/src/mlpack/methods/gmm/gmm_util.hpp
new file mode 100644
index 0000000..0390883
--- /dev/null
+++ b/src/mlpack/methods/gmm/gmm_util.hpp
@@ -0,0 +1,45 @@
+/**
+ * @file gmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Utility to save GMMs to files.
+ */
+#ifndef __MLPACK_METHODS_GMM_GMM_UTIL_HPP
+#define __MLPACK_METHODS_GMM_GMM_UTIL_HPP
+
+namespace mlpack {
+namespace gmm {
+
+// Save a GMM to file using boost::serialization.
+// This does not save a type id, however.
+template<typename GMMType>
+void SaveGMM(GMMType& g, const std::string filename)
+{
+  using namespace boost::archive;
+
+  const std::string extension = data::Extension(filename);
+  std::ofstream ofs(filename);
+  if (extension == "xml")
+  {
+    xml_oarchive ar(ofs);
+    ar << data::CreateNVP(g, "gmm");
+  }
+  else if (extension == "bin")
+  {
+    binary_oarchive ar(ofs);
+    ar << data::CreateNVP(g, "gmm");
+  }
+  else if (extension == "txt")
+  {
+    text_oarchive ar(ofs);
+    ar << data::CreateNVP(g, "gmm");
+  }
+  else
+    Log::Fatal << "Unknown extension '" << extension << "' for GMM model file "
+        << "(known: 'xml', 'bin', 'txt')." << std::endl;
+}
+
+} // namespace gmm
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list