[mlpack-git] master: Refactor main program to allow saving/loading LSH models. (6f16ab3)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:09 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271

>---------------------------------------------------------------

commit 6f16ab373eec67c8767be55cc9ca4e25d24c1625
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Nov 20 23:44:24 2015 +0000

    Refactor main program to allow saving/loading LSH models.


>---------------------------------------------------------------

6f16ab373eec67c8767be55cc9ca4e25d24c1625
 src/mlpack/methods/lsh/lsh_main.cpp | 138 ++++++++++++++++++++++++------------
 1 file changed, 91 insertions(+), 47 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_main.cpp b/src/mlpack/methods/lsh/lsh_main.cpp
index 6fe63e3..0d12a11 100644
--- a/src/mlpack/methods/lsh/lsh_main.cpp
+++ b/src/mlpack/methods/lsh/lsh_main.cpp
@@ -44,13 +44,17 @@ PROGRAM_INFO("All K-Approximate-Nearest-Neighbor Search with LSH",
     "set the random seed.");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
-    "r");
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+    "");
 PARAM_STRING("distances_file", "File to output distances into.", "d", "");
 PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
 
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
+// We can load or save models.
+PARAM_STRING("input_model_file", "File to load LSH model from.  (Cannot be "
+    "specified with --reference_file.)", "m", "");
+PARAM_STRING("output_model_file", "File to save LSH model to.", "M", "");
 
+PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0);
 PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
 
 PARAM_INT("projections", "The number of hash functions for each table", "K",
@@ -59,7 +63,7 @@ PARAM_INT("tables", "The number of hash tables to be used.", "L", 30);
 PARAM_DOUBLE("hash_width", "The hash width for the first-level hashing in the "
     "LSH preprocessing. By default, the LSH class automatically estimates a "
     "hash width for its use.", "H", 0.0);
-PARAM_INT("second_hash_size", "The size of the second level hash table.", "M",
+PARAM_INT("second_hash_size", "The size of the second level hash table.", "S",
     99901);
 PARAM_INT("bucket_size", "The size of a bucket in the second level hash.", "B",
     500);
@@ -76,30 +80,56 @@ int main(int argc, char *argv[])
     math::RandomSeed((size_t) time(NULL));
 
   // Get all the parameters.
-  string referenceFile = CLI::GetParam<string>("reference_file");
-  string distancesFile = CLI::GetParam<string>("distances_file");
-  string neighborsFile = CLI::GetParam<string>("neighbors_file");
+  const string referenceFile = CLI::GetParam<string>("reference_file");
+  const string distancesFile = CLI::GetParam<string>("distances_file");
+  const string neighborsFile = CLI::GetParam<string>("neighbors_file");
+  const string inputModelFile = CLI::GetParam<string>("input_model_file");
+  const string outputModelFile = CLI::GetParam<string>("output_model_file");
 
   size_t k = CLI::GetParam<int>("k");
   size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
   size_t bucketSize = CLI::GetParam<int>("bucket_size");
 
-  arma::mat referenceData;
-  arma::mat queryData; // So it doesn't go out of scope.
-  data::Load(referenceFile, referenceData, true);
+  if (CLI::HasParam("input_model_file") && CLI::HasParam("reference_file"))
+  {
+    Log::Fatal << "Cannot specify both --reference_file and --input_model_file!"
+        << " Either create a new model with --reference_file or use an existing"
+        << " model with --input_model_file." << endl;
+  }
+
+  if (!CLI::HasParam("input_model_file") && !CLI::HasParam("reference_file"))
+  {
+    Log::Fatal << "Must specify either --input_model_file or --reference_file!"
+        << endl;
+  }
 
-  Log::Info << "Loaded reference data from '" << referenceFile << "' ("
-      << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
+  if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+      !CLI::HasParam("output_model_file"))
+  {
+    Log::Warn << "Neither --neighbors_file, --distances_file, nor "
+        << "--output_model_file are specified; no results will be saved."
+        << endl;
+  }
 
-  // Sanity check on k value: must be greater than 0, must be less than the
-  // number of reference points.
-  if (k > referenceData.n_cols)
+  if ((CLI::HasParam("query_file") && !CLI::HasParam("k")) ||
+      (!CLI::HasParam("query_file") && CLI::HasParam("k")))
   {
-    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
-    Log::Fatal << "than or equal to the number of reference points (";
-    Log::Fatal << referenceData.n_cols << ")." << endl;
+    Log::Fatal << "Both --query_file and --k must be specified if search is to "
+        << "be done!" << endl;
   }
 
+  if (CLI::HasParam("input_model_file") && CLI::HasParam("k") &&
+      !CLI::HasParam("query_file"))
+  {
+    Log::Info << "Performing LSH-based approximate nearest neighbor search on "
+        << "the reference dataset in the model stored in '" << inputModelFile
+        << "'." << endl;
+  }
+
+  // These declarations are here so that the matrices don't go out of scope.
+  arma::mat referenceData;
+  arma::mat queryData;
+
   // Pick up the LSH-specific parameters.
   const size_t numProj = CLI::GetParam<int>("projections");
   const size_t numTables = CLI::GetParam<int>("tables");
@@ -108,15 +138,6 @@ int main(int argc, char *argv[])
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    string queryFile = CLI::GetParam<string>("query_file");
-
-    data::Load(queryFile, queryData, true);
-    Log::Info << "Loaded query data from '" << queryFile << "' ("
-              << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-  }
-
   if (hashWidth == 0.0)
     Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
         numTables << " tables (L) with default hash width." << endl;
@@ -124,30 +145,53 @@ int main(int argc, char *argv[])
     Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
         numTables << " tables (L) with hash width(r): " << hashWidth << endl;
 
-  Timer::Start("hash_building");
-
-  LSHSearch<>* allkann;
-
-  allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
-                            secondHashSize, bucketSize);
-
-  Timer::Stop("hash_building");
+  LSHSearch<> allkann;
+  if (CLI::HasParam("reference_file"))
+  {
+    data::Load(referenceFile, referenceData, true);
+    Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+        << referenceData.n_rows << " x " << referenceData.n_cols << ")."
+        << endl;
+
+    Timer::Start("hash_building");
+    allkann.Train(referenceData, numProj, numTables, hashWidth, secondHashSize,
+        bucketSize);
+    Timer::Stop("hash_building");
+  }
+  else if (CLI::HasParam("input_model_file"))
+  {
+    data::Load(inputModelFile, "lsh_model", allkann, true); // Fatal on fail.
+  }
 
-  Log::Info << "Computing " << k << " distance approximate nearest neighbors "
-      << endl;
-  if (CLI::HasParam("query_file"))
-    allkann->Search(queryData, k, neighbors, distances);
-  else
-    allkann->Search(k, neighbors, distances);
+  if (CLI::HasParam("k"))
+  {
+    Log::Info << "Computing " << k << " distance approximate nearest neighbors."
+        << endl;
+    if (CLI::HasParam("query_file"))
+    {
+      if (CLI::GetParam<string>("query_file") != "")
+      {
+        string queryFile = CLI::GetParam<string>("query_file");
+
+        data::Load(queryFile, queryData, true);
+        Log::Info << "Loaded query data from '" << queryFile << "' ("
+            << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+      }
+      allkann.Search(queryData, k, neighbors, distances);
+    }
+    else
+    {
+      allkann.Search(k, neighbors, distances);
+    }
+  }
 
   Log::Info << "Neighbors computed." << endl;
 
-  // Save output.
-  if (distancesFile != "")
+  // Save output, if desired.
+  if (CLI::HasParam("distances_file"))
     data::Save(distancesFile, distances);
-
-  if (neighborsFile != "")
+  if (CLI::HasParam("neighbors_file"))
     data::Save(neighborsFile, neighbors);
-
-  delete allkann;
+  if (CLI::HasParam("output_model_file"))
+    data::Save(outputModelFile, "lsh_model", allkann);
 }



More information about the mlpack-git mailing list