[mlpack-git] master: implement command line programs of softmaxRegression (ec4bdf2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 16 10:08:30 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9d882d9461a656dfec814b0ec7ae32bd4aebf8b2...7983dc9bfef684061f040667a69de75887cd2330

>---------------------------------------------------------------

commit ec4bdf27e9309208add86f07287ebf31c9a04295
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Thu Oct 29 12:17:41 2015 +0800

    implement command line programs of softmaxRegression


>---------------------------------------------------------------

ec4bdf27e9309208add86f07287ebf31c9a04295
 .../methods/softmax_regression/CMakeLists.txt      |  1 +
 .../methods/softmax_regression/softmax_main.cpp    | 98 ++++++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/src/mlpack/methods/softmax_regression/CMakeLists.txt b/src/mlpack/methods/softmax_regression/CMakeLists.txt
index df2a33f..b4941f7 100644
--- a/src/mlpack/methods/softmax_regression/CMakeLists.txt
+++ b/src/mlpack/methods/softmax_regression/CMakeLists.txt
@@ -5,6 +5,7 @@ set(SOURCES
   softmax_regression_impl.hpp
   softmax_regression_function.hpp
   softmax_regression_function.cpp
+  softmax_main.cpp
 )
 
 # Add directory name to sources.
diff --git a/src/mlpack/methods/softmax_regression/softmax_main.cpp b/src/mlpack/methods/softmax_regression/softmax_main.cpp
new file mode 100644
index 0000000..bdaeb8d
--- /dev/null
+++ b/src/mlpack/methods/softmax_regression/softmax_main.cpp
@@ -0,0 +1,98 @@
+#include <mlpack/core.hpp>
+#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
+#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+
+#include <set>
+
+// Define parameters for the executable.
+PROGRAM_INFO("Softmax Regression", "This program performs softmax regression "
+    "on the given dataset and able to store the learned parameters.");
+
+// Required options.
+PARAM_STRING_REQ("input_data", "Input dataset to perform training on(read from files).", "i");
+PARAM_STRING_REQ("input_label", "Input labels to perform training"
+                                " on(read from files). The labels must specify by row", "l");
+
+// Output options.
+PARAM_STRING("output_file", "If specified, the trained results will write into this "
+             "file; Else the training results will not be saved", "p", "");
+
+// Softmax configuration options.
+PARAM_INT("max_iterations", "Maximum number of iterations before "
+          "terminates.", "m", 400);
+
+PARAM_INT("number_of_classes", "Number of classes for classification, "
+          "if you do not specify, it will measure it out automatic",
+          "n", 0);
+
+PARAM_DOUBLE("lambda", "L2-regularization constant", "r", 0.0001);
+
+PARAM_FLAG("intercept", "Add intercept term, if not specify, "
+           "the intercept term will not be added", "t");
+
+
+int main(int argc, char** argv)
+{
+  using namespace mlpack;
+
+  CLI::ParseCommandLine(argc, argv);
+
+  const auto inputFile = CLI::GetParam<std::string>("input_data");
+  const auto labelFile = CLI::GetParam<std::string>("input_label");
+  const auto maxIterations = CLI::GetParam<int>("max_iterations");
+
+  if (maxIterations < 0)
+  {
+    Log::Fatal << "Invalid value for maximum iterations (" << maxIterations <<
+                  ")! Must be greater than or equal to 0." << std::endl;
+  }
+
+  // Make sure we have an output file if we're not doing the work in-place.
+  if (!CLI::HasParam("output_file"))
+  {
+    Log::Warn << "--output_file is not set; "
+              << "no results will be saved." << std::endl;
+  }
+
+  arma::mat trainData;
+  arma::Row<size_t> trainLabels;
+  trainData.load(inputFile, arma::auto_detect);
+  trainData = trainData.t();
+  trainLabels.load(labelFile, arma::auto_detect);
+
+  //load functions of mlpack do not works on windows, it will complain
+  //"[FATAL] Unable to detect type of 'softmax_data.txt'; incorrect extension?"
+  //data::Load(inputFile, trainData, true);
+  //data::Load(labelFile, trainLabels, true);
+
+  std::cout<<trainData<<"\n\n";
+  std::cout<<trainLabels<<"\n\n";
+  if(trainData.n_cols != trainLabels.n_elem)
+  {
+    Log::Fatal << "Samples of input_data should same as the size "
+                  "of input_label " << std::endl;
+  }
+
+  size_t numClasses = CLI::GetParam<int>("number_of_classes");
+  if(numClasses == 0){
+    const std::set<size_t> unique_labels(std::begin(trainLabels),
+                                         std::end(trainLabels));
+    numClasses = unique_labels.size();
+  }
+
+  const auto intercept = !CLI::HasParam("intercept") ? false : true;
+
+  using SRF = regression::SoftmaxRegressionFunction;
+  SRF smFunction(trainData, trainLabels, numClasses,
+               intercept, CLI::GetParam<double>("lambda"));
+
+  const size_t numBasis = 5;
+  optimization::L_BFGS<SRF> optimizer(smFunction, numBasis, maxIterations);
+  regression::SoftmaxRegression<> sm(optimizer);
+
+  if(CLI::HasParam("output_file"))
+  {
+    data::Save(CLI::GetParam<std::string>("output_file"),
+               "softmax_regression", sm, true);
+  }
+}



More information about the mlpack-git mailing list