[mlpack-git] master: Refactor main CF program to allow loading/saving of models. (a8222e0)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:10 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef
>---------------------------------------------------------------
commit a8222e0440cb31a8f53d21cd548faef02f9d7810
Author: ryan <ryan at ratml.org>
Date: Tue Dec 22 14:38:27 2015 -0500
Refactor main CF program to allow loading/saving of models.
>---------------------------------------------------------------
a8222e0440cb31a8f53d21cd548faef02f9d7810
src/mlpack/methods/cf/cf_main.cpp | 150 ++++++++++++++++++++++----------------
1 file changed, 89 insertions(+), 61 deletions(-)
diff --git a/src/mlpack/methods/cf/cf_main.cpp b/src/mlpack/methods/cf/cf_main.cpp
index 42ef2b6..b22b72e 100644
--- a/src/mlpack/methods/cf/cf_main.cpp
+++ b/src/mlpack/methods/cf/cf_main.cpp
@@ -50,37 +50,38 @@ PROGRAM_INFO("Collaborating Filtering", "This program performs collaborative "
"'SVDIncompleteIncremental' -- SVD incomplete incremental learning\n"
"'SVDCompleteIncremental' -- SVD complete incremental learning\n");
-// Parameters for program.
-PARAM_STRING_REQ("input_file", "Input dataset to perform CF on.", "i");
-PARAM_STRING("query_file", "List of users for which recommendations are to "
- "be generated.", "q", "");
-PARAM_FLAG("all_user_recommendations", "Generate recommendations for all "
- "users.", "A");
-
-PARAM_STRING("output_file","File to save output recommendations to.", "o",
- "recommendations.csv");
-
+// Parameters for training a model.
+PARAM_STRING("training_file", "Input dataset to perform CF on.", "t", "");
PARAM_STRING("algorithm", "Algorithm used for matrix factorization.", "a",
"NMF");
-
-PARAM_INT("recommendations", "Number of recommendations to generate for each "
- "query user.", "r", 5);
PARAM_INT("neighborhood", "Size of the neighborhood of similar users to "
"consider for each query user.", "n", 5);
-
PARAM_INT("rank", "Rank of decomposed matrices (if 0, a heuristic is used to "
"estimate the rank).", "R", 0);
-
-PARAM_STRING("test_file", "Test set to calculate RMSE on.", "t", "");
+PARAM_STRING("test_file", "Test set to calculate RMSE on.", "T", "");
// Offer the user the option to set the maximum number of iterations, and
// terminate only based on the number of iterations.
-PARAM_INT("max_iterations", "Maximum number of iterations.", "m", 1000);
+PARAM_INT("max_iterations", "Maximum number of iterations.", "N", 1000);
PARAM_FLAG("iteration_only_termination", "Terminate only when the maximum "
"number of iterations is reached.", "I");
PARAM_DOUBLE("min_residue", "Residue required to terminate the factorization "
"(lower values generally mean better fits).", "r", 1e-5);
+// Load/save a model.
+PARAM_STRING("input_model_file", "File to load trained CF model from.", "m",
+ "");
+PARAM_STRING("output_model_file", "File to save trained CF model to.", "M", "");
+
+// Query settings.
+PARAM_STRING("query_file", "List of users for which recommendations are to "
+ "be generated.", "q", "");
+PARAM_FLAG("all_user_recommendations", "Generate recommendations for all "
+ "users.", "A");
+PARAM_STRING("output_file","File to save output recommendations to.", "o", "");
+PARAM_INT("recommendations", "Number of recommendations to generate for each "
+ "query user.", "n", 5);
+
PARAM_INT("seed", "Set the random seed (0 uses std::time(NULL)).", "s", 0);
void ComputeRecommendations(CF& cf,
@@ -137,15 +138,8 @@ void ComputeRMSE(CF& cf)
Log::Info << "RMSE is " << rmse << "." << endl;
}
-template<typename Factorizer>
-void PerformAction(Factorizer&& factorizer,
- arma::mat& dataset,
- const size_t rank)
+void PerformAction(CF& c)
{
- // Parameters for generating the CF object.
- const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
- CF c(dataset, factorizer, neighborhood, rank);
-
if (CLI::HasParam("query_file") || CLI::HasParam("all_user_recommendations"))
{
// Get parameters for generating recommendations.
@@ -161,9 +155,22 @@ void PerformAction(Factorizer&& factorizer,
}
if (CLI::HasParam("test_file"))
- {
ComputeRMSE(c);
- }
+
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), "cf_model", c);
+}
+
+template<typename Factorizer>
+void PerformAction(Factorizer&& factorizer,
+ arma::mat& dataset,
+ const size_t rank)
+{
+ // Parameters for generating the CF object.
+ const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
+ CF c(dataset, factorizer, neighborhood, rank);
+
+ PerformAction(c);
}
void AssembleFactorizerType(const std::string& algorithm,
@@ -236,45 +243,66 @@ int main(int argc, char** argv)
else
math::RandomSeed(CLI::GetParam<int>("seed"));
- // Read from the input file.
- const string inputFile = CLI::GetParam<string>("input_file");
- arma::mat dataset;
- data::Load(inputFile, dataset, true);
-
- // Recommendation matrix.
- arma::Mat<size_t> recommendations;
+ // Validate parameters.
+ if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Only one of --training_file (t) or --input_model_file (-m) "
+ << "may be specified!" << endl;
- // Get parameters.
- const size_t rank = (size_t) CLI::GetParam<int>("rank");
+ if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) are"
+ << " specified!" << endl;
// Check that nothing stupid is happening.
if (CLI::HasParam("query_file") && CLI::HasParam("all_user_recommendations"))
Log::Fatal << "Both --query_file and --all_user_recommendations are given, "
<< "but only one is allowed!" << endl;
- // Perform decomposition to prepare for recommendations.
- Log::Info << "Performing CF matrix decomposition on dataset..." << endl;
-
- const string algo = CLI::GetParam<string>("algorithm");
-
- // Issue an error if an invalid factorizer is used.
- if (algo != "NMF" &&
- algo != "SVDBatch" &&
- algo != "SVDIncompleteIncremental" &&
- algo != "SVDCompleteIncremental" &&
- algo != "RegSVD")
- Log::Fatal << "Invalid decomposition algorithm. Choices are 'NMF', "
- << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
- << " and 'RegSVD'." << endl;
-
- // Issue a warning if the user provided a minimum residue but it will be
- // ignored.
- if (CLI::HasParam("min_residue") &&
- CLI::HasParam("iteration_only_termination"))
- Log::Warn << "--min_residue ignored, because --iteration_only_termination "
- << "is specified." << endl;
-
- // Perform the factorization and do whatever the user wanted.
- AssembleFactorizerType(algo, dataset,
- CLI::HasParam("iteration_only_termination"), rank);
+ // Either load from a model, or train a model.
+ if (CLI::HasParam("training_file"))
+ {
+ // Read from the input file.
+ const string trainingFile = CLI::GetParam<string>("training_file");
+ arma::mat dataset;
+ data::Load(trainingFile, dataset, true);
+
+ // Recommendation matrix.
+ arma::Mat<size_t> recommendations;
+
+ // Get parameters.
+ const size_t rank = (size_t) CLI::GetParam<int>("rank");
+
+ // Perform decomposition to prepare for recommendations.
+ Log::Info << "Performing CF matrix decomposition on dataset..." << endl;
+
+ const string algo = CLI::GetParam<string>("algorithm");
+
+ // Issue an error if an invalid factorizer is used.
+ if (algo != "NMF" &&
+ algo != "SVDBatch" &&
+ algo != "SVDIncompleteIncremental" &&
+ algo != "SVDCompleteIncremental" &&
+ algo != "RegSVD")
+ Log::Fatal << "Invalid decomposition algorithm. Choices are 'NMF', "
+ << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+ << " and 'RegSVD'." << endl;
+
+ // Issue a warning if the user provided a minimum residue but it will be
+ // ignored.
+ if (CLI::HasParam("min_residue") &&
+ CLI::HasParam("iteration_only_termination"))
+ Log::Warn << "--min_residue ignored, because --iteration_only_termination"
+ << " is specified." << endl;
+
+ // Perform the factorization and do whatever the user wanted.
+ AssembleFactorizerType(algo, dataset,
+ CLI::HasParam("iteration_only_termination"), rank);
+ }
+ else
+ {
+ // Load an input model.
+ CF c;
+ data::Load(CLI::GetParam<string>("input_model_file"), "cf_model", c, true);
+
+ PerformAction(c);
+ }
}
More information about the mlpack-git
mailing list