[mlpack-git] master: Refactor main program to allow saving/loading LSH models. (6f16ab3)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:09 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271
>---------------------------------------------------------------
commit 6f16ab373eec67c8767be55cc9ca4e25d24c1625
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Nov 20 23:44:24 2015 +0000
Refactor main program to allow saving/loading LSH models.
>---------------------------------------------------------------
6f16ab373eec67c8767be55cc9ca4e25d24c1625
src/mlpack/methods/lsh/lsh_main.cpp | 138 ++++++++++++++++++++++++------------
1 file changed, 91 insertions(+), 47 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 6fe63e3..0d12a11 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -44,13 +44,17 @@ PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
"set the random seed.");
// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+ "");
PARAM_STRING("distances_file", "File to output distances into.", "d", "");
PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+// We can load or save models.
+PARAM_STRING("input_model_file", "File to load LSH model from. (Cannot be "
+ "specified with --reference_file.)", "m", "");
+PARAM_STRING("output_model_file", "File to save LSH model to.", "M", "");
+PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0);
PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
PARAM_INT("projections", "The number of hash functions for each table", "K",
@@ -59,7 +63,7 @@ PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
"LSH preprocessing. By default, the LSH class automatically estimates a "
"hash width for its use.", "H", 0.0);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
+PARAM_INT("second_hash_size", "The size of the second level hash table.", "S",
99901);
PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
500);
@@ -76,30 +80,56 @@ int main(int argc, char *argv[])
math::RandomSeed((size_t) time(NULL));
// Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ const string distancesFile = CLI::GetParam<string>("distances_file");
+ const string neighborsFile = CLI::GetParam<string>("neighbors_file");
+ const string inputModelFile = CLI::GetParam<string>("input_model_file");
+ const string outputModelFile = CLI::GetParam<string>("output_model_file");
size_t k = CLI::GetParam<int>("k");
size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
size_t bucketSize = CLI::GetParam<int>("bucket_size");
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile, referenceData, true);
+ if (CLI::HasParam("input_model_file") && CLI::HasParam("reference_file"))
+ {
+ Log::Fatal << "Cannot specify both --reference_file and --input_model_file!"
+ << " Either create a new model with --reference_file or use an existing"
+ << " model with --input_model_file." << endl;
+ }
+
+ if (!CLI::HasParam("input_model_file") && !CLI::HasParam("reference_file"))
+ {
+ Log::Fatal << "Must specify either --input_model_file or --reference_file!"
+ << endl;
+ }
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+ if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+ !CLI::HasParam("output_model_file"))
+ {
+ Log::Warn << "Neither --neighbors_file, --distances_file, nor "
+ << "--output_model_file are specified; no results will be saved."
+ << endl;
+ }
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
+ if ((CLI::HasParam("query_file") && !CLI::HasParam("k")) ||
+ (!CLI::HasParam("query_file") && CLI::HasParam("k")))
{
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
+ Log::Fatal << "Both --query_file and --k must be specified if search is to "
+ << "be done!" << endl;
}
+ if (CLI::HasParam("input_model_file") && CLI::HasParam("k") &&
+ !CLI::HasParam("query_file"))
+ {
+ Log::Info << "Performing LSH-based approximate nearest neighbor search on "
+ << "the reference dataset in the model stored in '" << inputModelFile
+ << "'." << endl;
+ }
+
+ // These declarations are here so that the matrices don't go out of scope.
+ arma::mat referenceData;
+ arma::mat queryData;
+
// Pick up the LSH-specific parameters.
const size_t numProj = CLI::GetParam<int>("projections");
const size_t numTables = CLI::GetParam<int>("tables");
@@ -108,15 +138,6 @@ int main(int argc, char *argv[])
arma::Mat<size_t> neighbors;
arma::mat distances;
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
-
- data::Load(queryFile, queryData, true);
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- }
-
if (hashWidth == 0.0)
Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
numTables << " tables (L) with default hash width." << endl;
@@ -124,30 +145,53 @@ int main(int argc, char *argv[])
Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
numTables << " tables (L) with hash width(r): " << hashWidth << endl;
- Timer::Start("hash_building");
-
- LSHSearch<>* allkann;
-
- allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
- secondHashSize, bucketSize);
-
- Timer::Stop("hash_building");
+ LSHSearch<> allkann;
+ if (CLI::HasParam("reference_file"))
+ {
+ data::Load(referenceFile, referenceData, true);
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")."
+ << endl;
+
+ Timer::Start("hash_building");
+ allkann.Train(referenceData, numProj, numTables, hashWidth, secondHashSize,
+ bucketSize);
+ Timer::Stop("hash_building");
+ }
+ else if (CLI::HasParam("input_model_file"))
+ {
+ data::Load(inputModelFile, "lsh_model", allkann, true); // Fatal on fail.
+ }
- Log::Info << "Computing " << k << " distance approximate nearest neighbors "
- << endl;
- if (CLI::HasParam("query_file"))
- allkann->Search(queryData, k, neighbors, distances);
- else
- allkann->Search(k, neighbors, distances);
+ if (CLI::HasParam("k"))
+ {
+ Log::Info << "Computing " << k << " distance approximate nearest neighbors."
+ << endl;
+ if (CLI::HasParam("query_file"))
+ {
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile, queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ }
+ allkann.Search(queryData, k, neighbors, distances);
+ }
+ else
+ {
+ allkann.Search(k, neighbors, distances);
+ }
+ }
Log::Info << "Neighbors computed." << endl;
- // Save output.
- if (distancesFile != "")
+ // Save output, if desired.
+ if (CLI::HasParam("distances_file"))
data::Save(distancesFile, distances);
-
- if (neighborsFile != "")
+ if (CLI::HasParam("neighbors_file"))
data::Save(neighborsFile, neighbors);
-
- delete allkann;
+ if (CLI::HasParam("output_model_file"))
+ data::Save(outputModelFile, "lsh_model", allkann);
}
More information about the mlpack-git
mailing list