[mlpack-git] master: Style fixes. (35149da)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 22 10:17:30 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/b08ae02b90e18f97366b236e7a4d8725cd6e9050...213a04d31645134b61aa7d3702360bb34796d7de

>---------------------------------------------------------------

commit 35149daee158f2d4299eedb590376ba4a4c21a5a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Apr 22 10:17:30 2016 -0400

    Style fixes.


>---------------------------------------------------------------

35149daee158f2d4299eedb590376ba4a4c21a5a
 src/mlpack/core/data/split_data.hpp  | 123 ++++++++++++++++++++---------------
 src/mlpack/tests/split_data_test.cpp |  55 +++++++++-------
 2 files changed, 101 insertions(+), 77 deletions(-)

diff --git a/src/mlpack/core/data/split_data.hpp b/src/mlpack/core/data/split_data.hpp
index 1c33ba8..38196fd 100644
--- a/src/mlpack/core/data/split_data.hpp
+++ b/src/mlpack/core/data/split_data.hpp
@@ -1,5 +1,12 @@
-#ifndef __MLPACK_CORE_UTIL_SPLIT_DATA_HPP
-#define __MLPACK_CORE_UTIL_SPLIT_DATA_HPP
+/**
+ * @file split_data.hpp
+ * @author Tham Ngap Wei
+ *
+ * Defines TrainTestSplit(), a utility function to split a dataset into a
+ * training set and a test set.
+ */
+#ifndef MLPACK_CORE_UTIL_SPLIT_DATA_HPP
+#define MLPACK_CORE_UTIL_SPLIT_DATA_HPP
 
 #include <mlpack/core.hpp>
 
@@ -7,37 +14,43 @@ namespace mlpack {
 namespace data {
 
 /**
- *Split training data and test data
- *@param input input data want to split
- *@param label input label want to split
- *@param trainData training data split by input
- *@param testData test data split by input
- *@param trainLabel train label split by input
- *@param testLabel test label split by input
- *@param testRatio the ratio of test data
- *@code
- *arma::mat input = loadData();
- *arma::Row<size_t> label = loadLabel();
- *arma::mat trainData;
- *arma::mat testData;
- *arma::Row<size_t> trainLabel;
- *arma::Row<size_t> testLabel;
- *arma::arma_rng::set_seed(100); //set the seed if you like
- *TrainTestSplit(input, label, trainData,
- *               testData, trainLabel, testLabel);
- *@endcode
+ * Given an input dataset and labels, split into a training set and test set.
+ * Example usage below.  This overload places the split dataset into the four
+ * output parameters given (trainData, testData, trainLabel, and testLabel).
+ *
+ * @code
+ * arma::mat input = loadData();
+ * arma::Row<size_t> label = loadLabel();
+ * arma::mat trainData;
+ * arma::mat testData;
+ * arma::Row<size_t> trainLabel;
+ * arma::Row<size_t> testLabel;
+ * math::RandomSeed(100); // Set the seed if you like.
+ *
+ * // Split the dataset into a training and test set, with 30% of the data being
+ * // held out for the test set.
+ * TrainTestSplit(input, label, trainData,
+ *                testData, trainLabel, testLabel, 0.3);
+ * @endcode
+ *
+ * @param input Input dataset to split.
+ * @param label Input labels to split.
+ * @param trainData Matrix to store training data into.
+ * @param testData Matrix to store test data into.
+ * @param trainLabel Vector to store training labels into.
+ * @param testLabel Vector to store test labels into.
+ * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
  */
 template<typename T, typename U>
-void TrainTestSplit(const arma::Mat<T> &input,
-                    const arma::Row<U> &inputLabel,
-                    arma::Mat<T> &trainData,
-                    arma::Mat<T> &testData,
-                    arma::Row<U> &trainLabel,
-                    arma::Row<U> &testLabel,
+void TrainTestSplit(const arma::Mat<T>& input,
+                    const arma::Row<U>& inputLabel,
+                    arma::Mat<T>& trainData,
+                    arma::Mat<T>& testData,
+                    arma::Row<U>& trainLabel,
+                    arma::Row<U>& testLabel,
                     const double testRatio)
 {
-  size_t const testSize =
-      static_cast<size_t>(input.n_cols * testRatio);
+  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
   const size_t trainSize = input.n_cols - testSize;
   trainData.set_size(input.n_rows, trainSize);
   testData.set_size(input.n_rows, testSize);
@@ -48,13 +61,13 @@ void TrainTestSplit(const arma::Mat<T> &input,
       arma::shuffle(arma::linspace<arma::Col<size_t>>(0, input.n_cols - 1,
                                                       input.n_cols));
 
-  for(size_t i = 0; i != trainSize; ++i)
+  for (size_t i = 0; i != trainSize; ++i)
   {
     trainData.col(i) = input.col(order[i]);
     trainLabel(i) = inputLabel(order[i]);
   }
 
-  for(size_t i = 0; i != testSize; ++i)
+  for (size_t i = 0; i != testSize; ++i)
   {
     testData.col(i) = input.col(order[i + trainSize]);
     testLabel(i) = inputLabel(order[i + trainSize]);
@@ -62,23 +75,32 @@ void TrainTestSplit(const arma::Mat<T> &input,
 }
 
 /**
- *Overload of Split, if you do not like to pass in
- *so many param, you could call this api instead
- *@param input input data want to split
- *@param label input label want to split
- *@return They are trainData, testData, trainLabel and
- *testLabel
- *@code
- *arma::mat input = loadData();
- *arma::Row<size_t> label = loadLabel();
- *auto splitResult = TrainTestSplit(input, label, 0.2);
- *@endcode
+ * Given an input dataset and labels, split into a training set and test set.
+ * Example usage below.  This overload returns the split dataset as a std::tuple
+ * with four elements: an arma::Mat<T> containing the training data, an
+ * arma::Mat<T> containing the test data, an arma::Row<U> containing the
+ * training labels, and an arma::Row<U> containing the test labels.
+ *
+ * @code
+ * arma::mat input = loadData();
+ * arma::Row<size_t> label = loadLabel();
+ * auto splitResult = TrainTestSplit(input, label, 0.2);
+ * @endcode
+ *
+ * @param input Input dataset to split.
+ * @param label Input labels to split.
+ * @param trainData Matrix to store training data into.
+ * @param testData Matrix to store test data into.
+ * @param trainLabel Vector to store training labels into.
+ * @param testLabel Vector to store test labels into.
+ * @param testRatio Percentage of dataset to use for test set (between 0 and 1).
+ * @return std::tuple containing trainData (arma::Mat<T>), testData
+ *      (arma::Mat<T>), trainLabel (arma::Row<U>), and testLabel (arma::Row<U>).
  */
 template<typename T,typename U>
-std::tuple<arma::Mat<T>, arma::Mat<T>,
-arma::Row<U>, arma::Row<U>>
-TrainTestSplit(const arma::Mat<T> &input,
-               const arma::Row<U> &inputLabel,
+std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
+TrainTestSplit(const arma::Mat<T>& input,
+               const arma::Row<U>& inputLabel,
                const double testRatio)
 {
   arma::Mat<T> trainData;
@@ -86,13 +108,10 @@ TrainTestSplit(const arma::Mat<T> &input,
   arma::Row<U> trainLabel;
   arma::Row<U> testLabel;
 
-  TrainTestSplit(input, inputLabel,
-                 trainData, testData,
-                 trainLabel, testLabel,
-                 testRatio);
+  TrainTestSplit(input, inputLabel, trainData, testData, trainLabel, testLabel,
+      testRatio);
 
-  return std::make_tuple(trainData, testData,
-                         trainLabel, testLabel);
+  return std::make_tuple(trainData, testData, trainLabel, testLabel);
 }
 
 } // namespace data
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index bb9cd8d..a23accd 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -1,8 +1,8 @@
 /**
- * @file sparse_autoencoder_test.cpp
- * @author Siddharth Agrawal
+ * @file split_data_test.cpp
+ * @author Tham Ngap Wei
  *
- * Test the SparseAutoencoder class.
+ * Test the SplitData method.
  */
 #include <mlpack/core.hpp>
 #include <mlpack/core/data/split_data.hpp>
@@ -17,19 +17,25 @@ using namespace mlpack::data;
 BOOST_AUTO_TEST_SUITE(SplitDataTest);
 
 /**
- * compare the data after train test split
- * @param inputData The original data set before split
- * @param compareData The data want to compare with the inputData,
- * it could be train data or test data
- * @param inputLabel The label of the compareData
+ * Compare the data after train test split.  This assumes that the labels
+ * correspond to each column, so that we can easily check each point against its
+ * original.
+ *
+ * @param inputData The original data set before split.
+ * @param compareData The data want to compare with the inputData;
+ *   it could be train data or test data.
+ * @param inputLabel The labels of each point in compareData.
  */
-void CompareData(arma::mat const &inputData, arma::mat const &compareData,
-                 arma::Row<size_t> const &inputLabel)
+void CompareData(const mat& inputData,
+                 const mat& compareData,
+                 const Row<size_t>& inputLabel)
 {
-  for(size_t i = 0; i != compareData.n_cols; ++i){
-    arma::mat const &lhsCol = inputData.col(inputLabel(i));
-    arma::mat const &rhsCol = compareData.col(i);
-    for(size_t j = 0; j != lhsCol.n_rows; ++j){
+  for (size_t i = 0; i != compareData.n_cols; ++i)
+  {
+    const mat& lhsCol = inputData.col(inputLabel(i));
+    const mat& rhsCol = compareData.col(i);
+    for (size_t j = 0; j != lhsCol.n_rows; ++j)
+    {
       BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
     }
   }
@@ -37,20 +43,19 @@ void CompareData(arma::mat const &inputData, arma::mat const &compareData,
 
 BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
 {
-  arma::mat input(2,10);
+  mat input(2, 10);
   input.randu();
-  using Labels = arma::Row<size_t>;
-  //set the labels range same as the col, so the CompareData
-  //can compare the data after TrainTestSplit are valid or not
-  Labels const labels =
-          arma::linspace<Labels>(0, input.n_cols-1,
-                                 input.n_cols);
+
+  // Set the labels to the column ID, so that CompareData can compare the data
+  // after TrainTestSplit is called.
+  const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
+      input.n_cols);
 
   auto const value = TrainTestSplit(input, labels, 0.2);
-  BOOST_REQUIRE(std::get<0>(value).n_cols == 8);
-  BOOST_REQUIRE(std::get<1>(value).n_cols == 2);
-  BOOST_REQUIRE(std::get<2>(value).n_cols == 8);
-  BOOST_REQUIRE(std::get<3>(value).n_cols == 2);
+  BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8);
+  BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2);
+  BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 8);
+  BOOST_REQUIRE_EQUAL(std::get<3>(value).n_cols, 2);
 
   CompareData(input, std::get<0>(value), std::get<2>(value));
   CompareData(input, std::get<1>(value), std::get<3>(value));




More information about the mlpack-git mailing list