[mlpack-git] master: Refactor lcc program. (443b00a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Dec 16 14:12:39 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/cd5986e141b41781fdc13a9c89443f9be33b56bd...31c10fef76ac1d85c6415c92d2ccd429c430105f

>---------------------------------------------------------------

commit 443b00af67ac6013187eb8e75340218d775d0fe8
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Dec 16 19:11:55 2015 +0000

    Refactor lcc program.


>---------------------------------------------------------------

443b00af67ac6013187eb8e75340218d775d0fe8
 .../methods/local_coordinate_coding/lcc_main.cpp   | 209 ++++++++++++++-------
 1 file changed, 137 insertions(+), 72 deletions(-)

diff --git a/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp b/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
index 26887e8..f10d6c8 100644
--- a/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
+++ b/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
@@ -39,29 +39,30 @@ PROGRAM_INFO("Local Coordinate Coding",
     "Optionally, the input data matrix X can be normalized before coding with "
     "the -N option.");
 
-PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
-PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
-
+// Training parameters.
+PARAM_STRING("training_file", "Filename of the training data (X).", "t", "");
+PARAM_INT("atoms", "Number of atoms in the dictionary.", "k", 0);
 PARAM_DOUBLE("lambda", "Weighted l1-norm regularization parameter.", "l", 0.0);
-
 PARAM_INT("max_iterations", "Maximum number of iterations for LCC (0 indicates "
     "no limit).", "n", 0);
-
 PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
-    "D", "");
-
-PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
-    "d", "dictionary.csv");
-PARAM_STRING("codes_file", "Filename to save the output codes to.", "c",
-    "codes.csv");
-
+    "i", "");
 PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
     "before coding.", "N");
+PARAM_DOUBLE("tolerance", "Tolerance for objective function.", "o", 0.01);
 
-PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
+// Load/save a model.
+PARAM_STRING("input_model_file", "File containing input LCC model.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained LCC model to.", "M",
+    "");
 
-PARAM_DOUBLE("objective_tolerance", "Tolerance for objective function.", "o",
-    0.01);
+// Test on another dataset.
+PARAM_STRING("test_file", "File of test points to encode.", "T", "");
+PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
+    "d", "");
+PARAM_STRING("codes_file", "Filename to save the output codes to.", "c", "");
+
+PARAM_INT("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
 
 using namespace arma;
 using namespace std;
@@ -79,85 +80,149 @@ int main(int argc, char* argv[])
   else
     RandomSeed((size_t) std::time(NULL));
 
-  const double lambda = CLI::GetParam<double>("lambda");
-
-  const string inputFile = CLI::GetParam<string>("input_file");
-  const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
-  const string codesFile = CLI::GetParam<string>("codes_file");
-  const string initialDictionaryFile =
-      CLI::GetParam<string>("initial_dictionary");
-
-  const size_t maxIterations = CLI::GetParam<int>("max_iterations");
-  const size_t atoms = CLI::GetParam<int>("atoms");
+  // Check for parameter validity.
+  if (CLI::HasParam("input_model_file") && CLI::HasParam("initial_dictionary"))
+    Log::Fatal << "Cannot specify both --input_model_file (-m) and "
+        << "--initial_dictionary (-i)!" << endl;
 
-  const bool normalize = CLI::HasParam("normalize");
+  if (CLI::HasParam("training_file") && !CLI::HasParam("atoms"))
+    Log::Fatal << "If --training_file is specified, the number of atoms in the "
+        << "dictionary must be specified with --atoms (-k)!" << endl;
 
-  const double objTolerance = CLI::GetParam<double>("objective_tolerance");
+  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "One of --training_file (-t) or --input_model_file (-m) must "
+        << "be specified!" << endl;
 
-  mat input;
-  data::Load(inputFile, input, true);
+  if (!CLI::HasParam("codes_file") && !CLI::HasParam("dictionary_file") &&
+      !CLI::HasParam("output_model_file"))
+    Log::Warn << "Neither --codes_file (-c), --dictionary_file (-d), nor "
+        << "--output_model_file (-M) are specified; no output will be saved."
+        << endl;
 
-  Log::Info << "Loaded " << input.n_cols << " point in " << input.n_rows
-      << " dimensions." << endl;
+  if (CLI::HasParam("codes_file") && !CLI::HasParam("test_file"))
+    Log::Fatal << "--codes_file (-c) is specified, but no test matrix ("
+        << "specified with --test_file or -T) is given to encode!" << endl;
 
-  // Normalize each point if the user asked for it.
-  if (normalize)
+  if (!CLI::HasParam("training_file"))
   {
-    Log::Info << "Normalizing data before coding..." << endl;
-    for (size_t i = 0; i < input.n_cols; ++i)
-      input.col(i) /= norm(input.col(i), 2);
+    if (CLI::HasParam("atoms"))
+      Log::Warn << "--atoms (-k) ignored because --training_file (-t) is not "
+          << "specified." << endl;
+    if (CLI::HasParam("lambda"))
+      Log::Warn << "--lambda (-l) ignored because --training_file (-t) is not "
+          << "specified." << endl;
+    if (CLI::HasParam("initial_dictionary"))
+      Log::Warn << "--initial_dictionary (-i) ignored because --training_file "
+          << "(-t) is not specified." << endl;
+    if (CLI::HasParam("max_iterations"))
+      Log::Warn << "--max_iterations (-n) ignored because --training_file (-t) "
+          << "is not specified." << endl;
+    if (CLI::HasParam("normalize"))
+      Log::Warn << "--normalize (-N) ignored because --training_file (-t) is "
+          << "not specified." << endl;
+    if (CLI::HasParam("tolerance"))
+      Log::Warn << "--tolerance (-o) ignored because --training_file (-t) is "
+          << "not specified." << endl;
   }
 
-  // If there is an initial dictionary, be sure we do not initialize one.
-  if (initialDictionaryFile != "")
+  // Do we have an existing model?
+  LocalCoordinateCoding lcc(0, 0.0);
+  if (CLI::HasParam("input_model_file"))
   {
-    LocalCoordinateCoding lcc(atoms, lambda, maxIterations, objTolerance);
+    data::Load(CLI::GetParam<string>("input_model_file"), "lcc_model", lcc,
+        true);
+  }
 
-    // Load initial dictionary directly into LCC object.
-    data::Load(initialDictionaryFile, lcc.Dictionary(), true);
+  if (CLI::HasParam("training_file"))
+  {
+    mat matX;
+    data::Load(CLI::GetParam<string>("training_file"), matX, true);
 
-    // Validate size of initial dictionary.
-    if (lcc.Dictionary().n_cols != atoms)
+    // Normalize each point if the user asked for it.
+    if (CLI::HasParam("normalize"))
     {
-      Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
-          << " atoms, but the number of atoms was specified to be " << atoms
-          << "!" << endl;
+      Log::Info << "Normalizing data before coding..." << endl;
+      for (size_t i = 0; i < matX.n_cols; ++i)
+        matX.col(i) /= norm(matX.col(i), 2);
     }
 
-    if (lcc.Dictionary().n_rows != input.n_rows)
+    lcc.Lambda() = CLI::GetParam<double>("lambda");
+    lcc.Atoms() = (size_t) CLI::GetParam<int>("atoms");
+    lcc.MaxIterations() = (size_t) CLI::GetParam<int>("max_iterations");
+    lcc.Tolerance() = CLI::GetParam<double>("tolerance");
+
+    // Inform the user if we are overwriting their model.
+    if (CLI::HasParam("input_model_file"))
     {
-      Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
-          << " dimensions, but the data has " << input.n_rows << " dimensions!"
-          << endl;
+      Log::Info << "Using dictionary from existing model in '"
+          << CLI::GetParam<string>("input_model_file") << "' as initial "
+          << "dictionary for training." << endl;
+      lcc.Train<NothingInitializer>(matX);
     }
+    else if (CLI::HasParam("initial_dictionary"))
+    {
+      // Load initial dictionary directly into LCC object.
+      data::Load(CLI::GetParam<string>("initial_dictionary"), lcc.Dictionary(),
+          true);
+
+      // Validate the size of the initial dictionary.
+      if (lcc.Dictionary().n_cols != lcc.Atoms())
+      {
+        Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
+            << " atoms, but the number of atoms was specified to be "
+            << lcc.Atoms() << "!" << endl;
+      }
+
+      if (lcc.Dictionary().n_rows != matX.n_rows)
+      {
+        Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
+            << " dimensions, but the data has " << matX.n_rows << " dimensions!"
+            << endl;
+      }
+
+      // Train the model.
+      lcc.Train<NothingInitializer>(matX);
+    }
+    else
+    {
+      // Run with the default initialization.
+      lcc.Train(matX);
+    }
+  }
+
+  // Now, do we have any matrix to encode?
+  if (CLI::HasParam("test_file"))
+  {
+    mat matY;
+    data::Load(CLI::GetParam<string>("test_file"), matY, true);
 
-    // Run LCC.
-    lcc.Train<NothingInitializer>(input);
+    if (matY.n_rows != lcc.Dictionary().n_rows)
+      Log::Fatal << "Model was trained with a dimensionality of "
+          << lcc.Dictionary().n_rows << ", but data in test file "
+          << CLI::GetParam<string>("test_file") << " has a dimensionality of "
+          << matY.n_rows << "!" << endl;
 
-    // Save the results.
-    Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
-    data::Save(dictionaryFile, lcc.Dictionary());
+    // Normalize each point if the user asked for it.
+    if (CLI::HasParam("normalize"))
+    {
+      Log::Info << "Normalizing test data before coding..." << endl;
+      for (size_t i = 0; i < matY.n_cols; ++i)
+        matY.col(i) /= norm(matY.col(i), 2);
+    }
 
     mat codes;
-    lcc.Encode(input, codes);
+    lcc.Encode(matY, codes);
 
-    Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
-    data::Save(codesFile, codes);
+    if (CLI::HasParam("codes_file"))
+      data::Save(CLI::GetParam<string>("codes_file"), codes);
   }
-  else
-  {
-    // No initial dictionary.
-    LocalCoordinateCoding lcc(input, atoms, lambda, maxIterations,
-        objTolerance);
 
-    // Save the results.
-    Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
-    data::Save(dictionaryFile, lcc.Dictionary());
+  // Did the user want to save the dictionary?
+  if (CLI::HasParam("dictionary_file"))
+    data::Save(CLI::GetParam<string>("dictionary_file"), lcc.Dictionary());
 
-    mat codes;
-    lcc.Encode(input, codes);
-
-    Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
-    data::Save(codesFile, codes);
-  }
+  // Did the user want to save the model?
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "lcc_model", lcc,
+        false); // Non-fatal on failure.
 }



More information about the mlpack-git mailing list