[mlpack-git] master: add TrainTestSplit without label (2cad593)

gitdub at mlpack.org gitdub at mlpack.org
Thu May 26 15:09:27 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/e3a23c256f017ebb8185b15847c82f51d359cdfd...181792d99549467440b2b839f52deec75be10334

>---------------------------------------------------------------

commit 2cad593c482275d8d1658ddab01516e59f495f19
Author: Keon Kim <kwk236 at gmail.com>
Date:   Fri May 27 03:54:43 2016 +0900

    add TrainTestSplit without label


>---------------------------------------------------------------

2cad593c482275d8d1658ddab01516e59f495f19
 src/mlpack/core/data/split_data.hpp                | 109 +++++++++++++++++----
 .../methods/preprocess/preprocess_split_main.cpp   |  26 +++--
 src/mlpack/tests/split_data_test.cpp               |   6 +-
 3 files changed, 111 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/core/data/split_data.hpp b/src/mlpack/core/data/split_data.hpp
index 38196fd..b2af0c2 100644
--- a/src/mlpack/core/data/split_data.hpp
+++ b/src/mlpack/core/data/split_data.hpp
@@ -1,8 +1,8 @@
 /**
  * @file split_data.hpp
- * @author Tham Ngap Wei
+ * @author Tham Ngap Wei, Keon Kim
  *
- * Defines TrainTestSplit(), a utility function to split a dataset into a
+ * Defines TrainTestSplit() and LabelTrainTestSplit(), utility functions to split a dataset into a
  * training set and a test set.
  */
 #ifndef MLPACK_CORE_UTIL_SPLIT_DATA_HPP
@@ -12,7 +12,6 @@
 
 namespace mlpack {
 namespace data {
-
 /**
  * Given an input dataset and labels, split into a training set and test set.
  * Example usage below.  This overload places the split dataset into the four
@@ -29,7 +28,7 @@ namespace data {
  *
  * // Split the dataset into a training and test set, with 30% of the data being
  * // held out for the test set.
- * TrainTestSplit(input, label, trainData,
+ * LabelTrainTestSplit(input, label, trainData,
  *                testData, trainLabel, testLabel, 0.3);
  * @endcode
  *
@@ -42,13 +41,13 @@ namespace data {
  * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
  */
 template<typename T, typename U>
-void TrainTestSplit(const arma::Mat<T>& input,
-                    const arma::Row<U>& inputLabel,
-                    arma::Mat<T>& trainData,
-                    arma::Mat<T>& testData,
-                    arma::Row<U>& trainLabel,
-                    arma::Row<U>& testLabel,
-                    const double testRatio)
+void LabelTrainTestSplit(const arma::Mat<T>& input,
+                         const arma::Row<U>& inputLabel,
+                         arma::Mat<T>& trainData,
+                         arma::Mat<T>& testData,
+                         arma::Row<U>& trainLabel,
+                         arma::Row<U>& testLabel,
+                         const double testRatio)
 {
   const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
   const size_t trainSize = input.n_cols - testSize;
@@ -75,6 +74,52 @@ void TrainTestSplit(const arma::Mat<T>& input,
 }
 
 /**
+ * Given an input dataset, split into a training set and test set.
+ * Example usage below. This overload places the split dataset into the two
+ * output parameters given (trainData, testData).
+ *
+ * @code
+ * arma::mat input = loadData();
+ * arma::mat trainData;
+ * arma::mat testData;
+ * math::RandomSeed(100); // Set the seed if you like.
+ *
+ * // Split the dataset into a training and test set, with 30% of the data being
+ * // held out for the test set.
+ * TrainTestSplit(input, trainData, testData, 0.3);
+ * @endcode
+ *
+ * @param input Input dataset to split.
+ * @param trainData Matrix to store training data into.
+ * @param testData Matrix to store test data into.
+ * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
+ */
+template<typename T>
+void TrainTestSplit(const arma::Mat<T>& input,
+                    arma::Mat<T>& trainData,
+                    arma::Mat<T>& testData,
+                    const double testRatio)
+{
+  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
+  const size_t trainSize = input.n_cols - testSize;
+  trainData.set_size(input.n_rows, trainSize);
+  testData.set_size(input.n_rows, testSize);
+
+  const arma::Col<size_t> order =
+      arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols -1,
+                                                      input.n_cols));
+
+  for (size_t i = 0; i != trainSize; ++i)
+  {
+     trainData.col(i) = input.col(order[i]);
+  }
+  for (size_t i = 0; i != testSize; ++i)
+  {
+     testData.col(i) = input.col(order[i + trainSize]);
+  }
+}
+
+/**
  * Given an input dataset and labels, split into a training set and test set.
  * Example usage below.  This overload returns the split dataset as a std::tuple
  * with four elements: an arma::Mat<T> containing the training data, an
@@ -84,36 +129,60 @@ void TrainTestSplit(const arma::Mat<T>& input,
  * @code
  * arma::mat input = loadData();
  * arma::Row<size_t> label = loadLabel();
- * auto splitResult = TrainTestSplit(input, label, 0.2);
+ * auto splitResult = LabelTrainTestSplit(input, label, 0.2);
  * @endcode
  *
  * @param input Input dataset to split.
  * @param label Input labels to split.
- * @param trainData Matrix to store training data into.
- * @param testData Matrix to store test data into.
- * @param trainLabel Vector to store training labels into.
- * @param testLabel Vector to store test labels into.
  * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
  * @return std::tuple containing trainData (arma::Mat<T>), testData
  *      (arma::Mat<T>), trainLabel (arma::Row<U>), and testLabel (arma::Row<U>).
  */
 template<typename T,typename U>
 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
-TrainTestSplit(const arma::Mat<T>& input,
-               const arma::Row<U>& inputLabel,
-               const double testRatio)
+LabelTrainTestSplit(const arma::Mat<T>& input,
+                    const arma::Row<U>& inputLabel,
+                    const double testRatio)
 {
   arma::Mat<T> trainData;
   arma::Mat<T> testData;
   arma::Row<U> trainLabel;
   arma::Row<U> testLabel;
 
-  TrainTestSplit(input, inputLabel, trainData, testData, trainLabel, testLabel,
+  LabelTrainTestSplit(input, inputLabel, trainData, testData, trainLabel, testLabel,
       testRatio);
 
   return std::make_tuple(trainData, testData, trainLabel, testLabel);
 }
 
+/**
+ * Given an input dataset, split into a training set and test set.
+ * Example usage below.  This overload returns the split dataset as a std::tuple
+ * with two elements: an arma::Mat<T> containing the training data and an
+ * arma::Mat<T> containing the test data.
+ *
+ * @code
+ * arma::mat input = loadData();
+ * auto splitResult = TrainTestSplit(input, 0.2);
+ * @endcode
+ *
+ * @param input Input dataset to split.
+ * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
+ * @return std::tuple containing trainData (arma::Mat<T>)
+ *      and testData (arma::Mat<T>).
+ */
+template<typename T>
+std::tuple<arma::Mat<T>, arma::Mat<T>>
+TrainTestSplit(const arma::Mat<T>& input,
+               const double testRatio)
+{
+  arma::Mat<T> trainData;
+  arma::Mat<T> testData;
+  TrainTestSplit(input, trainData, testData, testRatio);
+
+  return std::make_tuple(trainData, testData);
+}
+
 } // namespace data
 } // namespace mlpack
 
diff --git a/src/mlpack/methods/preprocess/preprocess_split_main.cpp b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
index 357e77f..d6ffb89 100644
--- a/src/mlpack/methods/preprocess/preprocess_split_main.cpp
+++ b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
@@ -42,23 +42,35 @@ int main(int argc, char** argv)
 
   // container for input data and labels
   arma::mat data;
-  arma::Mat<size_t> labels;
+  arma::mat labels;
 
   // Load Data and Labels
   data::Load(inputFile, data, true);
   data::Load(inputLabel, labels, true);
-  arma::Row<size_t> labels_row = labels.row(0); // extract first row
+  arma::rowvec labels_row = labels.row(0); // extract first row
 
   // Split Data
-  const auto value = data::TrainTestSplit(data, labels_row, testRatio);
+  const auto value = data::LabelTrainTestSplit(data, labels_row, testRatio);
   Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
   Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
   Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
   Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
 
-  data::Save(trainingFile, get<0>(value), false);
-  data::Save(testFile, get<1>(value), false);
-  data::Save(trainingLabelsFile, get<2>(value), false);
-  data::Save(testLabelsFile, get<3>(value), false);
+  // Cast double matrix to string matrix
+  //Mat<string> training = conv_to<Mat<string>>::from(get<0>(value));
+  //Mat<string> test = conv_to<Mat<string>>::from(get<1>(value));
+  //Mat<string> trainingLabels = conv_to<Mat<string>>::from(get<2>(value));
+  //Mat<string> testLabels = conv_to<Mat<string>>::from(get<3>(value));
+
+  //Cast double matrix to string matrix
+  mat training = get<0>(value);
+  mat test = get<1>(value);
+  mat trainingLabels = get<2>(value);
+  mat testLabels = get<3>(value);
+
+  data::Save(trainingFile, training, false);
+  data::Save(testFile, test, false);
+  data::Save(trainingLabelsFile, trainingLabels, false);
+  data::Save(testLabelsFile, testLabels, false);
 }
 
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index 1cf7136..462708e 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -73,11 +73,11 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
   input.randu();
 
   // Set the labels to the column ID, so that CompareData can compare the data
-  // after TrainTestSplit is called.
+  // after LabelTrainTestSplit is called.
   const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
       input.n_cols);
 
-  const auto value = TrainTestSplit(input, labels, 0.2);
+  const auto value = LabelTrainTestSplit(input, labels, 0.2);
   BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8);
   BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2);
   BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 8);
@@ -103,7 +103,7 @@ BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
   const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
       input.n_cols);
 
-  const auto value = TrainTestSplit(input, labels, 0.3);
+  const auto value = LabelTrainTestSplit(input, labels, 0.3);
   BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
   BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
   BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 497 - size_t(0.3 * 497));




More information about the mlpack-git mailing list