[mlpack-git] master: add data split executable without labels + tests (e41b3db)

gitdub at mlpack.org gitdub at mlpack.org
Fri May 27 11:41:29 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/e3a23c256f017ebb8185b15847c82f51d359cdfd...181792d99549467440b2b839f52deec75be10334

>---------------------------------------------------------------

commit e41b3dbde7328a358cfa8adb2ffb5a545a48eb75
Author: Keon Kim <kwk236 at gmail.com>
Date:   Fri May 27 22:11:31 2016 +0900

    add data split executable without labels + tests


>---------------------------------------------------------------

e41b3dbde7328a358cfa8adb2ffb5a545a48eb75
 .../methods/preprocess/preprocess_split_main.cpp   | 150 +++++++++++++++------
 src/mlpack/tests/split_data_test.cpp               |  22 ++-
 2 files changed, 130 insertions(+), 42 deletions(-)

diff --git a/src/mlpack/methods/preprocess/preprocess_split_main.cpp b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
index dfd79cf..02e844a 100644
--- a/src/mlpack/methods/preprocess/preprocess_split_main.cpp
+++ b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
@@ -13,11 +13,12 @@ PROGRAM_INFO("Split into Train and Test Data", "This "
 
 // Define parameters for data
 PARAM_STRING_REQ("input_file", "File containing data,", "i");
-PARAM_STRING_REQ("input_label", "File containing labels", "I");
-PARAM_STRING_REQ("training_file", "File name to save train data", "t");
-PARAM_STRING_REQ("test_file", "File name to save test data", "T");
-PARAM_STRING_REQ("training_labels_file", "File name to save train label", "l");
-PARAM_STRING_REQ("test_labels_file", "File name to save test label", "L");
+// Define optional parameters
+PARAM_STRING("input_labels", "File containing labels", "I", "");
+PARAM_STRING("training_file", "File name to save train data", "t", "");
+PARAM_STRING("test_file", "File name to save test data", "T", "");
+PARAM_STRING("training_labels_file", "File name to save train label", "l", "");
+PARAM_STRING("test_labels_file", "File name to save test label", "L", "");
 
 // Define optional test ratio, default is 0.2 (Test 20% Train 80%)
 PARAM_DOUBLE("test_ratio", "Ratio of test set, if not set,"
@@ -31,46 +32,113 @@ int main(int argc, char** argv)
 {
   // Parse command line options.
   CLI::ParseCommandLine(argc, argv);
-
   const string inputFile = CLI::GetParam<string>("input_file");
-  const string inputLabel = CLI::GetParam<string>("input_label");
-  const string trainingFile = CLI::GetParam<string>("training_file");
-  const string testFile = CLI::GetParam<string>("test_file");
-  const string trainingLabelsFile = CLI::GetParam<string>("training_labels_file");
-  const string testLabelsFile = CLI::GetParam<string>("test_labels_file");
+  const string inputLabels = CLI::GetParam<string>("input_labels");
+  string trainingFile = CLI::GetParam<string>("training_file");
+  string testFile = CLI::GetParam<string>("test_file");
+  string trainingLabelsFile = CLI::GetParam<string>("training_labels_file");
+  string testLabelsFile = CLI::GetParam<string>("test_labels_file");
   const double testRatio = CLI::GetParam<double>("test_ratio");
 
-  // container for input data and labels
-  arma::mat data;
-  arma::mat labels;
+  // check on data parameters
+  if (trainingFile.empty())
+  {
+    trainingFile = "train_" + inputFile;
+    Log::Warn << "You did not specify --training_file. "
+      << "Training file name is automatically set to: "
+      << trainingFile << endl;
+  }
+  if (testFile.empty())
+  {
+    testFile = "test_" + inputFile;
+    Log::Warn << "You did not specify --test_file. "
+      << "Test file name is automatically set to: " << testFile << endl;
+  }
+
+  // check on label parameters
+  if (!inputLabels.empty())
+  {
+    if (!CLI::HasParam("training_labels_file"))
+    {
+      trainingLabelsFile = "train_" + inputLabels;
+      Log::Warn << "You did not specify --training_labels_file. "
+        << "Training labels file name is automatically set to: "
+        << trainingLabelsFile << endl;
+    }
+    if (!CLI::HasParam("test_labels_file"))
+    {
+      testLabelsFile = "test_" + inputLabels;
+      Log::Warn << "You did not specify --test_labels_file. "
+        << "Test labels file name is automatically set to: "
+        << testLabelsFile << endl;
+    }
+  }
+  else
+  {
+    if (CLI::HasParam("training_labels_file")
+        || CLI::HasParam("test_labels_file"))
+    {
+      Log::Fatal << "When specifying --training_labels_file or "
+        << "test_labels_file, you must also specify --input_labels. " << endl;
+    }
+  }
 
-  // Load Data and Labels
+  // check on test_ratio
+  if (CLI::HasParam("test_ratio"))
+  {
+    //sanity check on test_ratio
+    if ((testRatio < 0.0) && (testRatio > 1.0))
+    {
+      Log::Fatal << "Invalid parameter for test_ratio. "
+        << "test_ratio must be between 0.0 and 1.0" << endl;
+    }
+  }
+  else // if test_ratio is not set
+  {
+    Log::Warn << "You did not specify --test_ratio_file. "
+      << "Test ratio is automatically set to: 0.2"<< endl;
+  }
+
+  // load data
+  arma::mat data;
   data::Load(inputFile, data, true);
-  data::Load(inputLabel, labels, true);
-  arma::rowvec labels_row = labels.row(0); // extract first row
-
-  // Split Data
-  const auto value = data::Split(data, labels_row, testRatio);
-  Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
-  Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
-  Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
-  Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
-
-  // Cast double matrix to string matrix
-  //Mat<string> training = conv_to<Mat<string>>::from(get<0>(value));
-  //Mat<string> test = conv_to<Mat<string>>::from(get<1>(value));
-  //Mat<string> trainingLabels = conv_to<Mat<string>>::from(get<2>(value));
-  //Mat<string> testLabels = conv_to<Mat<string>>::from(get<3>(value));
-
-  //Cast double matrix to string matrix
-  mat training = get<0>(value);
-  mat test = get<1>(value);
-  mat trainingLabels = get<2>(value);
-  mat testLabels = get<3>(value);
-
-  data::Save(trainingFile, training, false);
-  data::Save(testFile, test, false);
-  data::Save(trainingLabelsFile, trainingLabels, false);
-  data::Save(testLabelsFile, testLabels, false);
+
+  // if parameters for labels exist
+  if (CLI::HasParam("input_labels"))
+  {
+    arma::mat labels;
+    data::Load(inputLabels, labels, true);
+    arma::rowvec labels_row = labels.row(0); // extract first row
+
+    const auto value = data::Split(data, labels_row, testRatio);
+    Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
+    Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
+    Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
+    Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
+
+    // TODO: fix full precision problem
+    mat training = get<0>(value);
+    mat test = get<1>(value);
+    mat trainingLabels = get<2>(value);
+    mat testLabels = get<3>(value);
+
+    data::Save(trainingFile, training, false);
+    data::Save(testFile, test, false);
+    data::Save(trainingLabelsFile, trainingLabels, false);
+    data::Save(testLabelsFile, testLabels, false);
+  }
+  else // split without parameters
+  {
+    const auto value = data::Split(data, testRatio);
+    Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
+    Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
+
+    // TODO: fix full precision problem
+    mat training = get<0>(value);
+    mat test = get<1>(value);
+
+    data::Save(trainingFile, training, false);
+    data::Save(testFile, test, false);
+  }
 }
 
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index 700e31a..d7b3990 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -67,7 +67,17 @@ void CheckDuplication(const Row<size_t>& trainLabels,
     BOOST_REQUIRE_EQUAL(counts[i], 1);
 }
 
-BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
+BOOST_AUTO_TEST_CASE(SplitDataResultMat)
+{
+  mat input(2, 10);
+  input.randu();
+
+  const auto value = Split(input, 0.2);
+  BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8); // train data
+  BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2); // test data
+}
+
+BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
 {
   mat input(2, 10);
   input.randu();
@@ -99,6 +109,16 @@ BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
   mat input(10, 497);
   input.randu();
 
+  const auto value = Split(input, 0.3);
+  BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
+  BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+}
+
+BOOST_AUTO_TEST_CASE(SplitLabeledDataLargerTest)
+{
+  mat input(10, 497);
+  input.randu();
+
   // Set the labels to the column ID.
   const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
       input.n_cols);




More information about the mlpack-git mailing list