[mlpack-git] master: Refactor main program to use boost::serialization for saving. (ad4f7b3)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Jul 13 04:04:55 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8b2ca720828224607c70d2b539c43aecf8f4ec32...b4659b668021db631b3c8a48e3d735b513706fdc
>---------------------------------------------------------------
commit ad4f7b3e38771f658b3cb10e721d5813564b2e3d
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Jul 12 13:31:18 2015 +0000
Refactor main program to use boost::serialization for saving.
>---------------------------------------------------------------
ad4f7b3e38771f658b3cb10e721d5813564b2e3d
src/mlpack/methods/gmm/CMakeLists.txt | 1 +
src/mlpack/methods/gmm/gmm_main.cpp | 17 ++++++++-----
src/mlpack/methods/gmm/gmm_util.hpp | 45 +++++++++++++++++++++++++++++++++++
3 files changed, 57 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/gmm/CMakeLists.txt b/src/mlpack/methods/gmm/CMakeLists.txt
index 981f853..16633b3 100644
--- a/src/mlpack/methods/gmm/CMakeLists.txt
+++ b/src/mlpack/methods/gmm/CMakeLists.txt
@@ -9,6 +9,7 @@ set(SOURCES
positive_definite_constraint.hpp
diagonal_constraint.hpp
eigenvalue_ratio_constraint.hpp
+ gmm_util.hpp
)
# Add directory name to sources.
diff --git a/src/mlpack/methods/gmm/gmm_main.cpp b/src/mlpack/methods/gmm/gmm_main.cpp
index 4215fc6..e1046f9 100644
--- a/src/mlpack/methods/gmm/gmm_main.cpp
+++ b/src/mlpack/methods/gmm/gmm_main.cpp
@@ -8,6 +8,7 @@
#include "gmm.hpp"
#include "no_constraint.hpp"
+#include "gmm_util.hpp"
#include <mlpack/methods/kmeans/refined_start.hpp>
@@ -39,8 +40,8 @@ PROGRAM_INFO("Gaussian Mixture Model (GMM) Training",
PARAM_STRING_REQ("input_file", "File containing the data on which the model "
"will be fit.", "i");
PARAM_INT("gaussians", "Number of Gaussians in the GMM.", "g", 1);
-PARAM_STRING("output_file", "The file to write the trained GMM parameters into "
- "(as XML).", "o", "gmm.xml");
+PARAM_STRING("output_file", "The file to write the trained GMM parameters "
+ "into.", "o", "gmm.xml");
PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
PARAM_INT("trials", "Number of trials to perform in training GMM.", "t", 10);
@@ -137,7 +138,8 @@ int main(int argc, char* argv[])
Timer::Stop("em");
// Save results.
- gmm.Save(CLI::GetParam<string>("output_file"));
+ const string outputFile = CLI::GetParam<string>("output_file");
+ SaveGMM(gmm, outputFile);
}
else
{
@@ -152,7 +154,8 @@ int main(int argc, char* argv[])
Timer::Stop("em");
// Save results.
- gmm.Save(CLI::GetParam<string>("output_file"));
+ const string outputFile = CLI::GetParam<string>("output_file");
+ SaveGMM(gmm, outputFile);
}
}
else
@@ -171,7 +174,8 @@ int main(int argc, char* argv[])
Timer::Stop("em");
// Save results.
- gmm.Save(CLI::GetParam<string>("output_file"));
+ const string outputFile = CLI::GetParam<string>("output_file");
+ SaveGMM(gmm, outputFile);
}
else
{
@@ -188,7 +192,8 @@ int main(int argc, char* argv[])
Timer::Stop("em");
// Save results.
- gmm.Save(CLI::GetParam<string>("output_file"));
+ const string outputFile = CLI::GetParam<string>("output_file");
+ SaveGMM(gmm, outputFile);
}
}
diff --git a/src/mlpack/methods/gmm/gmm_util.hpp b/src/mlpack/methods/gmm/gmm_util.hpp
new file mode 100644
index 0000000..0390883
--- /dev/null
+++ b/src/mlpack/methods/gmm/gmm_util.hpp
@@ -0,0 +1,45 @@
+/**
+ * @file gmm_util.hpp
+ * @author Ryan Curtin
+ *
+ * Utility to save GMMs to files.
+ */
+#ifndef __MLPACK_METHODS_GMM_GMM_UTIL_HPP
+#define __MLPACK_METHODS_GMM_GMM_UTIL_HPP
+
+namespace mlpack {
+namespace gmm {
+
+// Save a GMM to file using boost::serialization.
+// This does not save a type id, however.
+template<typename GMMType>
+void SaveGMM(GMMType& g, const std::string filename)
+{
+ using namespace boost::archive;
+
+ const std::string extension = data::Extension(filename);
+ std::ofstream ofs(filename);
+ if (extension == "xml")
+ {
+ xml_oarchive ar(ofs);
+ ar << data::CreateNVP(g, "gmm");
+ }
+ else if (extension == "bin")
+ {
+ binary_oarchive ar(ofs);
+ ar << data::CreateNVP(g, "gmm");
+ }
+ else if (extension == "txt")
+ {
+ text_oarchive ar(ofs);
+ ar << data::CreateNVP(g, "gmm");
+ }
+ else
+ Log::Fatal << "Unknown extension '" << extension << "' for GMM model file "
+ << "(known: 'xml', 'bin', 'txt')." << std::endl;
+}
+
+} // namespace gmm
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list