[mlpack-git] master: change class to function (75e946e)

gitdub at mlpack.org gitdub at mlpack.org
Tue Apr 12 17:47:22 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050

>---------------------------------------------------------------

commit 75e946e5fa8ebea610e7f74f5d6a68a2b70d2a97
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date:   Wed Apr 13 05:47:22 2016 +0800

    change class to function


>---------------------------------------------------------------

75e946e5fa8ebea610e7f74f5d6a68a2b70d2a97
 src/mlpack/core/util/split_data.hpp | 150 ++++++++++++++++++------------------
 1 file changed, 74 insertions(+), 76 deletions(-)

diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp
index 4538664..32ccaa6 100644
--- a/src/mlpack/core/util/split_data.hpp
+++ b/src/mlpack/core/util/split_data.hpp
@@ -15,91 +15,89 @@ namespace util {
 /**
  *Split training data and test data, please define
  *ARMA_USE_CXX11 to enable move of c++11
+ *@param input input data want to split
+ *@param label input label want to split
+ *@param trainData training data split by input
+ *@param testData test data split by input
+ *@param trainLabel train label split by input
+ *@param testLabel test label split by input
+ *@param testRatio the ratio of test data
+ *@code
+ *arma::mat input = loadData();
+ *arma::Row<size_t> label = loadLabel();
+ *arma::mat trainData;
+ *arma::mat testData;
+ *arma::Row<size_t> trainLabel;
+ *arma::Row<size_t> testLabel;
+ *std::random_device rd;
+ *TrainTestSplit tts(0.25);
+ *TrainTestSplit(input, label, trainData,
+ *               testData, trainLabel, testLabel);
+ *@endcode
  */
-class TrainTestSplit
+template<typename T, typename U>
+void TrainTestSplit(arma::Mat<T> const &input,
+                    arma::Row<U> const &inputLabel,
+                    arma::Mat<T> &trainData,
+                    arma::Mat<T> &testData,
+                    arma::Row<U> &trainLabel,
+                    arma::Row<U> &testLabel,
+                    double const testRatio)
 {
-public:  
-  /**
-   *Split training data and test data, please define
-   *ARMA_USE_CXX11 to enable move of c++11
-   *@param input input data want to split
-   *@param label input label want to split
-   *@param testRatio the ratio of test data   
-   *@code
-   *arma::mat input = loadData();
-   *arma::Row<size_t> label = loadLabel();
-   *arma::mat trainData;
-   *arma::mat testData;
-   *arma::Row<size_t> trainLabel;
-   *arma::Row<size_t> testLabel;
-   *std::random_device rd;
-   *TrainTestSplit tts(0.25);
-   *tts.Split(input, label, trainData, testData, trainLabel,
-   *          testLabel);
-   *@endcode
-   */
-  template<typename T, typename U>
-  void Split(arma::Mat<T> const &input,
-             arma::Row<U> const &inputLabel,
-             arma::Mat<T> &trainData,
-             arma::Mat<T> &testData,
-             arma::Row<U> &trainLabel,
-             arma::Row<U> &testLabel,
-             double testRatio)
-  {
-    size_t const testSize =
-        static_cast<size_t>(input.n_cols * testRatio);
-    size_t const trainSize = input.n_cols - testSize;
-    trainData.set_size(input.n_rows, trainSize);
-    testData.set_size(input.n_rows, testSize);
-    trainLabel.set_size(trainSize);
-    testLabel.set_size(testSize);
-
-    using Col = arma::Col<size_t>;    
-    Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
-                                             input.n_cols);
-    arma::Col<size_t> const order = arma::shuffle(sequence);
+  size_t const testSize =
+      static_cast<size_t>(input.n_cols * testRatio);
+  size_t const trainSize = input.n_cols - testSize;
+  trainData.set_size(input.n_rows, trainSize);
+  testData.set_size(input.n_rows, testSize);
+  trainLabel.set_size(trainSize);
+  testLabel.set_size(testSize);
 
-    for(size_t i = 0; i != trainSize; ++i)
-    {      
-      trainData.col(i) = input.col(order[i]);
-      trainLabel(i) = inputLabel(order[i]);
-    }
+  using Col = arma::Col<size_t>;
+  Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
+                                           input.n_cols);
+  arma::Col<size_t> const order = arma::shuffle(sequence);
 
-    for(size_t i = 0; i != testSize; ++i)
-    {
-      testData.col(i) = input.col(order[i + trainSize]);
-      testLabel(i) = inputLabel(order[i + trainSize]);
-    }
+  for(size_t i = 0; i != trainSize; ++i)
+  {
+    trainData.col(i) = input.col(order[i]);
+    trainLabel(i) = inputLabel(order[i]);
   }
 
-  /**
-   *Overload of Split, if you do not like to pass in
-   *so many param, you could call this api instead
-   *@param input input data want to split
-   *@param label input label want to split
-   *@return They are trainData, testData, trainLabel and
-   *testLabel
-   */
-  template<typename T,typename U>
-  std::tuple<arma::Mat<T>, arma::Mat<T>,
-  arma::Row<U>, arma::Row<U>>
-  Split(arma::Mat<T> const &input,
-        arma::Row<U> const &inputLabel,
-        double testRatio)
+  for(size_t i = 0; i != testSize; ++i)
   {
-    arma::Mat<T> trainData;
-    arma::Mat<T> testData;
-    arma::Row<U> trainLabel;
-    arma::Row<U> testLabel;
+    testData.col(i) = input.col(order[i + trainSize]);
+    testLabel(i) = inputLabel(order[i + trainSize]);
+  }
+}
 
-    Split(input, inputLabel, trainData, testData,
-          trainLabel, testLabel, testRatio);
+/**
+ *Overload of Split, if you do not like to pass in
+ *so many param, you could call this api instead
+ *@param input input data want to split
+ *@param label input label want to split
+ *@return They are trainData, testData, trainLabel and
+ *testLabel
+ */
+template<typename T,typename U>
+std::tuple<arma::Mat<T>, arma::Mat<T>,
+arma::Row<U>, arma::Row<U>>
+TrainTestSplit(arma::Mat<T> const &input,
+               arma::Row<U> const &inputLabel,
+               double const testRatio)
+{
+  arma::Mat<T> trainData;
+  arma::Mat<T> testData;
+  arma::Row<U> trainLabel;
+  arma::Row<U> testLabel;
 
-    return std::make_tuple(trainData, testData,
-                           trainLabel, testLabel);
-  }
-};
+  TrainTestSplit(input, inputLabel,
+                 trainData, testData,
+                 trainLabel, testLabel,
+                 testRatio);
+
+  return std::make_tuple(trainData, testData,
+                         trainLabel, testLabel);
+}
 
 } // namespace util
 } // namespace mlpack




More information about the mlpack-git mailing list