[mlpack-git] master: Allow calculation of RMSE. (14366a2)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jun 26 17:19:06 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6d2e43d610eb0dac5c3083d78d097c8103a5f207...14366a2bf0b91e5013783c60a7f45a0e5f7aaf85
>---------------------------------------------------------------
commit 14366a2bf0b91e5013783c60a7f45a0e5f7aaf85
Author: ryan <ryan at ratml.org>
Date: Fri Jun 26 17:18:53 2015 -0400
Allow calculation of RMSE.
>---------------------------------------------------------------
14366a2bf0b91e5013783c60a7f45a0e5f7aaf85
src/mlpack/methods/cf/cf_main.cpp | 84 +++++++++++++++++++++++++++++++++------
1 file changed, 71 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/methods/cf/cf_main.cpp b/src/mlpack/methods/cf/cf_main.cpp
index cdb6717..09cc435 100644
--- a/src/mlpack/methods/cf/cf_main.cpp
+++ b/src/mlpack/methods/cf/cf_main.cpp
@@ -65,6 +65,8 @@ PARAM_INT("neighborhood", "Size of the neighborhood of similar users to "
PARAM_INT("rank", "Rank of decomposed matrices.", "R", 2);
+PARAM_STRING("test_file", "Test set to calculate RMSE on.", "t", "");
+
template<typename Factorizer>
void ComputeRecommendations(Factorizer factorizer,
arma::mat& dataset,
@@ -96,7 +98,43 @@ void ComputeRecommendations(Factorizer factorizer,
}
}
+template<typename Factorizer>
+void ComputeRMSE(Factorizer&& factorizer,
+ const arma::mat& dataset,
+ const size_t neighborhood,
+ const size_t rank)
+{
+ CF<Factorizer> c(dataset, factorizer, neighborhood, rank);
+
+ // Now, compute each test point.
+ const string testFile = CLI::GetParam<string>("test_file");
+ arma::mat testData;
+ data::Load(testFile, testData, true);
+
+ // Assemble the combination matrix to get RMSE value.
+ arma::Mat<size_t> combinations(2, testData.n_cols);
+ for (size_t i = 0; i < testData.n_cols; ++i)
+ {
+ combinations(0, i) = size_t(testData(0, i));
+ combinations(1, i) = size_t(testData(1, i));
+ }
+
+ // Now compute the RMSE.
+ arma::vec predictions;
+ c.Predict(combinations, predictions);
+
+ // Compute the root of the sum of the squared errors, divide by the number of
+ // points to get the RMSE. It turns out this is just the L2-norm divided by
+ // the square root of the number of points, if we interpret the predictions
+ // and the true values as vectors.
+ const double rmse = arma::norm(predictions - testData.row(2).t()) /
+ std::sqrt((double) testData.n_cols);
+
+ Log::Info << "RMSE is " << rmse << "." << endl;
+}
+
#define CR(x) ComputeRecommendations(x, dataset, numRecs, neighborhood, rank, recommendations)
+#define RMSE(x) ComputeRMSE(x, dataset, neighborhood, rank)
int main(int argc, char** argv)
{
@@ -121,20 +159,40 @@ int main(int argc, char** argv)
const string algo = CLI::GetParam<string>("algorithm");
- if(algo == "NMF")
- CR(NMFALSFactorizer());
- else if(algo == "SVDBatch")
- CR(SparseSVDBatchFactorizer());
- else if(algo == "SVDIncompleteIncremental")
- CR(SparseSVDIncompleteIncrementalFactorizer());
- else if(algo == "SVDCompleteIncremental")
- CR(SparseSVDCompleteIncrementalFactorizer());
- else if(algo == "RegSVD")
- CR(RegularizedSVD<>());
+ if (!CLI::HasParam("test_file"))
+ {
+ if (algo == "NMF")
+ CR(NMFALSFactorizer());
+ else if (algo == "SVDBatch")
+ CR(SparseSVDBatchFactorizer());
+ else if (algo == "SVDIncompleteIncremental")
+ CR(SparseSVDIncompleteIncrementalFactorizer());
+ else if (algo == "SVDCompleteIncremental")
+ CR(SparseSVDCompleteIncrementalFactorizer());
+ else if (algo == "RegSVD")
+ CR(RegularizedSVD<>());
+ else
+ Log::Fatal << "Invalid decomposition algorithm. Choices are 'NMF', "
+ << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+ << " and 'RegSVD'." << endl;
+ }
else
- Log::Fatal << "Invalid decomposition algorithm. Choices are 'NMF', "
- << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental', "
- << " and 'RegSVD'." << endl;
+ {
+ if (algo == "NMF")
+ RMSE(NMFALSFactorizer());
+ else if (algo == "SVDBatch")
+ RMSE(SparseSVDBatchFactorizer());
+ else if (algo == "SVDIncompleteIncremental")
+ RMSE(SparseSVDIncompleteIncrementalFactorizer());
+ else if (algo == "SVDCompleteIncremental")
+ RMSE(SparseSVDCompleteIncrementalFactorizer());
+ else if (algo == "RegSVD")
+ RMSE(RegularizedSVD<>());
+ else
+ Log::Fatal << "Invalid decomposition algorithm. Choices are 'NMF', "
+ << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+ << " and 'RegSVD'." << endl;
+ }
const string outputFile = CLI::GetParam<string>("output_file");
data::Save(outputFile, recommendations);
More information about the mlpack-git
mailing list