[mlpack-git] master: Allow calculation of RMSE. (14366a2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jun 26 17:19:06 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6d2e43d610eb0dac5c3083d78d097c8103a5f207...14366a2bf0b91e5013783c60a7f45a0e5f7aaf85

>---------------------------------------------------------------

commit 14366a2bf0b91e5013783c60a7f45a0e5f7aaf85
Author: ryan <ryan at ratml.org>
Date:   Fri Jun 26 17:18:53 2015 -0400

    Allow calculation of RMSE.


>---------------------------------------------------------------

14366a2bf0b91e5013783c60a7f45a0e5f7aaf85
 src/mlpack/methods/cf/cf_main.cpp | 84 +++++++++++++++++++++++++++++++++------
 1 file changed, 71 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/methods/cf/cf_main.cpp b/src/mlpack/methods/cf/cf_main.cpp
index cdb6717..09cc435 100644
--- a/src/mlpack/methods/cf/cf_main.cpp
+++ b/src/mlpack/methods/cf/cf_main.cpp
@@ -65,6 +65,8 @@ PARAM_INT("neighborhood", "Size of the neighborhood of similar users to "
 
 PARAM_INT("rank", "Rank of decomposed matrices.", "R", 2);
 
+PARAM_STRING("test_file", "Test set to calculate RMSE on.", "t", "");
+
 template<typename Factorizer>
 void ComputeRecommendations(Factorizer factorizer,
                             arma::mat& dataset,
@@ -96,7 +98,43 @@ void ComputeRecommendations(Factorizer factorizer,
   }
 }
 
+template<typename Factorizer>
+void ComputeRMSE(Factorizer&& factorizer,
+                 const arma::mat& dataset,
+                 const size_t neighborhood,
+                 const size_t rank)
+{
+  CF<Factorizer> c(dataset, factorizer, neighborhood, rank);
+
+  // Now, compute each test point.
+  const string testFile = CLI::GetParam<string>("test_file");
+  arma::mat testData;
+  data::Load(testFile, testData, true);
+
+  // Assemble the combination matrix to get RMSE value.
+  arma::Mat<size_t> combinations(2, testData.n_cols);
+  for (size_t i = 0; i < testData.n_cols; ++i)
+  {
+    combinations(0, i) = size_t(testData(0, i));
+    combinations(1, i) = size_t(testData(1, i));
+  }
+
+  // Now compute the RMSE.
+  arma::vec predictions;
+  c.Predict(combinations, predictions);
+
+  // Compute the root of the sum of the squared errors, divide by the number of
+  // points to get the RMSE.  It turns out this is just the L2-norm divided by
+  // the square root of the number of points, if we interpret the predictions
+  // and the true values as vectors.
+  const double rmse = arma::norm(predictions - testData.row(2).t()) /
+      std::sqrt((double) testData.n_cols);
+
+  Log::Info << "RMSE is " << rmse << "." << endl;
+}
+
 #define CR(x) ComputeRecommendations(x, dataset, numRecs, neighborhood, rank, recommendations)
+#define RMSE(x) ComputeRMSE(x, dataset, neighborhood, rank)
 
 int main(int argc, char** argv)
 {
@@ -121,20 +159,40 @@ int main(int argc, char** argv)
 
   const string algo = CLI::GetParam<string>("algorithm");
 
-  if(algo == "NMF")
-    CR(NMFALSFactorizer());
-  else if(algo == "SVDBatch")
-    CR(SparseSVDBatchFactorizer());
-  else if(algo == "SVDIncompleteIncremental")
-    CR(SparseSVDIncompleteIncrementalFactorizer());
-  else if(algo == "SVDCompleteIncremental")
-    CR(SparseSVDCompleteIncrementalFactorizer());
-  else if(algo == "RegSVD")
-    CR(RegularizedSVD<>());
+  if (!CLI::HasParam("test_file"))
+  {
+    if (algo == "NMF")
+      CR(NMFALSFactorizer());
+    else if (algo == "SVDBatch")
+      CR(SparseSVDBatchFactorizer());
+    else if (algo == "SVDIncompleteIncremental")
+      CR(SparseSVDIncompleteIncrementalFactorizer());
+    else if (algo == "SVDCompleteIncremental")
+      CR(SparseSVDCompleteIncrementalFactorizer());
+    else if (algo == "RegSVD")
+      CR(RegularizedSVD<>());
+    else
+      Log::Fatal << "Invalid decomposition algorithm.  Choices are 'NMF', "
+          << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+          << " and 'RegSVD'." << endl;
+  }
   else
-    Log::Fatal << "Invalid decomposition algorithm.  Choices are 'NMF', "
-        << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental', "
-        << " and 'RegSVD'." << endl;
+  {
+    if (algo == "NMF")
+      RMSE(NMFALSFactorizer());
+    else if (algo == "SVDBatch")
+      RMSE(SparseSVDBatchFactorizer());
+    else if (algo == "SVDIncompleteIncremental")
+      RMSE(SparseSVDIncompleteIncrementalFactorizer());
+    else if (algo == "SVDCompleteIncremental")
+      RMSE(SparseSVDCompleteIncrementalFactorizer());
+    else if (algo == "RegSVD")
+      RMSE(RegularizedSVD<>());
+    else
+      Log::Fatal << "Invalid decomposition algorithm.  Choices are 'NMF', "
+          << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+          << " and 'RegSVD'." << endl;
+  }
 
   const string outputFile = CLI::GetParam<string>("output_file");
   data::Save(outputFile, recommendations);



More information about the mlpack-git mailing list