[mlpack-git] master: Add batch Predict() method. (ee38465)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Apr 27 20:24:57 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/96e6efcc7f5c4af597852ce64f13d3af2c5ba4be...ee384655c4462e422e343e9725437fd772ca4449
>---------------------------------------------------------------
commit ee384655c4462e422e343e9725437fd772ca4449
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Apr 27 20:24:44 2015 -0400
Add batch Predict() method.
>---------------------------------------------------------------
ee384655c4462e422e343e9725437fd772ca4449
src/mlpack/methods/cf/cf.hpp | 15 ++++++++++
src/mlpack/methods/cf/cf_impl.hpp | 55 ++++++++++++++++++++++++++++++++++
src/mlpack/tests/cf_test.cpp | 63 +++++++++++++++++++++++++++++++++++++++
3 files changed, 133 insertions(+)
diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp
index 55afb54..33e967d 100644
--- a/src/mlpack/methods/cf/cf.hpp
+++ b/src/mlpack/methods/cf/cf.hpp
@@ -174,6 +174,21 @@ class CF
double Predict(const size_t user, const size_t item) const;
/**
+ * Predict ratings for each user-item combination in the given coordinate list
+ * matrix. The matrix 'combinations' should have two rows and number of
+ * columns equal to the number of desired predictions. The first element of
+ * each column corresponds to the user index, and the second element of each
+ * column corresponds to the item index. The output vector 'predictions' will
+ * have length equal to combinations.n_cols, and predictions[i] will be equal
+ * to the prediction for the user/item combination in combinations.col(i).
+ *
+ * @param combinations User/item combinations to predict.
+ * @param predictions Predicted ratings for each user/item combination.
+ */
+ void Predict(const arma::Mat<size_t>& combinations,
+ arma::vec& predictions) const;
+
+ /**
* Returns a string representation of this object.
*/
std::string ToString() const;
diff --git a/src/mlpack/methods/cf/cf_impl.hpp b/src/mlpack/methods/cf/cf_impl.hpp
index 5ad7e45..63df8d1 100644
--- a/src/mlpack/methods/cf/cf_impl.hpp
+++ b/src/mlpack/methods/cf/cf_impl.hpp
@@ -242,6 +242,61 @@ double CF<FactorizerType>::Predict(const size_t user, const size_t item) const
return rating;
}
+// Predict the rating for a group of user/item combinations.
+template<typename FactorizerType>
+void CF<FactorizerType>::Predict(const arma::Mat<size_t>& combinations,
+ arma::vec& predictions) const
+{
+ // First, for nearest neighbor search, stretch the H matrix.
+ arma::mat l = arma::chol(w.t() * w);
+ arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.
+
+ // Now, we must determine those query indices we need to find the nearest
+ // neighbors for. This is easiest if we just sort the combinations matrix.
+ arma::Mat<size_t> sortedCombinations(combinations.n_rows,
+ combinations.n_cols);
+ arma::uvec ordering = arma::sort_index(combinations.row(0).t());
+ for (size_t i = 0; i < ordering.n_elem; ++i)
+ sortedCombinations.col(i) = combinations.col(ordering[i]);
+
+ // Now, we have to get the list of unique users we will be searching for.
+ arma::Col<size_t> users = arma::unique(combinations.row(0).t());
+
+ // Assemble our query matrix from the stretchedH matrix.
+ arma::mat queries(stretchedH.n_rows, users.n_elem);
+ for (size_t i = 0; i < queries.n_cols; ++i)
+ queries.col(i) = stretchedH.col(users[i]);
+
+ // Now calculate the neighborhood of these users.
+ neighbor::AllkNN a(stretchedH);
+ arma::mat distances;
+ arma::Mat<size_t> neighborhood;
+
+ a.Search(queries, numUsersForSimilarity, neighborhood, distances);
+
+ // Now that we have the neighborhoods we need, calculate the predictions.
+ predictions.set_size(combinations.n_cols);
+
+ size_t user = 0; // Cumulative user count, because we are doing it in order.
+ for (size_t i = 0; i < sortedCombinations.n_cols; ++i)
+ {
+ // Could this be made faster by calculating dot products for multiple items
+ // at once?
+ double rating = 0.0;
+
+ // Map the combination's user to the user ID used for kNN.
+ while (users[user] < sortedCombinations(0, i))
+ ++user;
+
+ for (size_t j = 0; j < neighborhood.n_rows; ++j)
+ rating += arma::as_scalar(w.row(sortedCombinations(1, i)) *
+ h.col(neighborhood(j, user)));
+ rating /= neighborhood.n_rows;
+
+ predictions(ordering[i]) = rating;
+ }
+}
+
template<typename FactorizerType>
void CF<FactorizerType>::CleanData(const arma::mat& data)
{
diff --git a/src/mlpack/tests/cf_test.cpp b/src/mlpack/tests/cf_test.cpp
index fff769a..feb4d4b 100644
--- a/src/mlpack/tests/cf_test.cpp
+++ b/src/mlpack/tests/cf_test.cpp
@@ -235,4 +235,67 @@ BOOST_AUTO_TEST_CASE(CFPredictTest)
BOOST_REQUIRE_LT(totalError, 0.5);
}
+// Do the same thing as the previous test, but ensure that the ratings we
+// predict with the batch Predict() are the same as the individual Predict()
+// calls.
+BOOST_AUTO_TEST_CASE(CFBatchPredictTest)
+{
+ // Load the GroupLens dataset; then, we will remove some values from it.
+ arma::mat dataset;
+ data::Load("GroupLens100k.csv", dataset);
+
+ // Save the columns we've removed.
+ arma::mat savedCols(3, 300); // Remove 300 5-star ratings.
+ size_t currentCol = 0;
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
+ if (currentCol == 300)
+ break;
+
+ if (dataset(2, i) > 4.5) // 5-star rating.
+ {
+ // Make sure we don't have this user yet. This is a slow way to do this
+ // but I don't particularly care here because it's in the tests.
+ bool found = false;
+ for (size_t j = 0; j < currentCol; ++j)
+ {
+ if (savedCols(0, j) == dataset(0, i))
+ {
+ found = true;
+ break;
+ }
+ }
+
+ // If this user doesn't already exist in savedCols, add them. Otherwise
+ // ignore this point.
+ if (!found)
+ {
+ savedCols.col(currentCol) = dataset.col(i);
+ dataset.shed_col(i);
+ ++currentCol;
+ }
+ }
+ }
+
+ // Now create the CF object.
+ CF<> c(dataset);
+
+ // Get predictions for all user/item pairs we held back.
+ arma::Mat<size_t> combinations(2, savedCols.n_cols);
+ for (size_t i = 0; i < savedCols.n_cols; ++i)
+ {
+ combinations(0, i) = size_t(savedCols(0, i));
+ combinations(1, i) = size_t(savedCols(1, i));
+ }
+
+ arma::vec predictions;
+ c.Predict(combinations, predictions);
+
+ for (size_t i = 0; i < combinations.n_cols; ++i)
+ {
+ const double prediction = c.Predict(combinations(0, i), combinations(1, i));
+ BOOST_REQUIRE_CLOSE(prediction, predictions[i], 1e-8);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list