[mlpack-git] master: Add tests for Train(). (93c0f0d)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Dec 21 15:25:53 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/be72510a765362f86782a8892f0e979aaa4a9f62...51205e0ad285b2cf421546d8876fc63e994f2d73

>---------------------------------------------------------------

commit 93c0f0d96ffcce0afbc6f71019177c4f1416e3a3
Author: ryan <ryan at ratml.org>
Date:   Mon Dec 21 15:10:10 2015 -0500

    Add tests for Train().


>---------------------------------------------------------------

93c0f0d96ffcce0afbc6f71019177c4f1416e3a3
 src/mlpack/tests/cf_test.cpp | 141 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 141 insertions(+)

diff --git a/src/mlpack/tests/cf_test.cpp b/src/mlpack/tests/cf_test.cpp
index 9fbce77..aaf4e2c 100644
--- a/src/mlpack/tests/cf_test.cpp
+++ b/src/mlpack/tests/cf_test.cpp
@@ -318,4 +318,145 @@ BOOST_AUTO_TEST_CASE(CFBatchPredictTest)
   }
 }
 
+/**
+ * Make sure we can train an already-trained model and it works okay.
+ */
+BOOST_AUTO_TEST_CASE(TrainTest)
+{
+  // Generate random data.
+  arma::mat randomData = arma::randu<arma::mat>(100, 100);
+  CF c(randomData);
+
+  // Now retrain with data we know about.
+  arma::mat dataset;
+  data::Load("GroupLens100k.csv", dataset);
+
+  // Save the columns we've removed.
+  arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+  size_t currentCol = 0;
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    if (currentCol == 300)
+      break;
+
+    if (dataset(2, i) > 4.5) // 5-star rating.
+    {
+      // Make sure we don't have this user yet.  This is a slow way to do this
+      // but I don't particularly care here because it's in the tests.
+      bool found = false;
+      for (size_t j = 0; j < currentCol; ++j)
+      {
+        if (savedCols(0, j) == dataset(0, i))
+        {
+          found = true;
+          break;
+        }
+      }
+
+      // If this user doesn't already exist in savedCols, add them.  Otherwise
+      // ignore this point.
+      if (!found)
+      {
+        savedCols.col(currentCol) = dataset.col(i);
+        dataset.shed_col(i);
+        ++currentCol;
+      }
+    }
+  }
+
+  // Make data into sparse matrix.
+  arma::sp_mat cleanedData;
+  CF::CleanData(dataset, cleanedData);
+
+  // Now retrain.
+  c.Train(cleanedData);
+
+  // Get predictions for all user/item pairs we held back.
+  arma::Mat<size_t> combinations(2, savedCols.n_cols);
+  for (size_t i = 0; i < savedCols.n_cols; ++i)
+  {
+    combinations(0, i) = size_t(savedCols(0, i));
+    combinations(1, i) = size_t(savedCols(1, i));
+  }
+
+  arma::vec predictions;
+  c.Predict(combinations, predictions);
+
+  for (size_t i = 0; i < combinations.n_cols; ++i)
+  {
+    const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+    BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+  }
+}
+
+/**
+ * Make sure we can train a model after using the empty constructor.
+ */
+BOOST_AUTO_TEST_CASE(EmptyConstructorTrainTest)
+{
+  // Use default constructor.
+  CF c;
+
+  // Now retrain with data we know about.
+  arma::mat dataset;
+  data::Load("GroupLens100k.csv", dataset);
+
+  // Save the columns we've removed.
+  arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+  size_t currentCol = 0;
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    if (currentCol == 300)
+      break;
+
+    if (dataset(2, i) > 4.5) // 5-star rating.
+    {
+      // Make sure we don't have this user yet.  This is a slow way to do this
+      // but I don't particularly care here because it's in the tests.
+      bool found = false;
+      for (size_t j = 0; j < currentCol; ++j)
+      {
+        if (savedCols(0, j) == dataset(0, i))
+        {
+          found = true;
+          break;
+        }
+      }
+
+      // If this user doesn't already exist in savedCols, add them.  Otherwise
+      // ignore this point.
+      if (!found)
+      {
+        savedCols.col(currentCol) = dataset.col(i);
+        dataset.shed_col(i);
+        ++currentCol;
+      }
+    }
+  }
+
+  // Make data into sparse matrix.
+  arma::sp_mat cleanedData;
+  CF::CleanData(dataset, cleanedData);
+
+  // Now retrain.
+  c.Train(cleanedData);
+
+  // Get predictions for all user/item pairs we held back.
+  arma::Mat<size_t> combinations(2, savedCols.n_cols);
+  for (size_t i = 0; i < savedCols.n_cols; ++i)
+  {
+    combinations(0, i) = size_t(savedCols(0, i));
+    combinations(1, i) = size_t(savedCols(1, i));
+  }
+
+  arma::vec predictions;
+  c.Predict(combinations, predictions);
+
+  for (size_t i = 0; i < combinations.n_cols; ++i)
+  {
+    const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+    BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list