[mlpack-git] master: refine test case (7723421)

gitdub at mlpack.org gitdub at mlpack.org
Tue Apr 12 18:16:41 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050

>---------------------------------------------------------------

commit 7723421c6415787fe60cd2ace74e6ff05039e2e3
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Wed Apr 13 06:16:41 2016 +0800

    refine test case


>---------------------------------------------------------------

7723421c6415787fe60cd2ace74e6ff05039e2e3
 src/mlpack/tests/split_data_test.cpp | 54 +++++++++++++++++++++---------------
 1 file changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index 5d9fcbc..8ec4591 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -16,36 +16,44 @@ using namespace arma;
 
 BOOST_AUTO_TEST_SUITE(SplitDataTest);
 
-void compareData(arma::mat const &inputData, arma::mat const &compareData,
+/**
+ * compare the data after train test split
+ * @param inputData The original data set before split
+ * @param compareData The data want to compare with the inputData,
+ * it could be train data or test data
+ * @param inputLabel The label of the compareData
+ */
+void CompareData(arma::mat const &inputData, arma::mat const &compareData,
                  arma::Row<size_t> const &inputLabel)
 {
-    for(size_t i = 0; i != compareData.n_cols; ++i){
-        arma::mat const &lhsCol = inputData.col(inputLabel(i));
-        arma::mat const &rhsCol = compareData.col(i);
-        for(size_t j = 0; j != lhsCol.n_rows; ++j){
-            BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
-        }
+  for(size_t i = 0; i != compareData.n_cols; ++i){
+    arma::mat const &lhsCol = inputData.col(inputLabel(i));
+    arma::mat const &rhsCol = compareData.col(i);
+    for(size_t j = 0; j != lhsCol.n_rows; ++j){
+      BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
     }
+  }
 }
 
 BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
 {    
-    arma::mat input(2,10);
-    input.randu();
-    using Labels = arma::Row<size_t>;
-    Labels const labels =
-            arma::linspace<Labels>(0, input.n_cols-1,
-                                   input.n_cols);
-
-    util::TrainTestSplit tts(0.2);
-    auto const value = tts.Split(input, labels);
-    BOOST_REQUIRE(std::get<0>(value).n_cols == 8);
-    BOOST_REQUIRE(std::get<1>(value).n_cols == 2);
-    BOOST_REQUIRE(std::get<2>(value).n_cols == 8);
-    BOOST_REQUIRE(std::get<3>(value).n_cols == 2);
-
-    compareData(input, std::get<0>(value), std::get<2>(value));
-    compareData(input, std::get<1>(value), std::get<3>(value));
+  arma::mat input(2,10);
+  input.randu();
+  using Labels = arma::Row<size_t>;
+  //set the labels range same as the col, so the CompareData
+  //can compare the data after TrainTestSplit are valid or not
+  Labels const labels =
+          arma::linspace<Labels>(0, input.n_cols-1,
+                                 input.n_cols);
+
+  auto const value = util::TrainTestSplit(input, labels, 0.2);
+  BOOST_REQUIRE(std::get<0>(value).n_cols == 8);
+  BOOST_REQUIRE(std::get<1>(value).n_cols == 2);
+  BOOST_REQUIRE(std::get<2>(value).n_cols == 8);
+  BOOST_REQUIRE(std::get<3>(value).n_cols == 2);
+
+  CompareData(input, std::get<0>(value), std::get<2>(value));
+  CompareData(input, std::get<1>(value), std::get<3>(value));
 }
 
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list