[mlpack-git] master: refine style and detail test (9e88669)
gitdub at mlpack.org
gitdub at mlpack.org
Tue May 31 03:35:14 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1f562a1aba7ae55475afcc95659511c2b7f694e5...5b8fdce471328f722fcd8c0f22a6d995ce22c98b
>---------------------------------------------------------------
commit 9e88669c6c5a8fcc9efb2ce37d09f99d49e4e302
Author: Keon Kim <kwk236 at gmail.com>
Date: Mon May 30 19:15:45 2016 +0900
refine style and detail test
>---------------------------------------------------------------
9e88669c6c5a8fcc9efb2ce37d09f99d49e4e302
src/mlpack/core/data/split_data.hpp | 30 +++++++++++-----------
.../methods/preprocess/preprocess_split_main.cpp | 24 +++++------------
src/mlpack/tests/split_data_test.cpp | 28 ++++++++++++++++++--
3 files changed, 48 insertions(+), 34 deletions(-)
diff --git a/src/mlpack/core/data/split_data.hpp b/src/mlpack/core/data/split_data.hpp
index 1df1b28..3979caf 100644
--- a/src/mlpack/core/data/split_data.hpp
+++ b/src/mlpack/core/data/split_data.hpp
@@ -42,12 +42,12 @@ namespace data {
*/
template<typename T, typename U>
void Split(const arma::Mat<T>& input,
- const arma::Row<U>& inputLabel,
- arma::Mat<T>& trainData,
- arma::Mat<T>& testData,
- arma::Row<U>& trainLabel,
- arma::Row<U>& testLabel,
- const double testRatio)
+ const arma::Row<U>& inputLabel,
+ arma::Mat<T>& trainData,
+ arma::Mat<T>& testData,
+ arma::Row<U>& trainLabel,
+ arma::Row<U>& testLabel,
+ const double testRatio)
{
const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
const size_t trainSize = input.n_cols - testSize;
@@ -96,9 +96,9 @@ void Split(const arma::Mat<T>& input,
*/
template<typename T>
void Split(const arma::Mat<T>& input,
- arma::Mat<T>& trainData,
- arma::Mat<T>& testData,
- const double testRatio)
+ arma::Mat<T>& trainData,
+ arma::Mat<T>& testData,
+ const double testRatio)
{
const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
const size_t trainSize = input.n_cols - testSize;
@@ -111,11 +111,11 @@ void Split(const arma::Mat<T>& input,
for (size_t i = 0; i != trainSize; ++i)
{
- trainData.col(i) = input.col(order[i]);
+ trainData.col(i) = input.col(order[i]);
}
for (size_t i = 0; i != testSize; ++i)
{
- testData.col(i) = input.col(order[i + trainSize]);
+ testData.col(i) = input.col(order[i + trainSize]);
}
}
@@ -141,8 +141,8 @@ void Split(const arma::Mat<T>& input,
template<typename T,typename U>
std::tuple<arma::Mat<T>, arma::Mat<T>, arma::Row<U>, arma::Row<U>>
Split(const arma::Mat<T>& input,
- const arma::Row<U>& inputLabel,
- const double testRatio)
+ const arma::Row<U>& inputLabel,
+ const double testRatio)
{
arma::Mat<T> trainData;
arma::Mat<T> testData;
@@ -174,13 +174,13 @@ Split(const arma::Mat<T>& input,
template<typename T>
std::tuple<arma::Mat<T>, arma::Mat<T>>
Split(const arma::Mat<T>& input,
- const double testRatio)
+ const double testRatio)
{
arma::Mat<T> trainData;
arma::Mat<T> testData;
Split(input, trainData, testData, testRatio);
- return std::make_tuple(trainData, testData);
+ return std::make_tuple(std::move(trainData), std::move(testData));
}
} // namespace data
diff --git a/src/mlpack/methods/preprocess/preprocess_split_main.cpp b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
index 02e844a..ca8e830 100644
--- a/src/mlpack/methods/preprocess/preprocess_split_main.cpp
+++ b/src/mlpack/methods/preprocess/preprocess_split_main.cpp
@@ -87,7 +87,7 @@ int main(int argc, char** argv)
if (CLI::HasParam("test_ratio"))
{
//sanity check on test_ratio
- if ((testRatio < 0.0) && (testRatio > 1.0))
+ if ((testRatio < 0.0) || (testRatio > 1.0))
{
Log::Fatal << "Invalid parameter for test_ratio. "
<< "test_ratio must be between 0.0 and 1.0" << endl;
@@ -116,16 +116,10 @@ int main(int argc, char** argv)
Log::Info << "Train Label Count: " << get<2>(value).n_cols << endl;
Log::Info << "Test Label Count: " << get<3>(value).n_cols << endl;
- // TODO: fix full precision problem
- mat training = get<0>(value);
- mat test = get<1>(value);
- mat trainingLabels = get<2>(value);
- mat testLabels = get<3>(value);
-
- data::Save(trainingFile, training, false);
- data::Save(testFile, test, false);
- data::Save(trainingLabelsFile, trainingLabels, false);
- data::Save(testLabelsFile, testLabels, false);
+ data::Save(trainingFile, get<0>(value), false);
+ data::Save(testFile, get<1>(value), false);
+ data::Save(trainingLabelsFile, get<2>(value), false);
+ data::Save(testLabelsFile, get<3>(value), false);
}
else // split without parameters
{
@@ -133,12 +127,8 @@ int main(int argc, char** argv)
Log::Info << "Train Data Count: " << get<0>(value).n_cols << endl;
Log::Info << "Test Data Count: " << get<1>(value).n_cols << endl;
- // TODO: fix full precision problem
- mat training = get<0>(value);
- mat test = get<1>(value);
-
- data::Save(trainingFile, training, false);
- data::Save(testFile, test, false);
+ data::Save(trainingFile, get<0>(value), false);
+ data::Save(testFile, get<1>(value), false);
}
}
diff --git a/src/mlpack/tests/split_data_test.cpp b/src/mlpack/tests/split_data_test.cpp
index d7b3990..daf4cd5 100644
--- a/src/mlpack/tests/split_data_test.cpp
+++ b/src/mlpack/tests/split_data_test.cpp
@@ -41,6 +41,22 @@ void CompareData(const mat& inputData,
}
}
+void CheckMatEqual(const mat& inputData,
+ const mat& compareData)
+{
+ const mat& sortedInput = arma::sort(inputData, "ascend", 1);
+ const mat& sortedCompare = arma::sort(compareData, "ascend", 1);
+ for (size_t i = 0; i < sortedInput.n_cols; ++i)
+ {
+ const mat& lhsCol = sortedInput.col(i);
+ const mat& rhsCol = sortedCompare.col(i);
+ for (size_t j = 0; j < lhsCol.n_rows; ++j)
+ {
+ BOOST_REQUIRE_CLOSE(lhsCol(j), rhsCol(j), 1e-5);
+ }
+ }
+}
+
/**
* Check that no labels have been duplicated.
*/
@@ -70,11 +86,15 @@ void CheckDuplication(const Row<size_t>& trainLabels,
BOOST_AUTO_TEST_CASE(SplitDataResultMat)
{
mat input(2, 10);
- input.randu();
+ size_t count = 0; // count for putting unique sequential values
+ input.imbue([&count] () { return ++count; });
const auto value = Split(input, 0.2);
BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 8); // train data
BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, 2); // test data
+
+ mat concat = arma::join_rows(std::get<0>(value), std::get<1>(value));
+ CheckMatEqual(input, concat);
}
BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
@@ -106,12 +126,16 @@ BOOST_AUTO_TEST_CASE(SplitLabeledDataResultMat)
*/
BOOST_AUTO_TEST_CASE(SplitDataLargerTest)
{
+ size_t count = 0;
mat input(10, 497);
- input.randu();
+ input.imbue([&count] () { return ++count; });
const auto value = Split(input, 0.3);
BOOST_REQUIRE_EQUAL(std::get<0>(value).n_cols, 497 - size_t(0.3 * 497));
BOOST_REQUIRE_EQUAL(std::get<1>(value).n_cols, size_t(0.3 * 497));
+
+ mat concat = arma::join_rows(std::get<0>(value), std::get<1>(value));
+ CheckMatEqual(input, concat);
}
BOOST_AUTO_TEST_CASE(SplitLabeledDataLargerTest)
More information about the mlpack-git
mailing list