[mlpack-git] master: 1 : change the option name of command line 2 : allow users to reuse trained results 3 : allow users to test the predict results of test data (c1e25a3)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 16 10:08:36 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9d882d9461a656dfec814b0ec7ae32bd4aebf8b2...7983dc9bfef684061f040667a69de75887cd2330
>---------------------------------------------------------------
commit c1e25a37f12655c324c5ab8bda3ece7c8bf5e175
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date: Fri Oct 30 10:15:53 2015 +0800
1 : change the option name of command line
2 : allow users to reuse trained results
3 : allow users to test the predict results of test data
>---------------------------------------------------------------
c1e25a37f12655c324c5ab8bda3ece7c8bf5e175
.../softmax_regression/softmax_regression_main.cpp | 226 +++++++++++++++++----
1 file changed, 183 insertions(+), 43 deletions(-)
diff --git a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
index a117bcb..8e870f8 100644
--- a/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
+++ b/src/mlpack/methods/softmax_regression/softmax_regression_main.cpp
@@ -2,25 +2,32 @@
#include <mlpack/methods/softmax_regression/softmax_regression.hpp>
#include <mlpack/core/optimizers/lbfgs/lbfgs.hpp>
+#include <memory>
#include <set>
// Define parameters for the executable.
-PROGRAM_INFO("Softmax Regression", "This program performs softmax regression "
- "on the given dataset and able to store the learned parameters.");
+//PROGRAM_INFO("Softmax Regression", "This program performs softmax regression "
+// "on the given dataset and able to store the learned parameters.");
// Required options.
-PARAM_STRING_REQ("input_data", "Input dataset to perform training on(read from files).", "i");
-PARAM_STRING_REQ("input_label",
- "Input labels to perform training"
- " on(read from files). The labels must order as a row", "l");
+PARAM_STRING("training_file", "A file containing the training set (the matrix "
+ "of predictors, X).", "t", "");
+PARAM_STRING("labels_file", "A file containing labels (0 or 1) for the points "
+ "in the training set (y). The labels must order as a row", "l", "");
-// Output options.
-PARAM_STRING("output_file", "If specified, the trained results will write into this "
- "file; Else the training results will not be saved", "p", "");
+// Model loading/saving.
+PARAM_STRING("input_model", "File containing existing model (parameters).", "i",
+ "");
+PARAM_STRING("output_model", "File to save trained logistic regression model "
+ "to.", "m", "");
+
+// Testing.
+PARAM_STRING("test_data", "File containing test dataset.", "T", "");
+PARAM_STRING("test_labels", "File containing test labels.", "L", "");
// Softmax configuration options.
PARAM_INT("max_iterations", "Maximum number of iterations before "
- "terminates.", "m", 400);
+ "terminates.", "M", 400);
PARAM_INT("number_of_classes", "Number of classes for classification, "
"if you do not specify, it will measure it out automatic",
@@ -31,6 +38,20 @@ PARAM_DOUBLE("lambda", "L2-regularization constant", "r", 0.0001);
PARAM_FLAG("intercept", "Add intercept term, if not specify, "
"the intercept term will not be added", "t");
+size_t calculateNumberOfClasses(size_t numClasses,
+ arma::Row<size_t> const &trainLabels);
+
+template<typename Model>
+void testPredictAcc(const std::string &testFile,
+ const std::string &testLabels,
+ size_t numClasses,
+ Model &model);
+
+template<typename Model>
+std::unique_ptr<Model> trainSoftmax(const std::string &trainingFile,
+ const std::string &labelFile,
+ const std::string &inputModelFile,
+ size_t maxIterations);
int main(int argc, char** argv)
{
@@ -38,9 +59,24 @@ int main(int argc, char** argv)
CLI::ParseCommandLine(argc, argv);
- const auto inputFile = CLI::GetParam<std::string>("input_data");
- const auto labelFile = CLI::GetParam<std::string>("input_label");
- const auto maxIterations = CLI::GetParam<int>("max_iterations");
+ const std::string trainingFile = CLI::GetParam<std::string>("training_file");
+ const std::string inputModelFile = CLI::GetParam<std::string>("input_model");
+
+ // One of inputFile and modelFile must be specified.
+ if(inputModelFile.empty() && trainingFile.empty())
+ {
+ Log::Fatal << "One of --input_model or --training_file must be specified."
+ << std::endl;
+ }
+
+ const std::string labelFile = CLI::GetParam<std::string>("labels_file");
+ if(!trainingFile.empty() && labelFile.empty())
+ {
+ Log::Fatal << "--label_file must be specified with --training_file"
+ << std::endl;
+ }
+
+ const int maxIterations = CLI::GetParam<int>("max_iterations");
if (maxIterations < 0)
{
@@ -48,52 +84,156 @@ int main(int argc, char** argv)
")! Must be greater than or equal to 0." << std::endl;
}
+ const std::string outputModelFile = CLI::GetParam<std::string>("output_model");
+
// Make sure we have an output file if we're not doing the work in-place.
- if (!CLI::HasParam("output_file"))
+ if (outputModelFile.empty())
{
- Log::Warn << "--output_file is not set; "
+ Log::Warn << "--output_model is not set; "
<< "no results will be saved." << std::endl;
}
- arma::mat trainData;
- arma::Row<size_t> trainLabels;
- trainData.load(inputFile, arma::auto_detect);
- trainData = trainData.t();
- trainLabels.load(labelFile, arma::auto_detect);
- //load functions of mlpack do not works on windows, it will complain
- //"[FATAL] Unable to detect type of 'softmax_data.txt'; incorrect extension?"
- //data::Load(inputFile, trainData, true);
- //data::Load(labelFile, trainLabels, true);
+ using SM = regression::SoftmaxRegression<>;
+ std::unique_ptr<SM> sm = trainSoftmax<SM>(trainingFile,
+ labelFile,
+ inputModelFile,
+ maxIterations);
- std::cout<<trainData<<"\n\n";
- std::cout<<trainLabels<<"\n\n";
- if(trainData.n_cols != trainLabels.n_elem)
+ testPredictAcc(CLI::GetParam<std::string>("test_data"),
+ CLI::GetParam<std::string>("test_labels"),
+ sm->NumClasses(), *sm);
+
+ if(!outputModelFile.empty())
{
- Log::Fatal << "Samples of input_data should same as the size "
- "of input_label " << std::endl;
+ data::Save(CLI::GetParam<std::string>("output_model"),
+ "softmax_regression_model", *sm, true);
}
+}
- size_t numClasses = CLI::GetParam<int>("number_of_classes");
+size_t calculateNumberOfClasses(size_t numClasses,
+ arma::Row<size_t> const &trainLabels)
+{
if(numClasses == 0){
const std::set<size_t> unique_labels(std::begin(trainLabels),
std::end(trainLabels));
numClasses = unique_labels.size();
}
- const auto intercept = !CLI::HasParam("intercept") ? false : true;
-
- using SRF = regression::SoftmaxRegressionFunction;
- SRF smFunction(trainData, trainLabels, numClasses,
- intercept, CLI::GetParam<double>("lambda"));
+ return numClasses;
+}
- const size_t numBasis = 5;
- optimization::L_BFGS<SRF> optimizer(smFunction, numBasis, maxIterations);
- regression::SoftmaxRegression<> sm(optimizer);
+template<typename Model>
+void testPredictAcc(const std::string &testFile,
+ const std::string &testLabelsFile,
+ size_t numClasses,
+ Model &model)
+{
+ using namespace mlpack;
+ if(testFile.empty() && testLabelsFile.empty())
+ {
+ return;
+ }
+
+ if((!testFile.empty() && testLabelsFile.empty()) ||
+ (testFile.empty() && !testLabelsFile.empty()))
+ {
+ Log::Fatal << "--test_file must be specified with --test_labels and vice versa"
+ << std::endl;
+ }
+
+ if(!testFile.empty() && !testLabelsFile.empty())
+ {
+ arma::mat testData;
+ arma::Row<size_t> testLabels;
+ testData.load(testFile, arma::auto_detect);
+ testData = testData.t();
+ testLabels.load(testLabelsFile, arma::auto_detect);
+
+ if(testData.n_cols!= testLabels.n_elem)
+ {
+ Log::Fatal << "Labels of --test_labels should same as the samples size "
+ "of --test_data " << std::endl;
+ }
+
+ arma::vec predictLabels;
+ model.Predict(testData, predictLabels);
+ std::vector<size_t> bingoLabels(numClasses, 0);
+ std::vector<size_t> labelSize(numClasses, 0);
+ for(arma::uword i = 0; i != predictLabels.n_elem; ++i)
+ {
+ if(predictLabels(i) == testLabels(i))
+ {
+ ++bingoLabels[testLabels(i)];
+ }
+ ++labelSize[testLabels(i)];
+ }
+ size_t totalBingo = 0;
+ for(size_t i = 0; i != bingoLabels.size(); ++i)
+ {
+ std::cout<<"Accuracy of label "<<i<<" is "
+ <<(bingoLabels[i]/static_cast<double>(labelSize[i]))
+ <<std::endl;
+ totalBingo += bingoLabels[i];
+ }
+ std::cout<<"\nTotal accuracy is "
+ <<(totalBingo)/static_cast<double>(predictLabels.n_elem)
+ <<std::endl;
+ }
+}
- if(CLI::HasParam("output_file"))
- {
- data::Save(CLI::GetParam<std::string>("output_file"),
- "softmax_regression", sm, true);
- }
+template<typename Model>
+std::unique_ptr<Model> trainSoftmax(const std::string &trainingFile,
+ const std::string &labelFile,
+ const std::string &inputModelFile,
+ size_t maxIterations)
+{
+ using namespace mlpack;
+
+ using SRF = regression::SoftmaxRegressionFunction;
+ using SM = regression::SoftmaxRegression<>;
+
+ std::unique_ptr<Model> sm;
+ if(!inputModelFile.empty())
+ {
+
+ sm.reset(new Model(0, 0, false));
+ mlpack::data::Load(inputModelFile,
+ "softmax_regression_model",
+ *sm, true);
+ }
+ else
+ {
+ arma::mat trainData;
+ arma::Row<size_t> trainLabels;
+ trainData.load(trainingFile, arma::auto_detect);
+ trainData = trainData.t();
+ trainLabels.load(labelFile, arma::auto_detect);
+
+ //load functions of mlpack do not works on windows, it will complain
+ //"[FATAL] Unable to detect type of 'softmax_data.txt'; incorrect extension?"
+ //data::Load(inputFile, trainData, true);
+ //data::Load(labelFile, trainLabels, true);
+
+ if(trainData.n_cols != trainLabels.n_elem)
+ {
+ Log::Fatal << "Samples of input_data should same as the size "
+ "of input_label " << std::endl;
+ }
+
+ //size_t numClasses = CLI::GetParam<int>("number_of_classes");
+ const size_t numClasses =
+ calculateNumberOfClasses(CLI::GetParam<int>("number_of_classes"),
+ trainLabels);
+
+ const bool intercept = !CLI::HasParam("intercept") ? false : true;
+
+ SRF smFunction(trainData, trainLabels, numClasses,
+ intercept, CLI::GetParam<double>("lambda"));
+ const size_t numBasis = 5;
+ optimization::L_BFGS<SRF> optimizer(smFunction, numBasis, maxIterations);
+ sm.reset( new Model(optimizer));
+ }
+
+ return sm;
}
More information about the mlpack-git
mailing list