[mlpack-git] master: Fix #406. (fbdc8d4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Apr 26 20:34:20 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/59a05b837daaec4678157beadf7783c7067ff607...fbdc8d44bcdb3a9f174b6f9a8839ced81b20f98f
>---------------------------------------------------------------
commit fbdc8d44bcdb3a9f174b6f9a8839ced81b20f98f
Author: ryan <ryan at ratml.org>
Date: Sun Apr 26 20:33:40 2015 -0400
Fix #406.
>---------------------------------------------------------------
fbdc8d44bcdb3a9f174b6f9a8839ced81b20f98f
src/mlpack/methods/cf/cf_impl.hpp | 41 ++++++++++++++++++++-------------------
1 file changed, 21 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/methods/cf/cf_impl.hpp b/src/mlpack/methods/cf/cf_impl.hpp
index 5c4e53f..d857166 100644
--- a/src/mlpack/methods/cf/cf_impl.hpp
+++ b/src/mlpack/methods/cf/cf_impl.hpp
@@ -119,42 +119,35 @@ void CF<FactorizerType>::GetRecommendations(const size_t numRecs,
arma::Mat<size_t>& recommendations,
arma::Col<size_t>& users)
{
- // Generate new table by multiplying approximate values.
- rating = w * h;
+ // We want to avoid calculating the full rating matrix, so we will do nearest
+ // neighbor search only on the H matrix, using the observation that if the
+ // rating matrix X = W*H, then d(X.col(i), X.col(j)) = d(W H.col(i), W
+ // H.col(j)). This can be seen as nearest neighbor search on the H matrix
+ // with the Mahalanobis distance where M^{-1} = W^T W. So, we'll decompose
+ // M^{-1} = L L^T (the Cholesky decomposition), and then multiply H by L^T.
+ // Then we can perform nearest neighbor search.
+ arma::mat l = arma::chol(w.t() * w);
+ arma::mat stretchedH = l * h; // Due to the Armadillo API, l is L^T.
// Now, we will use the decomposed w and h matrices to estimate what the user
// would have rated items as, and then pick the best items.
// Temporarily store feature vector of queried users.
- arma::mat query(rating.n_rows, users.n_elem);
+ arma::mat query(stretchedH.n_rows, users.n_elem);
// Select feature vectors of queried users.
for (size_t i = 0; i < users.n_elem; i++)
- query.col(i) = rating.col(users(i));
+ query.col(i) = stretchedH.col(users(i));
// Temporary storage for neighborhood of the queried users.
arma::Mat<size_t> neighborhood;
// Calculate the neighborhood of the queried users.
// This should be a templatized option.
- neighbor::AllkNN a(rating);
+ neighbor::AllkNN a(stretchedH);
arma::mat resultingDistances; // Temporary storage.
a.Search(query, numUsersForSimilarity, neighborhood, resultingDistances);
- // Temporary storage for storing the average rating for each user in their
- // neighborhood.
- arma::mat averages = arma::zeros<arma::mat>(rating.n_rows, query.n_cols);
-
- // Iterate over each query user.
- for (size_t i = 0; i < neighborhood.n_cols; ++i)
- {
- // Iterate over each neighbor of the query user.
- for (size_t j = 0; j < neighborhood.n_rows; ++j)
- averages.col(i) += rating.col(neighborhood(j, i));
- // Normalize average.
- averages.col(i) /= neighborhood.n_rows;
- }
-
// Generate recommendations for each query user by finding the maximum numRecs
// elements in the averages matrix.
recommendations.set_size(numRecs, users.n_elem);
@@ -163,6 +156,14 @@ void CF<FactorizerType>::GetRecommendations(const size_t numRecs,
values.fill(-DBL_MAX); // The smallest possible value.
for (size_t i = 0; i < users.n_elem; i++)
{
+ // First, calculate average of neighborhood values.
+ arma::vec averages;
+ averages.zeros(cleanedData.n_rows);
+
+ for (size_t j = 0; j < neighborhood.n_rows; ++j)
+ averages += w * h.col(neighborhood(j, i));
+ averages /= neighborhood.n_rows;
+
// Look through the averages column corresponding to the current user.
for (size_t j = 0; j < averages.n_rows; ++j)
{
@@ -171,7 +172,7 @@ void CF<FactorizerType>::GetRecommendations(const size_t numRecs,
continue; // The user already rated the item.
// Is the estimated value better than the worst candidate?
- const double value = averages(j, i);
+ const double value = averages[j];
if (value > values(values.n_rows - 1, i))
{
// It should be inserted. Which position?
More information about the mlpack-git
mailing list