[mlpack-git] master: Refactor main CF program to allow loading/saving of models. (a8222e0)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 17:02:10 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/6ab20afd8adaf9dcb86bc9a8ea98a24dd8b18743...eb41f4bc27b484c347acc006255104e2f8cc4eef

>---------------------------------------------------------------

commit a8222e0440cb31a8f53d21cd548faef02f9d7810
Author: ryan <ryan at ratml.org>
Date:   Tue Dec 22 14:38:27 2015 -0500

    Refactor main CF program to allow loading/saving of models.


>---------------------------------------------------------------

a8222e0440cb31a8f53d21cd548faef02f9d7810
 src/mlpack/methods/cf/cf_main.cpp | 150 ++++++++++++++++++++++----------------
 1 file changed, 89 insertions(+), 61 deletions(-)

diff --git a/src/mlpack/methods/cf/cf_main.cpp b/src/mlpack/methods/cf/cf_main.cpp
index 42ef2b6..b22b72e 100644
--- a/src/mlpack/methods/cf/cf_main.cpp
+++ b/src/mlpack/methods/cf/cf_main.cpp
@@ -50,37 +50,38 @@ PROGRAM_INFO("Collaborating Filtering", "This program performs collaborative "
     "'SVDIncompleteIncremental' -- SVD incomplete incremental learning\n"
     "'SVDCompleteIncremental' -- SVD complete incremental learning\n");
 
-// Parameters for program.
-PARAM_STRING_REQ("input_file", "Input dataset to perform CF on.", "i");
-PARAM_STRING("query_file", "List of users for which recommendations are to "
-    "be generated.", "q", "");
-PARAM_FLAG("all_user_recommendations", "Generate recommendations for all "
-    "users.", "A");
-
-PARAM_STRING("output_file","File to save output recommendations to.", "o",
-    "recommendations.csv");
-
+// Parameters for training a model.
+PARAM_STRING("training_file", "Input dataset to perform CF on.", "t", "");
 PARAM_STRING("algorithm", "Algorithm used for matrix factorization.", "a",
     "NMF");
-
-PARAM_INT("recommendations", "Number of recommendations to generate for each "
-    "query user.", "r", 5);
 PARAM_INT("neighborhood", "Size of the neighborhood of similar users to "
     "consider for each query user.", "n", 5);
-
 PARAM_INT("rank", "Rank of decomposed matrices (if 0, a heuristic is used to "
     "estimate the rank).", "R", 0);
-
-PARAM_STRING("test_file", "Test set to calculate RMSE on.", "t", "");
+PARAM_STRING("test_file", "Test set to calculate RMSE on.", "T", "");
 
 // Offer the user the option to set the maximum number of iterations, and
 // terminate only based on the number of iterations.
-PARAM_INT("max_iterations", "Maximum number of iterations.", "m", 1000);
+PARAM_INT("max_iterations", "Maximum number of iterations.", "N", 1000);
 PARAM_FLAG("iteration_only_termination", "Terminate only when the maximum "
     "number of iterations is reached.", "I");
 PARAM_DOUBLE("min_residue", "Residue required to terminate the factorization "
     "(lower values generally mean better fits).", "r", 1e-5);
 
+// Load/save a model.
+PARAM_STRING("input_model_file", "File to load trained CF model from.", "m",
+    "");
+PARAM_STRING("output_model_file", "File to save trained CF model to.", "M", "");
+
+// Query settings.
+PARAM_STRING("query_file", "List of users for which recommendations are to "
+    "be generated.", "q", "");
+PARAM_FLAG("all_user_recommendations", "Generate recommendations for all "
+    "users.", "A");
+PARAM_STRING("output_file","File to save output recommendations to.", "o", "");
+PARAM_INT("recommendations", "Number of recommendations to generate for each "
+    "query user.", "n", 5);
+
 PARAM_INT("seed", "Set the random seed (0 uses std::time(NULL)).", "s", 0);
 
 void ComputeRecommendations(CF& cf,
@@ -137,15 +138,8 @@ void ComputeRMSE(CF& cf)
   Log::Info << "RMSE is " << rmse << "." << endl;
 }
 
-template<typename Factorizer>
-void PerformAction(Factorizer&& factorizer,
-                   arma::mat& dataset,
-                   const size_t rank)
+void PerformAction(CF& c)
 {
-  // Parameters for generating the CF object.
-  const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
-  CF c(dataset, factorizer, neighborhood, rank);
-
   if (CLI::HasParam("query_file") || CLI::HasParam("all_user_recommendations"))
   {
     // Get parameters for generating recommendations.
@@ -161,9 +155,22 @@ void PerformAction(Factorizer&& factorizer,
   }
 
   if (CLI::HasParam("test_file"))
-  {
     ComputeRMSE(c);
-  }
+
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "cf_model", c);
+}
+
+template<typename Factorizer>
+void PerformAction(Factorizer&& factorizer,
+                   arma::mat& dataset,
+                   const size_t rank)
+{
+  // Parameters for generating the CF object.
+  const size_t neighborhood = (size_t) CLI::GetParam<int>("neighborhood");
+  CF c(dataset, factorizer, neighborhood, rank);
+
+  PerformAction(c);
 }
 
 void AssembleFactorizerType(const std::string& algorithm,
@@ -236,45 +243,66 @@ int main(int argc, char** argv)
   else
     math::RandomSeed(CLI::GetParam<int>("seed"));
 
-  // Read from the input file.
-  const string inputFile = CLI::GetParam<string>("input_file");
-  arma::mat dataset;
-  data::Load(inputFile, dataset, true);
-
-  // Recommendation matrix.
-  arma::Mat<size_t> recommendations;
+  // Validate parameters.
+  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Only one of --training_file (t) or --input_model_file (-m) "
+        << "may be specified!" << endl;
 
-  // Get parameters.
-  const size_t rank = (size_t) CLI::GetParam<int>("rank");
+  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) are"
+        << " specified!" << endl;
 
   // Check that nothing stupid is happening.
   if (CLI::HasParam("query_file") && CLI::HasParam("all_user_recommendations"))
     Log::Fatal << "Both --query_file and --all_user_recommendations are given, "
         << "but only one is allowed!" << endl;
 
-  // Perform decomposition to prepare for recommendations.
-  Log::Info << "Performing CF matrix decomposition on dataset..." << endl;
-
-  const string algo = CLI::GetParam<string>("algorithm");
-
-  // Issue an error if an invalid factorizer is used.
-  if (algo != "NMF" &&
-      algo != "SVDBatch" &&
-      algo != "SVDIncompleteIncremental" &&
-      algo != "SVDCompleteIncremental" &&
-      algo != "RegSVD")
-    Log::Fatal << "Invalid decomposition algorithm.  Choices are 'NMF', "
-        << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
-        << " and 'RegSVD'." << endl;
-
-  // Issue a warning if the user provided a minimum residue but it will be
-  // ignored.
-  if (CLI::HasParam("min_residue") &&
-      CLI::HasParam("iteration_only_termination"))
-    Log::Warn << "--min_residue ignored, because --iteration_only_termination "
-        << "is specified." << endl;
-
-  // Perform the factorization and do whatever the user wanted.
-  AssembleFactorizerType(algo, dataset,
-      CLI::HasParam("iteration_only_termination"), rank);
+  // Either load from a model, or train a model.
+  if (CLI::HasParam("training_file"))
+  {
+    // Read from the input file.
+    const string trainingFile = CLI::GetParam<string>("training_file");
+    arma::mat dataset;
+    data::Load(trainingFile, dataset, true);
+
+    // Recommendation matrix.
+    arma::Mat<size_t> recommendations;
+
+    // Get parameters.
+    const size_t rank = (size_t) CLI::GetParam<int>("rank");
+
+    // Perform decomposition to prepare for recommendations.
+    Log::Info << "Performing CF matrix decomposition on dataset..." << endl;
+
+    const string algo = CLI::GetParam<string>("algorithm");
+
+    // Issue an error if an invalid factorizer is used.
+    if (algo != "NMF" &&
+        algo != "SVDBatch" &&
+        algo != "SVDIncompleteIncremental" &&
+        algo != "SVDCompleteIncremental" &&
+        algo != "RegSVD")
+      Log::Fatal << "Invalid decomposition algorithm.  Choices are 'NMF', "
+          << "'SVDBatch', 'SVDIncompleteIncremental', 'SVDCompleteIncremental',"
+          << " and 'RegSVD'." << endl;
+
+    // Issue a warning if the user provided a minimum residue but it will be
+    // ignored.
+    if (CLI::HasParam("min_residue") &&
+        CLI::HasParam("iteration_only_termination"))
+      Log::Warn << "--min_residue ignored, because --iteration_only_termination"
+          << " is specified." << endl;
+
+    // Perform the factorization and do whatever the user wanted.
+    AssembleFactorizerType(algo, dataset,
+        CLI::HasParam("iteration_only_termination"), rank);
+  }
+  else
+  {
+    // Load an input model.
+    CF c;
+    data::Load(CLI::GetParam<string>("input_model_file"), "cf_model", c, true);
+
+    PerformAction(c);
+  }
 }



More information about the mlpack-git mailing list