[mlpack-git] master: 1 : remove seed variable 2 : do not store testRatio, pass in by function (595ff70)

gitdub at mlpack.org gitdub at mlpack.org
Tue Apr 12 17:32:10 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050

>---------------------------------------------------------------

commit 595ff7061fc3efc327553dbe18ae603d5e93f326
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Wed Apr 13 05:32:10 2016 +0800

    1 : remove seed variable
    2 : do not store testRatio, pass in by function


>---------------------------------------------------------------

595ff7061fc3efc327553dbe18ae603d5e93f326
 src/mlpack/core/util/split_data.hpp | 45 +++++--------------------------------
 1 file changed, 5 insertions(+), 40 deletions(-)

diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp
index dd4b0df..4538664 100644
--- a/src/mlpack/core/util/split_data.hpp
+++ b/src/mlpack/core/util/split_data.hpp
@@ -20,25 +20,11 @@ class TrainTestSplit
 {
 public:  
   /**
-   * @brief TrainTestSplit
-   * @param testRatio the ratio of test data
-   * @param seed seed of the random device
-   * @warning slice should not less than 1
-   */
-  TrainTestSplit(double testRatio,                
-                 arma::arma_rng::seed_type seed = 0) :
-    seed(seed),    
-    testRatio(testRatio)
-  {    
-  }
-
-  /**
    *Split training data and test data, please define
    *ARMA_USE_CXX11 to enable move of c++11
    *@param input input data want to split
    *@param label input label want to split
    *@param testRatio the ratio of test data   
-   *@param seed seed of the random device
    *@code
    *arma::mat input = loadData();
    *arma::Row<size_t> label = loadLabel();
@@ -58,7 +44,8 @@ public:
              arma::Mat<T> &trainData,
              arma::Mat<T> &testData,
              arma::Row<U> &trainLabel,
-             arma::Row<U> &testLabel)
+             arma::Row<U> &testLabel,
+             double testRatio)
   {
     size_t const testSize =
         static_cast<size_t>(input.n_cols * testRatio);
@@ -69,7 +56,6 @@ public:
     testLabel.set_size(testSize);
 
     using Col = arma::Col<size_t>;    
-    arma::arma_rng::set_seed(seed);
     Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
                                              input.n_cols);
     arma::Col<size_t> const order = arma::shuffle(sequence);
@@ -99,7 +85,8 @@ public:
   std::tuple<arma::Mat<T>, arma::Mat<T>,
   arma::Row<U>, arma::Row<U>>
   Split(arma::Mat<T> const &input,
-        arma::Row<U> const &inputLabel)
+        arma::Row<U> const &inputLabel,
+        double testRatio)
   {
     arma::Mat<T> trainData;
     arma::Mat<T> testData;
@@ -107,33 +94,11 @@ public:
     arma::Row<U> testLabel;
 
     Split(input, inputLabel, trainData, testData,
-          trainLabel, testLabel);
+          trainLabel, testLabel, testRatio);
 
     return std::make_tuple(trainData, testData,
                            trainLabel, testLabel);
   }
-
-  void Seed(arma::arma_rng::seed_type value)
-  {
-    seed = value;
-  }
-  arma::arma_rng::seed_type Seed() const
-  {
-    return seed;
-  }   
-
-  void TestRatio(double value)
-  {
-    testRatio = value;
-  }
-  double TestRatio() const
-  {
-    return testRatio;
-  }
-
-private:  
-  arma::arma_rng::seed_type seed;  
-  double testRatio;
 };
 
 } // namespace util




More information about the mlpack-git mailing list