[mlpack-git] master: Refactor lcc program. (443b00a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 16 14:12:39 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cd5986e141b41781fdc13a9c89443f9be33b56bd...31c10fef76ac1d85c6415c92d2ccd429c430105f
>---------------------------------------------------------------
commit 443b00af67ac6013187eb8e75340218d775d0fe8
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Dec 16 19:11:55 2015 +0000
Refactor lcc program.
>---------------------------------------------------------------
443b00af67ac6013187eb8e75340218d775d0fe8
.../methods/local_coordinate_coding/lcc_main.cpp | 209 ++++++++++++++-------
1 file changed, 137 insertions(+), 72 deletions(-)
diff --git a/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp b/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
index 26887e8..f10d6c8 100644
--- a/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
+++ b/src/mlpack/methods/local_coordinate_coding/lcc_main.cpp
@@ -39,29 +39,30 @@ PROGRAM_INFO("Local Coordinate Coding",
"Optionally, the input data matrix X can be normalized before coding with "
"the -N option.");
-PARAM_STRING_REQ("input_file", "Filename of the input data.", "i");
-PARAM_INT_REQ("atoms", "Number of atoms in the dictionary.", "k");
-
+// Training parameters.
+PARAM_STRING("training_file", "Filename of the training data (X).", "t", "");
+PARAM_INT("atoms", "Number of atoms in the dictionary.", "k", 0);
PARAM_DOUBLE("lambda", "Weighted l1-norm regularization parameter.", "l", 0.0);
-
PARAM_INT("max_iterations", "Maximum number of iterations for LCC (0 indicates "
"no limit).", "n", 0);
-
PARAM_STRING("initial_dictionary", "Filename for optional initial dictionary.",
- "D", "");
-
-PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
- "d", "dictionary.csv");
-PARAM_STRING("codes_file", "Filename to save the output codes to.", "c",
- "codes.csv");
-
+ "i", "");
PARAM_FLAG("normalize", "If set, the input data matrix will be normalized "
"before coding.", "N");
+PARAM_DOUBLE("tolerance", "Tolerance for objective function.", "o", 0.01);
-PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
+// Load/save a model.
+PARAM_STRING("input_model_file", "File containing input LCC model.", "m", "");
+PARAM_STRING("output_model_file", "File to save trained LCC model to.", "M",
+ "");
-PARAM_DOUBLE("objective_tolerance", "Tolerance for objective function.", "o",
- 0.01);
+// Test on another dataset.
+PARAM_STRING("test_file", "File of test points to encode.", "T", "");
+PARAM_STRING("dictionary_file", "Filename to save the output dictionary to.",
+ "d", "");
+PARAM_STRING("codes_file", "Filename to save the output codes to.", "c", "");
+
+PARAM_INT("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
using namespace arma;
using namespace std;
@@ -79,85 +80,149 @@ int main(int argc, char* argv[])
else
RandomSeed((size_t) std::time(NULL));
- const double lambda = CLI::GetParam<double>("lambda");
-
- const string inputFile = CLI::GetParam<string>("input_file");
- const string dictionaryFile = CLI::GetParam<string>("dictionary_file");
- const string codesFile = CLI::GetParam<string>("codes_file");
- const string initialDictionaryFile =
- CLI::GetParam<string>("initial_dictionary");
-
- const size_t maxIterations = CLI::GetParam<int>("max_iterations");
- const size_t atoms = CLI::GetParam<int>("atoms");
+ // Check for parameter validity.
+ if (CLI::HasParam("input_model_file") && CLI::HasParam("initial_dictionary"))
+ Log::Fatal << "Cannot specify both --input_model_file (-m) and "
+ << "--initial_dictionary (-i)!" << endl;
- const bool normalize = CLI::HasParam("normalize");
+ if (CLI::HasParam("training_file") && !CLI::HasParam("atoms"))
+ Log::Fatal << "If --training_file is specified, the number of atoms in the "
+ << "dictionary must be specified with --atoms (-k)!" << endl;
- const double objTolerance = CLI::GetParam<double>("objective_tolerance");
+ if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "One of --training_file (-t) or --input_model_file (-m) must "
+ << "be specified!" << endl;
- mat input;
- data::Load(inputFile, input, true);
+ if (!CLI::HasParam("codes_file") && !CLI::HasParam("dictionary_file") &&
+ !CLI::HasParam("output_model_file"))
+ Log::Warn << "Neither --codes_file (-c), --dictionary_file (-d), nor "
+ << "--output_model_file (-M) are specified; no output will be saved."
+ << endl;
- Log::Info << "Loaded " << input.n_cols << " point in " << input.n_rows
- << " dimensions." << endl;
+ if (CLI::HasParam("codes_file") && !CLI::HasParam("test_file"))
+ Log::Fatal << "--codes_file (-c) is specified, but no test matrix ("
+ << "specified with --test_file or -T) is given to encode!" << endl;
- // Normalize each point if the user asked for it.
- if (normalize)
+ if (!CLI::HasParam("training_file"))
{
- Log::Info << "Normalizing data before coding..." << endl;
- for (size_t i = 0; i < input.n_cols; ++i)
- input.col(i) /= norm(input.col(i), 2);
+ if (CLI::HasParam("atoms"))
+ Log::Warn << "--atoms (-k) ignored because --training_file (-t) is not "
+ << "specified." << endl;
+ if (CLI::HasParam("lambda"))
+ Log::Warn << "--lambda (-l) ignored because --training_file (-t) is not "
+ << "specified." << endl;
+ if (CLI::HasParam("initial_dictionary"))
+ Log::Warn << "--initial_dictionary (-i) ignored because --training_file "
+ << "(-t) is not specified." << endl;
+ if (CLI::HasParam("max_iterations"))
+ Log::Warn << "--max_iterations (-n) ignored because --training_file (-t) "
+ << "is not specified." << endl;
+ if (CLI::HasParam("normalize"))
+ Log::Warn << "--normalize (-N) ignored because --training_file (-t) is "
+ << "not specified." << endl;
+ if (CLI::HasParam("tolerance"))
+ Log::Warn << "--tolerance (-o) ignored because --training_file (-t) is "
+ << "not specified." << endl;
}
- // If there is an initial dictionary, be sure we do not initialize one.
- if (initialDictionaryFile != "")
+ // Do we have an existing model?
+ LocalCoordinateCoding lcc(0, 0.0);
+ if (CLI::HasParam("input_model_file"))
{
- LocalCoordinateCoding lcc(atoms, lambda, maxIterations, objTolerance);
+ data::Load(CLI::GetParam<string>("input_model_file"), "lcc_model", lcc,
+ true);
+ }
- // Load initial dictionary directly into LCC object.
- data::Load(initialDictionaryFile, lcc.Dictionary(), true);
+ if (CLI::HasParam("training_file"))
+ {
+ mat matX;
+ data::Load(CLI::GetParam<string>("training_file"), matX, true);
- // Validate size of initial dictionary.
- if (lcc.Dictionary().n_cols != atoms)
+ // Normalize each point if the user asked for it.
+ if (CLI::HasParam("normalize"))
{
- Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
- << " atoms, but the number of atoms was specified to be " << atoms
- << "!" << endl;
+ Log::Info << "Normalizing data before coding..." << endl;
+ for (size_t i = 0; i < matX.n_cols; ++i)
+ matX.col(i) /= norm(matX.col(i), 2);
}
- if (lcc.Dictionary().n_rows != input.n_rows)
+ lcc.Lambda() = CLI::GetParam<double>("lambda");
+ lcc.Atoms() = (size_t) CLI::GetParam<int>("atoms");
+ lcc.MaxIterations() = (size_t) CLI::GetParam<int>("max_iterations");
+ lcc.Tolerance() = CLI::GetParam<double>("tolerance");
+
+ // Inform the user if we are overwriting their model.
+ if (CLI::HasParam("input_model_file"))
{
- Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
- << " dimensions, but the data has " << input.n_rows << " dimensions!"
- << endl;
+ Log::Info << "Using dictionary from existing model in '"
+ << CLI::GetParam<string>("input_model_file") << "' as initial "
+ << "dictionary for training." << endl;
+ lcc.Train<NothingInitializer>(matX);
}
+ else if (CLI::HasParam("initial_dictionary"))
+ {
+ // Load initial dictionary directly into LCC object.
+ data::Load(CLI::GetParam<string>("initial_dictionary"), lcc.Dictionary(),
+ true);
+
+ // Validate the size of the initial dictionary.
+ if (lcc.Dictionary().n_cols != lcc.Atoms())
+ {
+ Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_cols
+ << " atoms, but the number of atoms was specified to be "
+ << lcc.Atoms() << "!" << endl;
+ }
+
+ if (lcc.Dictionary().n_rows != matX.n_rows)
+ {
+ Log::Fatal << "The initial dictionary has " << lcc.Dictionary().n_rows
+ << " dimensions, but the data has " << matX.n_rows << " dimensions!"
+ << endl;
+ }
+
+ // Train the model.
+ lcc.Train<NothingInitializer>(matX);
+ }
+ else
+ {
+ // Run with the default initialization.
+ lcc.Train(matX);
+ }
+ }
+
+ // Now, do we have any matrix to encode?
+ if (CLI::HasParam("test_file"))
+ {
+ mat matY;
+ data::Load(CLI::GetParam<string>("test_file"), matY, true);
- // Run LCC.
- lcc.Train<NothingInitializer>(input);
+ if (matY.n_rows != lcc.Dictionary().n_rows)
+ Log::Fatal << "Model was trained with a dimensionality of "
+ << lcc.Dictionary().n_rows << ", but data in test file "
+ << CLI::GetParam<string>("test_file") << " has a dimensionality of "
+ << matY.n_rows << "!" << endl;
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, lcc.Dictionary());
+ // Normalize each point if the user asked for it.
+ if (CLI::HasParam("normalize"))
+ {
+ Log::Info << "Normalizing test data before coding..." << endl;
+ for (size_t i = 0; i < matY.n_cols; ++i)
+ matY.col(i) /= norm(matY.col(i), 2);
+ }
mat codes;
- lcc.Encode(input, codes);
+ lcc.Encode(matY, codes);
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, codes);
+ if (CLI::HasParam("codes_file"))
+ data::Save(CLI::GetParam<string>("codes_file"), codes);
}
- else
- {
- // No initial dictionary.
- LocalCoordinateCoding lcc(input, atoms, lambda, maxIterations,
- objTolerance);
- // Save the results.
- Log::Info << "Saving dictionary matrix to '" << dictionaryFile << "'.\n";
- data::Save(dictionaryFile, lcc.Dictionary());
+ // Did the user want to save the dictionary?
+ if (CLI::HasParam("dictionary_file"))
+ data::Save(CLI::GetParam<string>("dictionary_file"), lcc.Dictionary());
- mat codes;
- lcc.Encode(input, codes);
-
- Log::Info << "Saving sparse codes to '" << codesFile << "'.\n";
- data::Save(codesFile, codes);
- }
+ // Did the user want to save the model?
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), "lcc_model", lcc,
+ false); // Non-fatal on failure.
}
More information about the mlpack-git
mailing list