[mlpack-git] master: Add another test, and check that we aren't duplicating points. (213a04d)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Apr 22 10:48:30 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/b08ae02b90e18f97366b236e7a4d8725cd6e9050...213a04d31645134b61aa7d3702360bb34796d7de
>---------------------------------------------------------------
commit 213a04d31645134b61aa7d3702360bb34796d7de
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Apr 22 10:48:30 2016 -0400
Add another test, and check that we aren't duplicating points.
>---------------------------------------------------------------
213a04d31645134b61aa7d3702360bb34796d7de
src/mlpack/tests/split_data_test.cpp | 56 +++++++++++++++++++++++++++++++++++-
1 file changed, 55 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index a23accd..1cf7136 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -41,6 +41,32 @@ void CompareData(const mat& inputData,
}
}
+/**
+ * Check that no labels have been duplicated.
+ */
+void CheckDuplication(const Row<size_t>& trainLabels,
+ const Row<size_t>& testLabels)
+{
+ // Assemble a vector that will hold the counts of each element.
+ Row<size_t> counts(trainLabels.n_elem + testLabels.n_elem);
+ counts.zeros();
+
+ for (size_t i = 0; i < trainLabels.n_elem; ++i)
+ {
+ BOOST_REQUIRE_LT(trainLabels[i], counts.n_elem);
+ counts[trainLabels[i]]++;
+ }
+ for (size_t i = 0; i < testLabels.n_elem; ++i)
+ {
+ BOOST_REQUIRE_LT(testLabels[i], counts.n_elem);
+ counts[testLabels[i]]++;
+ }
+
+ // Now make sure each point has been used once.
+ for (size_t i = 0; i < counts.n_elem; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+}
+
BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
{
mat input(2, 10);
@@ -51,7 +77,7 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
input.n_cols);
- auto const value = TrainTestSplit(input, labels, 0.2);
+ const auto value = TrainTestSplit(input, labels, 0.2);
BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8);
BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2);
BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 8);
@@ -59,6 +85,34 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
CompareData(input, std::get<0>(value), std::get<2>(value));
CompareData(input, std::get<1>(value), std::get<3>(value));
+
+ // The last thing to check is that we aren't duplicating any points in the
+ // train or test labels.
+ CheckDuplication(std::get<2>(value), std::get<3>(value));
+}
+
+/**
+ * The same test as above, but on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
+{
+ mat input(10, 497);
+ input.randu();
+
+ // Set the labels to the column ID.
+ const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
+ input.n_cols);
+
+ const auto value = TrainTestSplit(input, labels, 0.3);
+ BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
+ BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+ BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 497 - size_t(0.3 * 497));
+ BOOST_REQUIRE_EQUAL(std::get<3>(value).n_cols, size_t(0.3 * 497));
+
+ CompareData(input, std::get<0>(value), std::get<2>(value));
+ CompareData(input, std::get<1>(value), std::get<3>(value));
+
+ CheckDuplication(std::get<2>(value), std::get<3>(value));
}
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list