[mlpack-git] master: Add tests for Train(). (93c0f0d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 15:25:53 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/be72510a765362f86782a8892f0e979aaa4a9f62...51205e0ad285b2cf421546d8876fc63e994f2d73
>---------------------------------------------------------------
commit 93c0f0d96ffcce0afbc6f71019177c4f1416e3a3
Author: ryan <ryan at ratml.org>
Date: Mon Dec 21 15:10:10 2015 -0500
Add tests for Train().
>---------------------------------------------------------------
93c0f0d96ffcce0afbc6f71019177c4f1416e3a3
src/mlpack/tests/cf_test.cpp | 141 +++++++++++++++++++++++++++++++++++++++++++
1 file changed, 141 insertions(+)
diff --git a/src/mlpack/tests/cf_test.cpp b/src/mlpack/tests/cf_test.cpp
index 9fbce77..aaf4e2c 100644
--- a/src/mlpack/tests/cf_test.cpp
+++ b/src/mlpack/tests/cf_test.cpp
@@ -318,4 +318,145 @@ BOOST_AUTO_TEST_CASE(CFBatchPredictTest)
}
}
+/**
+ * Make sure we can train an already-trained model and it works okay.
+ */
+BOOST_AUTO_TEST_CASE(TrainTest)
+{
+ // Generate random data.
+ arma::mat randomData = arma::randu<arma::mat>(100, 100);
+ CF c(randomData);
+
+ // Now retrain with data we know about.
+ arma::mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ // Save the columns we've removed.
+ arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+ size_t currentCol = 0;
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ if (currentCol == 300)
+ break;
+
+ if (dataset(2, i) > 4.5) // 5-star rating.
+ {
+ // Make sure we don't have this user yet. This is a slow way to do this
+ // but I don't particularly care here because it's in the tests.
+ bool found = false;
+ for (size_t j = 0; j < currentCol; ++j)
+ {
+ if (savedCols(0, j) == dataset(0, i))
+ {
+ found = true;
+ break;
+ }
+ }
+
+ // If this user doesn't already exist in savedCols, add them. Otherwise
+ // ignore this point.
+ if (!found)
+ {
+ savedCols.col(currentCol) = dataset.col(i);
+ dataset.shed_col(i);
+ ++currentCol;
+ }
+ }
+ }
+
+ // Make data into sparse matrix.
+ arma::sp_mat cleanedData;
+ CF::CleanData(dataset, cleanedData);
+
+ // Now retrain.
+ c.Train(cleanedData);
+
+ // Get predictions for all user/item pairs we held back.
+ arma::Mat<size_t> combinations(2, savedCols.n_cols);
+ for (size_t i = 0; i < savedCols.n_cols; ++i)
+ {
+ combinations(0, i) = size_t(savedCols(0, i));
+ combinations(1, i) = size_t(savedCols(1, i));
+ }
+
+ arma::vec predictions;
+ c.Predict(combinations, predictions);
+
+ for (size_t i = 0; i < combinations.n_cols; ++i)
+ {
+ const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+ BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+ }
+}
+
+/**
+ * Make sure we can train a model after using the empty constructor.
+ */
+BOOST_AUTO_TEST_CASE(EmptyConstructorTrainTest)
+{
+ // Use default constructor.
+ CF c;
+
+ // Now retrain with data we know about.
+ arma::mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ // Save the columns we've removed.
+ arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+ size_t currentCol = 0;
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ if (currentCol == 300)
+ break;
+
+ if (dataset(2, i) > 4.5) // 5-star rating.
+ {
+ // Make sure we don't have this user yet. This is a slow way to do this
+ // but I don't particularly care here because it's in the tests.
+ bool found = false;
+ for (size_t j = 0; j < currentCol; ++j)
+ {
+ if (savedCols(0, j) == dataset(0, i))
+ {
+ found = true;
+ break;
+ }
+ }
+
+ // If this user doesn't already exist in savedCols, add them. Otherwise
+ // ignore this point.
+ if (!found)
+ {
+ savedCols.col(currentCol) = dataset.col(i);
+ dataset.shed_col(i);
+ ++currentCol;
+ }
+ }
+ }
+
+ // Make data into sparse matrix.
+ arma::sp_mat cleanedData;
+ CF::CleanData(dataset, cleanedData);
+
+ // Now retrain.
+ c.Train(cleanedData);
+
+ // Get predictions for all user/item pairs we held back.
+ arma::Mat<size_t> combinations(2, savedCols.n_cols);
+ for (size_t i = 0; i < savedCols.n_cols; ++i)
+ {
+ combinations(0, i) = size_t(savedCols(0, i));
+ combinations(1, i) = size_t(savedCols(1, i));
+ }
+
+ arma::vec predictions;
+ c.Predict(combinations, predictions);
+
+ for (size_t i = 0; i < combinations.n_cols; ++i)
+ {
+ const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+ BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list