[mlpack-git] master: change class to function (75e946e)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Apr 12 17:47:22 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ebf77f8b13323a87c433b6f639deb2369188b00c...b08ae02b90e18f97366b236e7a4d8725cd6e9050
>---------------------------------------------------------------
commit 75e946e5fa8ebea610e7f74f5d6a68a2b70d2a97
Author: stereomatchingkiss <stereomatchingkiss at gmail.com>
Date: Wed Apr 13 05:47:22 2016 +0800
change class to function
>---------------------------------------------------------------
75e946e5fa8ebea610e7f74f5d6a68a2b70d2a97
src/mlpack/core/util/split_data.hpp | 150 ++++++++++++++++++------------------
1 file changed, 74 insertions(+), 76 deletions(-)
diff --git a/src/mlpack/core/util/split_data.hpp b/src/mlpack/core/util/split_data.hpp
index 4538664..32ccaa6 100644
--- a/src/mlpack/core/util/split_data.hpp
+++ b/src/mlpack/core/util/split_data.hpp
@@ -15,91 +15,89 @@ namespace util {
/**
*Split training data and test data, please define
*ARMA_USE_CXX11 to enable move of c++11
+ *@param input input data want to split
+ *@param label input label want to split
+ *@param trainData training data split by input
+ *@param testData test data split by input
+ *@param trainLabel train label split by input
+ *@param testLabel test label split by input
+ *@param testRatio the ratio of test data
+ *@code
+ *arma::mat input = loadData();
+ *arma::Row<size_t> label = loadLabel();
+ *arma::mat trainData;
+ *arma::mat testData;
+ *arma::Row<size_t> trainLabel;
+ *arma::Row<size_t> testLabel;
+ *std::random_device rd;
+ *TrainTestSplit tts(0.25);
+ *TrainTestSplit(input, label, trainData,
+ * testData, trainLabel, testLabel);
+ *@endcode
*/
-class TrainTestSplit
+template<typename T, typename U>
+void TrainTestSplit(arma::Mat<T> const &input,
+ arma::Row<U> const &inputLabel,
+ arma::Mat<T> &trainData,
+ arma::Mat<T> &testData,
+ arma::Row<U> &trainLabel,
+ arma::Row<U> &testLabel,
+ double const testRatio)
{
-public:
- /**
- *Split training data and test data, please define
- *ARMA_USE_CXX11 to enable move of c++11
- *@param input input data want to split
- *@param label input label want to split
- *@param testRatio the ratio of test data
- *@code
- *arma::mat input = loadData();
- *arma::Row<size_t> label = loadLabel();
- *arma::mat trainData;
- *arma::mat testData;
- *arma::Row<size_t> trainLabel;
- *arma::Row<size_t> testLabel;
- *std::random_device rd;
- *TrainTestSplit tts(0.25);
- *tts.Split(input, label, trainData, testData, trainLabel,
- * testLabel);
- *@endcode
- */
- template<typename T, typename U>
- void Split(arma::Mat<T> const &input,
- arma::Row<U> const &inputLabel,
- arma::Mat<T> &trainData,
- arma::Mat<T> &testData,
- arma::Row<U> &trainLabel,
- arma::Row<U> &testLabel,
- double testRatio)
- {
- size_t const testSize =
- static_cast<size_t>(input.n_cols * testRatio);
- size_t const trainSize = input.n_cols - testSize;
- trainData.set_size(input.n_rows, trainSize);
- testData.set_size(input.n_rows, testSize);
- trainLabel.set_size(trainSize);
- testLabel.set_size(testSize);
-
- using Col = arma::Col<size_t>;
- Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
- input.n_cols);
- arma::Col<size_t> const order = arma::shuffle(sequence);
+ size_t const testSize =
+ static_cast<size_t>(input.n_cols * testRatio);
+ size_t const trainSize = input.n_cols - testSize;
+ trainData.set_size(input.n_rows, trainSize);
+ testData.set_size(input.n_rows, testSize);
+ trainLabel.set_size(trainSize);
+ testLabel.set_size(testSize);
- for(size_t i = 0; i != trainSize; ++i)
- {
- trainData.col(i) = input.col(order[i]);
- trainLabel(i) = inputLabel(order[i]);
- }
+ using Col = arma::Col<size_t>;
+ Col const sequence = arma::linspace<Col>(0, input.n_cols - 1,
+ input.n_cols);
+ arma::Col<size_t> const order = arma::shuffle(sequence);
- for(size_t i = 0; i != testSize; ++i)
- {
- testData.col(i) = input.col(order[i + trainSize]);
- testLabel(i) = inputLabel(order[i + trainSize]);
- }
+ for(size_t i = 0; i != trainSize; ++i)
+ {
+ trainData.col(i) = input.col(order[i]);
+ trainLabel(i) = inputLabel(order[i]);
}
- /**
- *Overload of Split, if you do not like to pass in
- *so many param, you could call this api instead
- *@param input input data want to split
- *@param label input label want to split
- *@return They are trainData, testData, trainLabel and
- *testLabel
- */
- template<typename T,typename U>
- std::tuple<arma::Mat<T>, arma::Mat<T>,
- arma::Row<U>, arma::Row<U>>
- Split(arma::Mat<T> const &input,
- arma::Row<U> const &inputLabel,
- double testRatio)
+ for(size_t i = 0; i != testSize; ++i)
{
- arma::Mat<T> trainData;
- arma::Mat<T> testData;
- arma::Row<U> trainLabel;
- arma::Row<U> testLabel;
+ testData.col(i) = input.col(order[i + trainSize]);
+ testLabel(i) = inputLabel(order[i + trainSize]);
+ }
+}
- Split(input, inputLabel, trainData, testData,
- trainLabel, testLabel, testRatio);
+/**
+ *Overload of Split, if you do not like to pass in
+ *so many param, you could call this api instead
+ *@param input input data want to split
+ *@param label input label want to split
+ *@return They are trainData, testData, trainLabel and
+ *testLabel
+ */
+template<typename T,typename U>
+std::tuple<arma::Mat<T>, arma::Mat<T>,
+arma::Row<U>, arma::Row<U>>
+TrainTestSplit(arma::Mat<T> const &input,
+ arma::Row<U> const &inputLabel,
+ double const testRatio)
+{
+ arma::Mat<T> trainData;
+ arma::Mat<T> testData;
+ arma::Row<U> trainLabel;
+ arma::Row<U> testLabel;
- return std::make_tuple(trainData, testData,
- trainLabel, testLabel);
- }
-};
+ TrainTestSplit(input, inputLabel,
+ trainData, testData,
+ trainLabel, testLabel,
+ testRatio);
+
+ return std::make_tuple(trainData, testData,
+ trainLabel, testLabel);
+}
} // namespace util
} // namespace mlpack
More information about the mlpack-git
mailing list