[mlpack-git] master: Fix vector parameter handling and add tests. This should solve #798. (998bb2f)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Oct 6 14:09:41 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ee0da579342f05c77f7cf54b54d624fc408db713...998bb2fae41210d03ddf007b51d994a9cf6262cf
>---------------------------------------------------------------
commit 998bb2fae41210d03ddf007b51d994a9cf6262cf
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Oct 6 14:08:01 2016 -0400
Fix vector parameter handling and add tests. This should solve #798.
For taking vectors, we need to have multitoken() when the argument is added. So
this adds some template metaprogramming to handle that case properly, and then
some tests to ensure that it works.
>---------------------------------------------------------------
998bb2fae41210d03ddf007b51d994a9cf6262cf
src/mlpack/core/util/cli.hpp | 36 +++++++++++++++++++++++
src/mlpack/core/util/cli_impl.hpp | 21 ++++++++++++--
src/mlpack/tests/cli_test.cpp | 61 +++++++++++++++++++++++++++++++++++++++
3 files changed, 116 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/core/util/cli.hpp b/src/mlpack/core/util/cli.hpp
index 7eadf8a..f53ef0e 100644
--- a/src/mlpack/core/util/cli.hpp
+++ b/src/mlpack/core/util/cli.hpp
@@ -425,6 +425,42 @@ class CLI
//! Private copy constructor; we don't want copies floating around.
CLI(const CLI& other);
+
+ //! Metaprogramming structure for vector detection.
+ template<typename T>
+ struct IsStdVector { const static bool value = false; };
+
+ //! Metaprogramming structure for vector detection.
+ template<typename eT>
+ struct IsStdVector<std::vector<eT>> { const static bool value = true; };
+
+ /**
+ * Add an option if it is not a vector type. This is a utility function used
+ * by CLI::Add.
+ *
+ * @tparam Type of parameter.
+ * @param optId Name of parameter.
+ * @param descr Description.
+ */
+ template<typename T>
+ void AddOption(
+ const char* optId,
+ const char* descr,
+ const typename boost::disable_if<IsStdVector<T>>::type* /* junk */ = 0);
+
+ /**
+ * Add an option if it is a vector type. This is a utility function used by
+ * CLI::Add.
+ *
+ * @tparam Type of parameter.
+ * @param optId Name of parameter.
+ * @param descr Description.
+ */
+ template<typename T>
+ void AddOption(
+ const char* optId,
+ const char* descr,
+ const typename boost::enable_if<IsStdVector<T>>::type* /* junk */ = 0);
};
} // namespace mlpack
diff --git a/src/mlpack/core/util/cli_impl.hpp b/src/mlpack/core/util/cli_impl.hpp
index 12db93e..c4d6cf2 100644
--- a/src/mlpack/core/util/cli_impl.hpp
+++ b/src/mlpack/core/util/cli_impl.hpp
@@ -64,7 +64,6 @@ void CLI::Add(const std::string& identifier,
outstr << "Parameter --" << identifier << "(-" << alias << ") "
<< "is defined multiple times with same alias." << std::endl;
- po::options_description& desc = CLI::GetSingleton().desc;
// Must make use of boost syntax here.
std::string progOptId =
alias.length() ? identifier + "," + alias : identifier;
@@ -73,7 +72,7 @@ void CLI::Add(const std::string& identifier,
AddAlias(alias, identifier);
// Add the option to boost program_options.
- desc.add_options()(progOptId.c_str(), po::value<T>(), description.c_str());
+ GetSingleton().AddOption<T>(progOptId.c_str(), description.c_str());
// Make sure the appropriate metadata is inserted into gmap.
ParamData data;
@@ -145,6 +144,24 @@ T& CLI::GetParam(const std::string& identifier)
return *boost::any_cast<T>(&gmap[key].value);
}
+template<typename T>
+void CLI::AddOption(
+ const char* optId,
+ const char* descr,
+ const typename boost::disable_if<IsStdVector<T>>::type* /* junk */)
+{
+ desc.add_options()(optId, po::value<T>(), descr);
+}
+
+template<typename T>
+void CLI::AddOption(
+ const char* optId,
+ const char* descr,
+ const typename boost::enable_if<IsStdVector<T>>::type* /* junk */)
+{
+ desc.add_options()(optId, po::value<T>()->multitoken(), descr);
+}
+
} // namespace mlpack
#endif
diff --git a/src/mlpack/tests/cli_test.cpp b/src/mlpack/tests/cli_test.cpp
index a3eab2e..40449df 100644
--- a/src/mlpack/tests/cli_test.cpp
+++ b/src/mlpack/tests/cli_test.cpp
@@ -135,6 +135,67 @@ BOOST_AUTO_TEST_CASE(TestBooleanOption)
}
/**
+ * Test that a vector option works correctly.
+ */
+BOOST_AUTO_TEST_CASE(TestVectorOption)
+{
+ PARAM_VECTOR_IN(size_t, "test_vec", "test description", "t");
+
+ int argc = 5;
+ const char* argv[5];
+ argv[0] = "./test";
+ argv[1] = "--test_vec";
+ argv[2] = "1";
+ argv[3] = "2";
+ argv[4] = "4";
+
+ Log::Fatal.ignoreInput = true;
+ CLI::ParseCommandLine(argc, const_cast<char**>(argv));
+ Log::Fatal.ignoreInput = false;
+
+ BOOST_REQUIRE(CLI::HasParam("test_vec"));
+
+ std::vector<size_t> v = CLI::GetParam<std::vector<size_t>>("test_vec");
+
+ BOOST_REQUIRE_EQUAL(v.size(), 3);
+ BOOST_REQUIRE_EQUAL(v[0], 1);
+ BOOST_REQUIRE_EQUAL(v[1], 2);
+ BOOST_REQUIRE_EQUAL(v[2], 4);
+}
+
+/**
+ * Test that we can use a vector option by specifying it many times.
+ */
+BOOST_AUTO_TEST_CASE(TestVectorOption2)
+{
+ PARAM_VECTOR_IN(size_t, "test2_vec", "test description", "T");
+
+ int argc = 7;
+ const char* argv[7];
+ argv[0] = "./test";
+ argv[1] = "--test2_vec";
+ argv[2] = "1";
+ argv[3] = "--test2_vec";
+ argv[4] = "2";
+ argv[5] = "--test2_vec";
+ argv[6] = "4";
+
+ Log::Fatal.ignoreInput = true;
+ CLI::ParseCommandLine(argc, const_cast<char**>(argv));
+ Log::Fatal.ignoreInput = false;
+
+ BOOST_REQUIRE(CLI::HasParam("test_vec"));
+
+ std::vector<size_t> v = CLI::GetParam<std::vector<size_t>>("test_vec");
+
+ BOOST_REQUIRE_EQUAL(v.size(), 3);
+ BOOST_REQUIRE_EQUAL(v[0], 1);
+ BOOST_REQUIRE_EQUAL(v[1], 2);
+ BOOST_REQUIRE_EQUAL(v[2], 4);
+
+}
+
+/**
* Test that we can correctly output Armadillo objects to PrefixedOutStream
* objects.
*/
More information about the mlpack-git
mailing list