[mlpack-git] master: Refactor main program to include QDAFN. (6c317b8)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 25 03:15:17 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 6c317b8b9a4a2a9fea729222878a549bc72bc93b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Oct 25 16:15:17 2016 +0900

    Refactor main program to include QDAFN.


>---------------------------------------------------------------

6c317b8b9a4a2a9fea729222878a549bc72bc93b
 .../methods/approx_kfn/drusilla_select_main.cpp    | 289 ++++++++++++++++-----
 1 file changed, 226 insertions(+), 63 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
index 9e55ec7..4d6ef67 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
@@ -5,96 +5,259 @@
  * Command-line program for the SmartHash algorithm.
  */
 #include <mlpack/core.hpp>
-#include "smarthash_fn.hpp"
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include "drusilla_select.hpp"
+#include "qdafn.hpp"
 
-using namespace smarthash;
 using namespace mlpack;
+using namespace mlpack::neighbor;
 using namespace std;
 
-PROGRAM_INFO("Query-dependent approximate furthest neighbor search",
-    "This program implements the algorithm from the SISAP 2015 paper titled "
-    "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. "
-    "Silvestri, J. Sivertsen, and M. Skala.  Specify a reference set (set to "
-    "search in) with --reference_file, specify a query set (set to search for) "
-    "with --query_file, and specify algorithm parameters with --num_tables and "
-    "--num_projections (or don't, and defaults will be used).  Also specify "
-    "the number of points to search for with --k.  Each of those options has "
-    "short names too; see the detailed parameter documentation below."
+PROGRAM_INFO("Approximate furthest neighbor search",
+    "This program implements two strategies for furthest neighbor search. "
+    "These strategies are:"
+    "\n\n"
+    " - The 'qdafn' algorithm from 'Approximate Furthest Neighbor in High "
+    "Dimensions' by R. Pagh, F. Silvestri, J. Sivertsen, and M. Skala, in "
+    "Similarity Search and Applications 2015 (SISAP)."
+    "\n"
+    " - The 'DrusillaSelect' algorithm from 'Fast approximate furthest "
+    "neighbors with data-dependent candidate selection, by R.R. Curtin and A.B."
+    " Gardner, in Similarity Search and Applications 2016 (SISAP)."
+    "\n\n"
+    "These two strategies give approximate results for the furthest neighbor "
+    "search problem and can be used as fast replacements for other furthest "
+    "neighbor techniques such as those found in the mlpack_kfn program.  Note "
+    "that typically, the 'ds' algorithm requires far fewer tables and "
+    "projections than the 'qdafn' algorithm."
+    "\n\n"
+    "Specify a reference set (set to search in) with --reference_file, "
+    "specify a query set with --query_file, and specify algorithm parameters "
+    "with --num_tables (-l) and --num_projections (-m) (or don't and defaults "
+    "will be used).  The algorithm to be used (either 'ds'---the default---or "
+    "'qdafn') may be specified with --algorithm.  Also specify the number of "
+    "neighbors to search for with --k.  Each of those options also has short "
+    "names; see the detailed parameter documentation below."
+    "\n\n"
+    "If no query file is specified, the reference set will be used as the "
+    "query set.  A model may be saved with --output_model_file (-M), and an "
+    "input model may be loaded instead of specifying a reference set with "
+    "--input_model_file (-m)."
     "\n\n"
     "Results for each query point are stored in the files specified by "
     "--neighbors_file and --distances_file.  This is in the same format as the "
-    "mlpack KFN and KNN programs: each row holds the k distances or neighbor "
-    "indices for each query point.");
+    "mlpack_kfn and mlpack_knn programs: each row holds the k distances or "
+    "neighbor indices for each query point.");
+
+PARAM_STRING_IN("reference_file", "File containing reference points.", "r", "");
+PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
 
-PARAM_STRING_REQ("reference_file", "File containing reference points.", "r");
-PARAM_STRING_REQ("query_file", "File containing query points.", "q");
+// Model loading and saving.
+PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
+PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", "");
 
-PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k");
+PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k");
 
-PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10);
-PARAM_INT("num_projections", "Number of projections to use in each hash table.",
-    "p", 30);
+PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5);
+PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
+    "table.", "m", 5);
+PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
 
-PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.",
+PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
     "n", "");
-PARAM_STRING("distances_file", "File to save furthest neighbor distances to.",
+PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.",
     "d", "");
 
 PARAM_FLAG("calculate_error", "If set, calculate the average distance error.",
     "e");
-PARAM_STRING("exact_distances_file", "File containing exact distances", "x", "");
+PARAM_STRING_IN("exact_distances_file", "File containing exact distances to "
+    "furthest neighbors; this can be used to avoid explicit calculation when "
+    "--calculate_error is set.", "x", "");
+
+// If we save a model we must also save what type it is.
+class ApproxKFNModel
+{
+ public:
+  int type;
+  boost::any model;
+
+  //! Constructor, which does nothing.
+  ApproxKFNModel() : type(0) { /* Nothing to do. */ }
+
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    ar & data::CreateNVP(type, "type");
+    if (type == 0)
+      ar & data::CreateNVP(boost::any_cast<DrusillaSelect<>>(model), "model");
+    else
+      ar & data::CreateNVP(boost::any_cast<QDAFN<>>(model), "model");
+  }
+};
 
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
-  const string referenceFile = CLI::GetParam<string>("reference_file");
-  const string queryFile = CLI::GetParam<string>("query_file");
-  const size_t k = (size_t) CLI::GetParam<int>("k");
-  const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
-  const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
-
-  // Load the data.
-  arma::mat referenceData, queryData;
-  data::Load(referenceFile, referenceData, true);
-  data::Load(queryFile, queryData, true);
-
-  // Construct the object.
-  Timer::Start("smarthash_construct");
-  SmartHash<> q(referenceData, numTables, numProjections);
-  Timer::Stop("smarthash_construct");
-
-  // Do the search.
-  arma::Mat<size_t> neighbors;
-  arma::mat distances;
-  Timer::Start("smarthash_search");
-  q.Search(queryData, k, neighbors, distances);
-  Timer::Stop("smarthash_search");
-
-  if (CLI::HasParam("calculate_error"))
+  if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must"
+        << " be specified!" << endl;
+  if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
+        << " can be specified!" << endl;
+  if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k"))
+    Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;"
+        << " no task will be performed." << endl;
+  if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+      !CLI::HasParam("output_model_file"))
+    Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or "
+        << "--distances_file (-d) are specified; no output will be saved!"
+        << endl;
+  if (CLI::GetParam<string>("algorithm") != "ds" &&
+      CLI::GetParam<string>("algorithm") != "qdafn")
+    Log::Fatal << "Unknown algorithm '" << CLI::GetParam<string>("algorithm")
+        << "'; must be 'ds' or 'qdafn'!" << endl;
+  if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") ||
+                              CLI::HasParam("query_file")))
+    Log::Fatal << "If search is being performed, then either --query_file "
+        << "or --reference_file must be specified!" << endl;
+
+  if (CLI::GetParam<int>("num_tables") <= 0)
+    Log::Fatal << "Invalid --num_tables value ("
+        << CLI::GetParam<int>("num_tables") << "); must be greater than 0!"
+        << endl;
+  if (CLI::GetParam<int>("num_projections") <= 0)
+    Log::Fatal << "Invalid --num_projections value ("
+        << CLI::GetParam<int>("num_projections") << "); must be greater than 0!"
+        << endl;
+
+  if (CLI::HasParam("calculate_error") && !CLI::HasParam("k"))
+    Log::Warn << "--calculate_error ignored because --k is not specified."
+        << endl;
+  if (CLI::HasParam("exact_distances_file") &&
+      !CLI::HasParam("calculate_error"))
+    Log::Warn << "--exact_distances_file ignored beceause --calculate_error is "
+        << "not specified." << endl;
+  if (CLI::HasParam("calculate_error") &&
+      !CLI::HasParam("exact_distances_file") &&
+      !CLI::HasParam("reference_file"))
+    Log::Fatal << "Cannot calculate error without either --exact_distances_file"
+        << " or --reference_file specified!" << endl;
+
+  // Do the building of a model, if necessary.
+  ApproxKFNModel m;
+  arma::mat referenceSet; // This may be used at query time.
+  if (CLI::HasParam("reference_file"))
+  {
+    const string referenceFile = CLI::GetParam<string>("reference_file");
+    data::Load(referenceFile, referenceSet);
+
+    const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
+    const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+    const string algorithm = CLI::GetParam<string>("algorithm");
+
+    if (algorithm == "ds")
+    {
+      Timer::Start("drusilla_select_construct");
+      Log::Info << "Building DrusillaSelect model..." << endl;
+      m.type = 0;
+      m.model = boost::any(DrusillaSelect<>(referenceSet, numTables,
+          numProjections));
+      Timer::Stop("drusilla_select_construct");
+    }
+    else
+    {
+      Timer::Start("qdafn_construct");
+      Log::Info << "Building QDAFN model..." << endl;
+      m.type = 1;
+      m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections));
+      Timer::Stop("qdafn_construct");
+    }
+  }
+  else
   {
-//    neighbor::AllkFN kfn(referenceData);
+    // We must load the model from file.
+    const string inputModelFile = CLI::GetParam<string>("input_model_file");
+    data::Load(inputModelFile, m);
+  }
+
+  // Now, do we need to do any queries?
+  if (CLI::HasParam("k"))
+  {
+    const size_t k = (size_t) CLI::GetParam<int>("k");
+
+    arma::Mat<size_t> neighbors;
+    arma::mat distances;
+
+    if (CLI::HasParam("query_file"))
+    {
+      const string queryFile = CLI::GetParam<string>("query_file");
+      arma::mat querySet;
+      data::Load(querySet, queryFile);
+
+      if (m.type == 0)
+      {
+        Timer::Start("drusilla_select_search");
+        boost::any_cast<DrusillaSelect<>>(m.model).Search(querySet, k,
+            neighbors, distances);
+        Timer::Stop("drusilla_select_search");
+      }
+      else
+      {
+        Timer::Start("qdafn_search");
+        boost::any_cast<QDAFN<>>(m.model).Search(querySet, k, neighbors,
+            distances);
+        Timer::Stop("qdafn_search");
+      }
+    }
+    else
+    {
+      // We will do search with the reference set.
+      if (m.type == 0)
+        boost::any_cast<DrusillaSelect<>>(m.model).Search(k, neighbors,
+            distances);
+      else
+        boost::any_cast<QDAFN<>>(m.model).Search(k, neighbors, distances);
+    }
 
-//    arma::Mat<size_t> trueNeighbors;
-    arma::mat trueDistances;
-    data::Load(CLI::GetParam<string>("exact_distances_file"), trueDistances);
+    // Should we calculate error?
+    if (CLI::HasParam("calculate_error"))
+    {
+      arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
+      arma::mat exactDistances;
+      if (CLI::HasParam("exact_distances_file"))
+      {
+        data::Load(CLI::GetParam<string>("exact_distances_file"),
+            exactDistances);
+      }
+      else
+      {
+        // Calculate exact distances.  We are guaranteed the reference set is
+        // available.
+        AllkFN kfn(referenceSet);
+        arma::Mat<size_t> exactNeighbors;
+        kfn.Search(set, k, exactNeighbors, exactDistances);
 
-//    kfn.Search(queryData, 1, trueNeighbors, trueDistances);
+        const double averageError = arma::sum(trueDistances / distances.row(0))
+            / distances.n_cols;
+        const double minError = arma::min(trueDistances / distances.row(0));
+        const double maxError = arma::max(trueDistances / distances.row(0));
 
-    const double averageError = arma::sum(trueDistances / distances.row(0)) /
-        distances.n_cols;
-    const double minError = arma::min(trueDistances / distances.row(0));
-    const double maxError = arma::max(trueDistances / distances.row(0));
+        Log::Info << "Average error: " << averageError << "." << endl;
+        Log::Info << "Maximum error: " << maxError << "." << endl;
+        Log::Info << "Minimum error: " << minError << "." << endl;
+      }
+    }
 
-    Log::Info << "Average error: " << averageError << "." << endl;
-    Log::Info << "Maximum error: " << maxError << "." << endl;
-    Log::Info << "Minimum error: " << minError << "." << endl;
+    // Save results, if desired.
+    if (CLI::HasParam("neighbors_file"))
+      data::Save(CLI::GetParam<string>("neighbors_file"), neighbors, false);
+    if (CLI::HasParam("distances_file"))
+      data::Save(CLI::GetParam<string>("distances_file"), distances, false);
   }
 
-  // Save the results.
-  if (CLI::HasParam("neighbors_file"))
-    data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
-  if (CLI::HasParam("distances_file"))
-    data::Save(CLI::GetParam<string>("distances_file"), distances);
+  // Should we save the model?
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), m);
 }




More information about the mlpack-git mailing list