[mlpack-git] master: Add batch Predict() method. (ee38465)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Apr 27 20:24:57 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/96e6efcc7f5c4af597852ce64f13d3af2c5ba4be...ee384655c4462e422e343e9725437fd772ca4449

>---------------------------------------------------------------

commit ee384655c4462e422e343e9725437fd772ca4449
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Apr 27 20:24:44 2015 -0400

    Add batch Predict() method.


>---------------------------------------------------------------

ee384655c4462e422e343e9725437fd772ca4449
 src/mlpack/methods/cf/cf.hpp      | 15 ++++++++++
 src/mlpack/methods/cf/cf_impl.hpp | 55 ++++++++++++++++++++++++++++++++++
 src/mlpack/tests/cf_test.cpp      | 63 +++++++++++++++++++++++++++++++++++++++
 3 files changed, 133 insertions(+)

diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp
index 55afb54..33e967d 100644
--- a/src/mlpack/methods/cf/cf.hpp
+++ b/src/mlpack/methods/cf/cf.hpp
@@ -174,6 +174,21 @@ class CF
   double Predict(const size_t user, const size_t item) const;
 
   /**
+   * Predict ratings for each user-item combination in the given coordinate list
+   * matrix.  The matrix 'combinations' should have two rows and number of
+   * columns equal to the number of desired predictions.  The first element of
+   * each column corresponds to the user index, and the second element of each
+   * column corresponds to the item index.  The output vector 'predictions' will
+   * have length equal to combinations.n_cols, and predictions[i] will be equal
+   * to the prediction for the user/item combination in combinations.col(i).
+   *
+   * @param combinations User/item combinations to predict.
+   * @param predictions Predicted ratings for each user/item combination.
+   */
+  void Predict(const arma::Mat<size_t>& combinations,
+               arma::vec& predictions) const;
+
+  /**
    * Returns a string representation of this object.
    */
   std::string ToString() const;
diff --git a/src/mlpack/methods/cf/cf_impl.hpp b/src/mlpack/methods/cf/cf_impl.hpp
index 5ad7e45..63df8d1 100644
--- a/src/mlpack/methods/cf/cf_impl.hpp
+++ b/src/mlpack/methods/cf/cf_impl.hpp
@@ -242,6 +242,61 @@ double CF<FactorizerType>::Predict(const size_t user, const size_t item) const
   return rating;
 }
 
+// Predict the rating for a group of user/item combinations.
+template<typename FactorizerType>
+void CF<FactorizerType>::Predict(const arma::Mat<size_t>& combinations,
+                                 arma::vec& predictions) const
+{
+  // First, for nearest neighbor search, stretch the H matrix.
+  arma::mat l = arma::chol(w.t() * w);
+  arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.
+
+  // Now, we must determine those query indices we need to find the nearest
+  // neighbors for.  This is easiest if we just sort the combinations matrix.
+  arma::Mat<size_t> sortedCombinations(combinations.n_rows,
+                                       combinations.n_cols);
+  arma::uvec ordering = arma::sort_index(combinations.row(0).t());
+  for (size_t i = 0; i < ordering.n_elem; ++i)
+    sortedCombinations.col(i) = combinations.col(ordering[i]);
+
+  // Now, we have to get the list of unique users we will be searching for.
+  arma::Col<size_t> users = arma::unique(combinations.row(0).t());
+
+  // Assemble our query matrix from the stretchedH matrix.
+  arma::mat queries(stretchedH.n_rows, users.n_elem);
+  for (size_t i = 0; i < queries.n_cols; ++i)
+    queries.col(i) = stretchedH.col(users[i]);
+
+  // Now calculate the neighborhood of these users.
+  neighbor::AllkNN a(stretchedH);
+  arma::mat distances;
+  arma::Mat<size_t> neighborhood;
+
+  a.Search(queries, numUsersForSimilarity, neighborhood, distances);
+
+  // Now that we have the neighborhoods we need, calculate the predictions.
+  predictions.set_size(combinations.n_cols);
+
+  size_t user = 0; // Cumulative user count, because we are doing it in order.
+  for (size_t i = 0; i < sortedCombinations.n_cols; ++i)
+  {
+    // Could this be made faster by calculating dot products for multiple items
+    // at once?
+    double rating = 0.0;
+
+    // Map the combination's user to the user ID used for kNN.
+    while (users[user] < sortedCombinations(0, i))
+      ++user;
+
+    for (size_t j = 0; j < neighborhood.n_rows; ++j)
+      rating += arma::as_scalar(w.row(sortedCombinations(1, i)) *
+          h.col(neighborhood(j, user)));
+    rating /= neighborhood.n_rows;
+
+    predictions(ordering[i]) = rating;
+  }
+}
+
 template<typename FactorizerType>
 void CF<FactorizerType>::CleanData(const arma::mat& data)
 {
diff --git a/src/mlpack/tests/cf_test.cpp b/src/mlpack/tests/cf_test.cpp
index fff769a..feb4d4b 100644
--- a/src/mlpack/tests/cf_test.cpp
+++ b/src/mlpack/tests/cf_test.cpp
@@ -235,4 +235,67 @@ BOOST_AUTO_TEST_CASE(CFPredictTest)
   BOOST_REQUIRE_LT(totalError, 0.5);
 }
 
+// Do the same thing as the previous test, but ensure that the ratings we
+// predict with the batch Predict() are the same as the individual Predict()
+// calls.
+BOOST_AUTO_TEST_CASE(CFBatchPredictTest)
+{
+  // Load the GroupLens dataset; then, we will remove some values from it.
+  arma::mat dataset;
+  data::Load("GroupLens100k.csv", dataset);
+
+  // Save the columns we've removed.
+  arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+  size_t currentCol = 0;
+  for (size_t i = 0; i < dataset.n_cols; ++i)
+  {
+    if (currentCol == 300)
+      break;
+
+    if (dataset(2, i) > 4.5) // 5-star rating.
+    {
+      // Make sure we don't have this user yet.  This is a slow way to do this
+      // but I don't particularly care here because it's in the tests.
+      bool found = false;
+      for (size_t j = 0; j < currentCol; ++j)
+      {
+        if (savedCols(0, j) == dataset(0, i))
+        {
+          found = true;
+          break;
+        }
+      }
+
+      // If this user doesn't already exist in savedCols, add them.  Otherwise
+      // ignore this point.
+      if (!found)
+      {
+        savedCols.col(currentCol) = dataset.col(i);
+        dataset.shed_col(i);
+        ++currentCol;
+      }
+    }
+  }
+
+  // Now create the CF object.
+  CF<> c(dataset);
+
+  // Get predictions for all user/item pairs we held back.
+  arma::Mat<size_t> combinations(2, savedCols.n_cols);
+  for (size_t i = 0; i < savedCols.n_cols; ++i)
+  {
+    combinations(0, i) = size_t(savedCols(0, i));
+    combinations(1, i) = size_t(savedCols(1, i));
+  }
+
+  arma::vec predictions;
+  c.Predict(combinations, predictions);
+
+  for (size_t i = 0; i < combinations.n_cols; ++i)
+  {
+    const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+    BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list