[mlpack-git] master: refine style and detail test (9e88669)

gitdub at mlpack.org gitdub at mlpack.org
Tue May 31 03:35:14 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b

>---------------------------------------------------------------

commit 9e88669c6c5a8fcc9efb2ce37d09f99d49e4e302
Author: Keon Kim <kwk236 at gmail.com>
Date:   Mon May 30 19:15:45 2016 +0900

    refine style and detail test


>---------------------------------------------------------------

9e88669c6c5a8fcc9efb2ce37d09f99d49e4e302
 src/mlpack/core/data/split_data.hpp                | 30 +++++++++++-----------
 .../methods/preprocess/preprocess_split_main.cpp   | 24 +++++------------
 src/mlpack/tests/split_data_test.cpp               | 28 ++++++++++++++++++--
 3 files changed, 48 insertions(+), 34 deletions(-)

diff --git a/src/mlpack/core/data/split_data.hpp b/src/mlpack/core/data/split_data.hpp
index 1df1b28..3979caf 100644
--- a/src/mlpack/core/data/split_data.hpp
+++ b/src/mlpack/core/data/split_data.hpp
@@ -42,12 +42,12 @@ namespace data {
  */
 template<typename T, typename U>
 void Split(const arma::Mat<T>& input,
-                    const arma::Row<U>& inputLabel,
-                    arma::Mat<T>& trainData,
-                    arma::Mat<T>& testData,
-                    arma::Row<U>& trainLabel,
-                    arma::Row<U>& testLabel,
-                    const double testRatio)
+           const arma::Row<U>& inputLabel,
+           arma::Mat<T>& trainData,
+           arma::Mat<T>& testData,
+           arma::Row<U>& trainLabel,
+           arma::Row<U>& testLabel,
+           const double testRatio)
 {
   const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
   const size_t trainSize = input.n_cols - testSize;
@@ -96,9 +96,9 @@ void Split(const arma::Mat<T>& input,
  */
 template<typename T>
 void Split(const arma::Mat<T>& input,
-                    arma::Mat<T>& trainData,
-                    arma::Mat<T>& testData,
-                    const double testRatio)
+           arma::Mat<T>& trainData,
+           arma::Mat<T>& testData,
+           const double testRatio)
 {
   const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
   const size_t trainSize = input.n_cols - testSize;
@@ -111,11 +111,11 @@ void Split(const arma::Mat<T>& input,
 
   for (size_t i = 0; i != trainSize; ++i)
   {
-     trainData.col(i) = input.col(order[i]);
+    trainData.col(i) = input.col(order[i]);
   }
   for (size_t i = 0; i != testSize; ++i)
   {
-     testData.col(i) = input.col(order[i + trainSize]);
+    testData.col(i) = input.col(order[i + trainSize]);
   }
 }
 
@@ -141,8 +141,8 @@ void Split(const arma::Mat<T>& input,
 template<typename T,typename U>
 std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
 Split(const arma::Mat<T>& input,
-               const arma::Row<U>& inputLabel,
-               const double testRatio)
+      const arma::Row<U>& inputLabel,
+      const double testRatio)
 {
   arma::Mat<T> trainData;
   arma::Mat<T> testData;
@@ -174,13 +174,13 @@ Split(const arma::Mat<T>& input,
 template<typename T>
 std::tuple<arma::Mat<T>, arma::Mat<T>>
 Split(const arma::Mat<T>& input,
-               const double testRatio)
+      const double testRatio)
 {
   arma::Mat<T> trainData;
   arma::Mat<T> testData;
   Split(input, trainData, testData, testRatio);
 
-  return std::make_tuple(trainData, testData);
+  return std::make_tuple(std::move(trainData), std::move(testData));
 }
 
 } // namespace data
diff --git a/src/mlpack/methods/preprocess/preprocess_split_main.cpp b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
index 02e844a..ca8e830 100644
--- a/src/mlpack/methods/preprocess/preprocess_split_main.cpp
+++ b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
@@ -87,7 +87,7 @@ int main(int argc, char** argv)
   if (CLI::HasParam("test_ratio"))
   {
     //sanity check on test_ratio
-    if ((testRatio < 0.0) && (testRatio > 1.0))
+    if ((testRatio < 0.0) || (testRatio > 1.0))
     {
       Log::Fatal << "Invalid parameter for test_ratio. "
         << "test_ratio must be between 0.0 and 1.0" << endl;
@@ -116,16 +116,10 @@ int main(int argc, char** argv)
     Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
     Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
 
-    // TODO: fix full precision problem
-    mat training = get<0>(value);
-    mat test = get<1>(value);
-    mat trainingLabels = get<2>(value);
-    mat testLabels = get<3>(value);
-
-    data::Save(trainingFile, training, false);
-    data::Save(testFile, test, false);
-    data::Save(trainingLabelsFile, trainingLabels, false);
-    data::Save(testLabelsFile, testLabels, false);
+    data::Save(trainingFile, get<0>(value), false);
+    data::Save(testFile, get<1>(value), false);
+    data::Save(trainingLabelsFile, get<2>(value), false);
+    data::Save(testLabelsFile, get<3>(value), false);
   }
   else // split without parameters
   {
@@ -133,12 +127,8 @@ int main(int argc, char** argv)
     Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
     Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
 
-    // TODO: fix full precision problem
-    mat training = get<0>(value);
-    mat test = get<1>(value);
-
-    data::Save(trainingFile, training, false);
-    data::Save(testFile, test, false);
+    data::Save(trainingFile, get<0>(value), false);
+    data::Save(testFile, get<1>(value), false);
   }
 }
 
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index d7b3990..daf4cd5 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -41,6 +41,22 @@ void CompareData(const mat& inputData,
   }
 }
 
+void CheckMatEqual(const mat& inputData,
+                   const mat& compareData)
+{
+  const mat& sortedInput = arma::sort(inputData, "ascend", 1);
+  const mat& sortedCompare = arma::sort(compareData, "ascend", 1);
+  for (size_t i = 0; i < sortedInput.n_cols; ++i)
+  {
+    const mat& lhsCol = sortedInput.col(i);
+    const mat& rhsCol = sortedCompare.col(i);
+    for (size_t j = 0; j < lhsCol.n_rows; ++j)
+    {
+      BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
+    }
+  }
+}
+
 /**
  * Check that no labels have been duplicated.
  */
@@ -70,11 +86,15 @@ void CheckDuplication(const Row<size_t>& trainLabels,
 BOOST_AUTO_TEST_CASE(SplitDataResultMat)
 {
   mat input(2, 10);
-  input.randu();
+  size_t count = 0; // count for putting unique sequential values
+  input.imbue([&count] () { return ++count; });
 
   const auto value = Split(input, 0.2);
   BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8); // train data
   BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2); // test data
+
+  mat concat = arma::join_rows(std::get<0>(value), std::get<1>(value));
+  CheckMatEqual(input, concat);
 }
 
 BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
@@ -106,12 +126,16 @@ BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
  */
 BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
 {
+  size_t count = 0;
   mat input(10, 497);
-  input.randu();
+  input.imbue([&count] () { return ++count; });
 
   const auto value = Split(input, 0.3);
   BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
   BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+
+  mat concat = arma::join_rows(std::get<0>(value), std::get<1>(value));
+  CheckMatEqual(input, concat);
 }
 
 BOOST_AUTO_TEST_CASE(SplitLabeledDataLargerTest)




More information about the mlpack-git mailing list