[mlpack-git] master: refine test case (7723421)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Apr 12 18:16:41 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050
>---------------------------------------------------------------
commit 7723421c6415787fe60cd2ace74e6ff05039e2e3
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date: Wed Apr 13 06:16:41 2016 +0800
refine test case
>---------------------------------------------------------------
7723421c6415787fe60cd2ace74e6ff05039e2e3
src/mlpack/tests/split_data_test.cpp | 54 +++++++++++++++++++++---------------
1 file changed, 31 insertions(+), 23 deletions(-)
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index 5d9fcbc..8ec4591 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -16,36 +16,44 @@ using namespace arma;
BOOST_AUTO_TEST_SUITE(SplitDataTest);
-void compareData(arma::mat const &inputData, arma::mat const &compareData,
+/**
+ * compare the data after train test split
+ * @param inputData The original data set before split
+ * @param compareData The data want to compare with the inputData,
+ * it could be train data or test data
+ * @param inputLabel The label of the compareData
+ */
+void CompareData(arma::mat const &inputData, arma::mat const &compareData,
arma::Row<size_t> const &inputLabel)
{
- for(size_t i = 0; i != compareData.n_cols; ++i){
- arma::mat const &lhsCol = inputData.col(inputLabel(i));
- arma::mat const &rhsCol = compareData.col(i);
- for(size_t j = 0; j != lhsCol.n_rows; ++j){
- BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
- }
+ for(size_t i = 0; i != compareData.n_cols; ++i){
+ arma::mat const &lhsCol = inputData.col(inputLabel(i));
+ arma::mat const &rhsCol = compareData.col(i);
+ for(size_t j = 0; j != lhsCol.n_rows; ++j){
+ BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
}
+ }
}
BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
{
- arma::mat input(2,10);
- input.randu();
- using Labels = arma::Row<size_t>;
- Labels const labels =
- arma::linspace<Labels>(0, input.n_cols-1,
- input.n_cols);
-
- util::TrainTestSplit tts(0.2);
- auto const value = tts.Split(input, labels);
- BOOST_REQUIRE(std::get<0>(value).n_cols == 8);
- BOOST_REQUIRE(std::get<1>(value).n_cols == 2);
- BOOST_REQUIRE(std::get<2>(value).n_cols == 8);
- BOOST_REQUIRE(std::get<3>(value).n_cols == 2);
-
- compareData(input, std::get<0>(value), std::get<2>(value));
- compareData(input, std::get<1>(value), std::get<3>(value));
+ arma::mat input(2,10);
+ input.randu();
+ using Labels = arma::Row<size_t>;
+ //set the labels range same as the col, so the CompareData
+ //can compare the data after TrainTestSplit are valid or not
+ Labels const labels =
+ arma::linspace<Labels>(0, input.n_cols-1,
+ input.n_cols);
+
+ auto const value = util::TrainTestSplit(input, labels, 0.2);
+ BOOST_REQUIRE(std::get<0>(value).n_cols == 8);
+ BOOST_REQUIRE(std::get<1>(value).n_cols == 2);
+ BOOST_REQUIRE(std::get<2>(value).n_cols == 8);
+ BOOST_REQUIRE(std::get<3>(value).n_cols == 2);
+
+ CompareData(input, std::get<0>(value), std::get<2>(value));
+ CompareData(input, std::get<1>(value), std::get<3>(value));
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list