[mlpack-git] master: add data split executable without labels + tests (e41b3db)
gitdub at mlpack.org
gitdub at mlpack.org
Fri May 27 11:41:29 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/e3a23c256f017ebb8185b15847c82f51d359cdfd...181792d99549467440b2b839f52deec75be10334
>---------------------------------------------------------------
commit e41b3dbde7328a358cfa8adb2ffb5a545a48eb75
Author: Keon Kim <kwk236 at gmail.com>
Date: Fri May 27 22:11:31 2016 +0900
add data split executable without labels + tests
>---------------------------------------------------------------
e41b3dbde7328a358cfa8adb2ffb5a545a48eb75
.../methods/preprocess/preprocess_split_main.cpp | 150 +++++++++++++++------
src/mlpack/tests/split_data_test.cpp | 22 ++-
2 files changed, 130 insertions(+), 42 deletions(-)
diff --git a/src/mlpack/methods/preprocess/preprocess_split_main.cpp b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
index dfd79cf..02e844a 100644
--- a/src/mlpack/methods/preprocess/preprocess_split_main.cpp
+++ b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
@@ -13,11 +13,12 @@ PROGRAM_INFO("Split into Train and Test Data", "This "
// Define parameters for data
PARAM_STRING_REQ("input_file", "File containing data,", "i");
-PARAM_STRING_REQ("input_label", "File containing labels", "I");
-PARAM_STRING_REQ("training_file", "File name to save train data", "t");
-PARAM_STRING_REQ("test_file", "File name to save test data", "T");
-PARAM_STRING_REQ("training_labels_file", "File name to save train label", "l");
-PARAM_STRING_REQ("test_labels_file", "File name to save test label", "L");
+// Define optional parameters
+PARAM_STRING("input_labels", "File containing labels", "I", "");
+PARAM_STRING("training_file", "File name to save train data", "t", "");
+PARAM_STRING("test_file", "File name to save test data", "T", "");
+PARAM_STRING("training_labels_file", "File name to save train label", "l", "");
+PARAM_STRING("test_labels_file", "File name to save test label", "L", "");
// Define optional test ratio, default is 0.2 (Test 20% Train 80%)
PARAM_DOUBLE("test_ratio", "Ratio of test set, if not set,"
@@ -31,46 +32,113 @@ int main(int argc, char** argv)
{
// Parse command line options.
CLI::ParseCommandLine(argc, argv);
-
const string inputFile = CLI::GetParam<string>("input_file");
- const string inputLabel = CLI::GetParam<string>("input_label");
- const string trainingFile = CLI::GetParam<string>("training_file");
- const string testFile = CLI::GetParam<string>("test_file");
- const string trainingLabelsFile = CLI::GetParam<string>("training_labels_file");
- const string testLabelsFile = CLI::GetParam<string>("test_labels_file");
+ const string inputLabels = CLI::GetParam<string>("input_labels");
+ string trainingFile = CLI::GetParam<string>("training_file");
+ string testFile = CLI::GetParam<string>("test_file");
+ string trainingLabelsFile = CLI::GetParam<string>("training_labels_file");
+ string testLabelsFile = CLI::GetParam<string>("test_labels_file");
const double testRatio = CLI::GetParam<double>("test_ratio");
- // container for input data and labels
- arma::mat data;
- arma::mat labels;
+ // check on data parameters
+ if (trainingFile.empty())
+ {
+ trainingFile = "train_" + inputFile;
+ Log::Warn << "You did not specify --training_file. "
+ << "Training file name is automatically set to: "
+ << trainingFile << endl;
+ }
+ if (testFile.empty())
+ {
+ testFile = "test_" + inputFile;
+ Log::Warn << "You did not specify --test_file. "
+ << "Test file name is automatically set to: " << testFile << endl;
+ }
+
+ // check on label parameters
+ if (!inputLabels.empty())
+ {
+ if (!CLI::HasParam("training_labels_file"))
+ {
+ trainingLabelsFile = "train_" + inputLabels;
+ Log::Warn << "You did not specify --training_labels_file. "
+ << "Training labels file name is automatically set to: "
+ << trainingLabelsFile << endl;
+ }
+ if (!CLI::HasParam("test_labels_file"))
+ {
+ testLabelsFile = "test_" + inputLabels;
+ Log::Warn << "You did not specify --test_labels_file. "
+ << "Test labels file name is automatically set to: "
+ << testLabelsFile << endl;
+ }
+ }
+ else
+ {
+ if (CLI::HasParam("training_labels_file")
+ || CLI::HasParam("test_labels_file"))
+ {
+ Log::Fatal << "When specifying --training_labels_file or "
+ << "test_labels_file, you must also specify --input_labels. " << endl;
+ }
+ }
- // Load Data and Labels
+ // check on test_ratio
+ if (CLI::HasParam("test_ratio"))
+ {
+ //sanity check on test_ratio
+ if ((testRatio < 0.0) && (testRatio > 1.0))
+ {
+ Log::Fatal << "Invalid parameter for test_ratio. "
+ << "test_ratio must be between 0.0 and 1.0" << endl;
+ }
+ }
+ else // if test_ratio is not set
+ {
+ Log::Warn << "You did not specify --test_ratio_file. "
+ << "Test ratio is automatically set to: 0.2"<< endl;
+ }
+
+ // load data
+ arma::mat data;
data::Load(inputFile, data, true);
- data::Load(inputLabel, labels, true);
- arma::rowvec labels_row = labels.row(0); // extract first row
-
- // Split Data
- const auto value = data::Split(data, labels_row, testRatio);
- Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
- Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
- Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
- Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
-
- // Cast double matrix to string matrix
- //Mat<string> training = conv_to<Mat<string>>::from(get<0>(value));
- //Mat<string> test = conv_to<Mat<string>>::from(get<1>(value));
- //Mat<string> trainingLabels = conv_to<Mat<string>>::from(get<2>(value));
- //Mat<string> testLabels = conv_to<Mat<string>>::from(get<3>(value));
-
- //Cast double matrix to string matrix
- mat training = get<0>(value);
- mat test = get<1>(value);
- mat trainingLabels = get<2>(value);
- mat testLabels = get<3>(value);
-
- data::Save(trainingFile, training, false);
- data::Save(testFile, test, false);
- data::Save(trainingLabelsFile, trainingLabels, false);
- data::Save(testLabelsFile, testLabels, false);
+
+ // if parameters for labels exist
+ if (CLI::HasParam("input_labels"))
+ {
+ arma::mat labels;
+ data::Load(inputLabels, labels, true);
+ arma::rowvec labels_row = labels.row(0); // extract first row
+
+ const auto value = data::Split(data, labels_row, testRatio);
+ Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
+ Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
+ Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
+ Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
+
+ // TODO: fix full precision problem
+ mat training = get<0>(value);
+ mat test = get<1>(value);
+ mat trainingLabels = get<2>(value);
+ mat testLabels = get<3>(value);
+
+ data::Save(trainingFile, training, false);
+ data::Save(testFile, test, false);
+ data::Save(trainingLabelsFile, trainingLabels, false);
+ data::Save(testLabelsFile, testLabels, false);
+ }
+ else // split without parameters
+ {
+ const auto value = data::Split(data, testRatio);
+ Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
+ Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
+
+ // TODO: fix full precision problem
+ mat training = get<0>(value);
+ mat test = get<1>(value);
+
+ data::Save(trainingFile, training, false);
+ data::Save(testFile, test, false);
+ }
}
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index 700e31a..d7b3990 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -67,7 +67,17 @@ void CheckDuplication(const Row<size_t>& trainLabels,
BOOST_REQUIRE_EQUAL(counts[i], 1);
}
-BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
+BOOST_AUTO_TEST_CASE(SplitDataResultMat)
+{
+ mat input(2, 10);
+ input.randu();
+
+ const auto value = Split(input, 0.2);
+ BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8); // train data
+ BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2); // test data
+}
+
+BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
{
mat input(2, 10);
input.randu();
@@ -99,6 +109,16 @@ BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
mat input(10, 497);
input.randu();
+ const auto value = Split(input, 0.3);
+ BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
+ BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+}
+
+BOOST_AUTO_TEST_CASE(SplitLabeledDataLargerTest)
+{
+ mat input(10, 497);
+ input.randu();
+
// Set the labels to the column ID.
const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
input.n_cols);
More information about the mlpack-git
mailing list