[mlpack-git] master: Add another test, and check that we aren't duplicating points. (213a04d)

gitdub at mlpack.org gitdub at mlpack.org
Fri Apr 22 10:48:30 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/b08ae02b90e18f97366b236e7a4d8725cd6e9050...213a04d31645134b61aa7d3702360bb34796d7de

>---------------------------------------------------------------

commit 213a04d31645134b61aa7d3702360bb34796d7de
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Apr 22 10:48:30 2016 -0400

    Add another test, and check that we aren't duplicating points.


>---------------------------------------------------------------

213a04d31645134b61aa7d3702360bb34796d7de
 src/mlpack/tests/split_data_test.cpp | 56 +++++++++++++++++++++++++++++++++++-
 1 file changed, 55 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index a23accd..1cf7136 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -41,6 +41,32 @@ void CompareData(const mat& inputData,
   }
 }
 
+/**
+ * Check that no labels have been duplicated.
+ */
+void CheckDuplication(const Row<size_t>& trainLabels,
+                      const Row<size_t>& testLabels)
+{
+  // Assemble a vector that will hold the counts of each element.
+  Row<size_t> counts(trainLabels.n_elem + testLabels.n_elem);
+  counts.zeros();
+
+  for (size_t i = 0; i < trainLabels.n_elem; ++i)
+  {
+    BOOST_REQUIRE_LT(trainLabels[i], counts.n_elem);
+    counts[trainLabels[i]]++;
+  }
+  for (size_t i = 0; i < testLabels.n_elem; ++i)
+  {
+    BOOST_REQUIRE_LT(testLabels[i], counts.n_elem);
+    counts[testLabels[i]]++;
+  }
+
+  // Now make sure each point has been used once.
+  for (size_t i = 0; i < counts.n_elem; ++i)
+    BOOST_REQUIRE_EQUAL(counts[i], 1);
+}
+
 BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
 {
   mat input(2, 10);
@@ -51,7 +77,7 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
   const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
       input.n_cols);
 
-  auto const value = TrainTestSplit(input, labels, 0.2);
+  const auto value = TrainTestSplit(input, labels, 0.2);
   BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8);
   BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2);
   BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 8);
@@ -59,6 +85,34 @@ BOOST_AUTO_TEST_CASE(SplitDataSplitResultMat)
 
   CompareData(input, std::get<0>(value), std::get<2>(value));
   CompareData(input, std::get<1>(value), std::get<3>(value));
+
+  // The last thing to check is that we aren't duplicating any points in the
+  // train or test labels.
+  CheckDuplication(std::get<2>(value), std::get<3>(value));
+}
+
+/**
+ * The same test as above, but on a larger dataset.
+ */
+BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
+{
+  mat input(10, 497);
+  input.randu();
+
+  // Set the labels to the column ID.
+  const Row<size_t> labels = arma::linspace<Row<size_t>>(0, input.n_cols - 1,
+      input.n_cols);
+
+  const auto value = TrainTestSplit(input, labels, 0.3);
+  BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
+  BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+  BOOST_REQUIRE_EQUAL(std::get<2>(value).n_cols, 497 - size_t(0.3 * 497));
+  BOOST_REQUIRE_EQUAL(std::get<3>(value).n_cols, size_t(0.3 * 497));
+
+  CompareData(input, std::get<0>(value), std::get<2>(value));
+  CompareData(input, std::get<1>(value), std::get<3>(value));
+
+  CheckDuplication(std::get<2>(value), std::get<3>(value));
 }
 
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list