[mlpack-git] master: 1 : remove seed variable 2 : do not store testRatio, pass in by function (595ff70)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Apr 12 17:32:10 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050
>---------------------------------------------------------------
commit 595ff7061fc3efc327553dbe18ae603d5e93f326
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date: Wed Apr 13 05:32:10 2016 +0800
1 : remove seed variable
2 : do not store testRatio, pass in by function
>---------------------------------------------------------------
595ff7061fc3efc327553dbe18ae603d5e93f326
src/mlpack/core/util/split_data.hpp | 45 +++++--------------------------------
1 file changed, 5 insertions(+), 40 deletions(-)
diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp
index dd4b0df..4538664 100644
--- a/src/mlpack/core/util/split_data.hpp
+++ b/src/mlpack/core/util/split_data.hpp
@@ -20,25 +20,11 @@ class TrainTestSplit
{
public:
/**
- * @brief TrainTestSplit
- * @param testRatio the ratio of test data
- * @param seed seed of the random device
- * @warning slice should not less than 1
- */
- TrainTestSplit(double testRatio,
- arma::arma_rng::seed_type seed = 0) :
- seed(seed),
- testRatio(testRatio)
- {
- }
-
- /**
*Split training data and test data, please define
*ARMA_USE_CXX11 to enable move of c++11
*@param input input data want to split
*@param label input label want to split
*@param testRatio the ratio of test data
- *@param seed seed of the random device
*@code
*arma::mat input = loadData();
*arma::Row<size_t> label = loadLabel();
@@ -58,7 +44,8 @@ public:
arma::Mat<T> &trainData,
arma::Mat<T> &testData,
arma::Row<U> &trainLabel,
- arma::Row<U> &testLabel)
+ arma::Row<U> &testLabel,
+ double testRatio)
{
size_t const testSize =
static_cast<size_t>(input.n_cols * testRatio);
@@ -69,7 +56,6 @@ public:
testLabel.set_size(testSize);
using Col = arma::Col<size_t>;
- arma::arma_rng::set_seed(seed);
Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
input.n_cols);
arma::Col<size_t> const order = arma::shuffle(sequence);
@@ -99,7 +85,8 @@ public:
std::tuple<arma::Mat<T>, arma::Mat<T>,
arma::Row<U>, arma::Row<U>>
Split(arma::Mat<T> const &input,
- arma::Row<U> const &inputLabel)
+ arma::Row<U> const &inputLabel,
+ double testRatio)
{
arma::Mat<T> trainData;
arma::Mat<T> testData;
@@ -107,33 +94,11 @@ public:
arma::Row<U> testLabel;
Split(input, inputLabel, trainData, testData,
- trainLabel, testLabel);
+ trainLabel, testLabel, testRatio);
return std::make_tuple(trainData, testData,
trainLabel, testLabel);
}
-
- void Seed(arma::arma_rng::seed_type value)
- {
- seed = value;
- }
- arma::arma_rng::seed_type Seed() const
- {
- return seed;
- }
-
- void TestRatio(double value)
- {
- testRatio = value;
- }
- double TestRatio() const
- {
- return testRatio;
- }
-
-private:
- arma::arma_rng::seed_type seed;
- double testRatio;
};
} // namespace util
More information about the mlpack-git
mailing list