[mlpack-git] master: Refactor main program to include QDAFN. (6c317b8)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Oct 25 03:15:17 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd
>---------------------------------------------------------------
commit 6c317b8b9a4a2a9fea729222878a549bc72bc93b
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Oct 25 16:15:17 2016 +0900
Refactor main program to include QDAFN.
>---------------------------------------------------------------
6c317b8b9a4a2a9fea729222878a549bc72bc93b
.../methods/approx_kfn/drusilla_select_main.cpp | 289 ++++++++++++++++-----
1 file changed, 226 insertions(+), 63 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
index 9e55ec7..4d6ef67 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
@@ -5,96 +5,259 @@
* Command-line program for the SmartHash algorithm.
*/
#include <mlpack/core.hpp>
-#include "smarthash_fn.hpp"
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include "drusilla_select.hpp"
+#include "qdafn.hpp"
-using namespace smarthash;
using namespace mlpack;
+using namespace mlpack::neighbor;
using namespace std;
-PROGRAM_INFO("Query-dependent approximate furthest neighbor search",
- "This program implements the algorithm from the SISAP 2015 paper titled "
- "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. "
- "Silvestri, J. Sivertsen, and M. Skala. Specify a reference set (set to "
- "search in) with --reference_file, specify a query set (set to search for) "
- "with --query_file, and specify algorithm parameters with --num_tables and "
- "--num_projections (or don't, and defaults will be used). Also specify "
- "the number of points to search for with --k. Each of those options has "
- "short names too; see the detailed parameter documentation below."
+PROGRAM_INFO("Approximate furthest neighbor search",
+ "This program implements two strategies for furthest neighbor search. "
+ "These strategies are:"
+ "\n\n"
+ " - The 'qdafn' algorithm from 'Approximate Furthest Neighbor in High "
+ "Dimensions' by R. Pagh, F. Silvestri, J. Sivertsen, and M. Skala, in "
+ "Similarity Search and Applications 2015 (SISAP)."
+ "\n"
+ " - The 'DrusillaSelect' algorithm from 'Fast approximate furthest "
+ "neighbors with data-dependent candidate selection, by R.R. Curtin and A.B."
+ " Gardner, in Similarity Search and Applications 2016 (SISAP)."
+ "\n\n"
+ "These two strategies give approximate results for the furthest neighbor "
+ "search problem and can be used as fast replacements for other furthest "
+ "neighbor techniques such as those found in the mlpack_kfn program. Note "
+ "that typically, the 'ds' algorithm requires far fewer tables and "
+ "projections than the 'qdafn' algorithm."
+ "\n\n"
+ "Specify a reference set (set to search in) with --reference_file, "
+ "specify a query set with --query_file, and specify algorithm parameters "
+ "with --num_tables (-l) and --num_projections (-m) (or don't and defaults "
+ "will be used). The algorithm to be used (either 'ds'---the default---or "
+ "'qdafn') may be specified with --algorithm. Also specify the number of "
+ "neighbors to search for with --k. Each of those options also has short "
+ "names; see the detailed parameter documentation below."
+ "\n\n"
+ "If no query file is specified, the reference set will be used as the "
+ "query set. A model may be saved with --output_model_file (-M), and an "
+ "input model may be loaded instead of specifying a reference set with "
+ "--input_model_file (-m)."
"\n\n"
"Results for each query point are stored in the files specified by "
"--neighbors_file and --distances_file. This is in the same format as the "
- "mlpack KFN and KNN programs: each row holds the k distances or neighbor "
- "indices for each query point.");
+ "mlpack_kfn and mlpack_knn programs: each row holds the k distances or "
+ "neighbor indices for each query point.");
+
+PARAM_STRING_IN("reference_file", "File containing reference points.", "r", "");
+PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
-PARAM_STRING_REQ("reference_file", "File containing reference points.", "r");
-PARAM_STRING_REQ("query_file", "File containing query points.", "q");
+// Model loading and saving.
+PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
+PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", "");
-PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k");
+PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k");
-PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10);
-PARAM_INT("num_projections", "Number of projections to use in each hash table.",
- "p", 30);
+PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5);
+PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
+ "table.", "m", 5);
+PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
-PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.",
+PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
"n", "");
-PARAM_STRING("distances_file", "File to save furthest neighbor distances to.",
+PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.",
"d", "");
PARAM_FLAG("calculate_error", "If set, calculate the average distance error.",
"e");
-PARAM_STRING("exact_distances_file", "File containing exact distances", "x", "");
+PARAM_STRING_IN("exact_distances_file", "File containing exact distances to "
+ "furthest neighbors; this can be used to avoid explicit calculation when "
+ "--calculate_error is set.", "x", "");
+
+// If we save a model we must also save what type it is.
+class ApproxKFNModel
+{
+ public:
+ int type;
+ boost::any model;
+
+ //! Constructor, which does nothing.
+ ApproxKFNModel() : type(0) { /* Nothing to do. */ }
+
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(type, "type");
+ if (type == 0)
+ ar & data::CreateNVP(boost::any_cast<DrusillaSelect<>>(model), "model");
+ else
+ ar & data::CreateNVP(boost::any_cast<QDAFN<>>(model), "model");
+ }
+};
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
- const string referenceFile = CLI::GetParam<string>("reference_file");
- const string queryFile = CLI::GetParam<string>("query_file");
- const size_t k = (size_t) CLI::GetParam<int>("k");
- const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
- const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
-
- // Load the data.
- arma::mat referenceData, queryData;
- data::Load(referenceFile, referenceData, true);
- data::Load(queryFile, queryData, true);
-
- // Construct the object.
- Timer::Start("smarthash_construct");
- SmartHash<> q(referenceData, numTables, numProjections);
- Timer::Stop("smarthash_construct");
-
- // Do the search.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
- Timer::Start("smarthash_search");
- q.Search(queryData, k, neighbors, distances);
- Timer::Stop("smarthash_search");
-
- if (CLI::HasParam("calculate_error"))
+ if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must"
+ << " be specified!" << endl;
+ if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
+ << " can be specified!" << endl;
+ if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k"))
+ Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;"
+ << " no task will be performed." << endl;
+ if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+ !CLI::HasParam("output_model_file"))
+ Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or "
+ << "--distances_file (-d) are specified; no output will be saved!"
+ << endl;
+ if (CLI::GetParam<string>("algorithm") != "ds" &&
+ CLI::GetParam<string>("algorithm") != "qdafn")
+ Log::Fatal << "Unknown algorithm '" << CLI::GetParam<string>("algorithm")
+ << "'; must be 'ds' or 'qdafn'!" << endl;
+ if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") ||
+ CLI::HasParam("query_file")))
+ Log::Fatal << "If search is being performed, then either --query_file "
+ << "or --reference_file must be specified!" << endl;
+
+ if (CLI::GetParam<int>("num_tables") <= 0)
+ Log::Fatal << "Invalid --num_tables value ("
+ << CLI::GetParam<int>("num_tables") << "); must be greater than 0!"
+ << endl;
+ if (CLI::GetParam<int>("num_projections") <= 0)
+ Log::Fatal << "Invalid --num_projections value ("
+ << CLI::GetParam<int>("num_projections") << "); must be greater than 0!"
+ << endl;
+
+ if (CLI::HasParam("calculate_error") && !CLI::HasParam("k"))
+ Log::Warn << "--calculate_error ignored because --k is not specified."
+ << endl;
+ if (CLI::HasParam("exact_distances_file") &&
+ !CLI::HasParam("calculate_error"))
+ Log::Warn << "--exact_distances_file ignored beceause --calculate_error is "
+ << "not specified." << endl;
+ if (CLI::HasParam("calculate_error") &&
+ !CLI::HasParam("exact_distances_file") &&
+ !CLI::HasParam("reference_file"))
+ Log::Fatal << "Cannot calculate error without either --exact_distances_file"
+ << " or --reference_file specified!" << endl;
+
+ // Do the building of a model, if necessary.
+ ApproxKFNModel m;
+ arma::mat referenceSet; // This may be used at query time.
+ if (CLI::HasParam("reference_file"))
+ {
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ data::Load(referenceFile, referenceSet);
+
+ const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
+ const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+ const string algorithm = CLI::GetParam<string>("algorithm");
+
+ if (algorithm == "ds")
+ {
+ Timer::Start("drusilla_select_construct");
+ Log::Info << "Building DrusillaSelect model..." << endl;
+ m.type = 0;
+ m.model = boost::any(DrusillaSelect<>(referenceSet, numTables,
+ numProjections));
+ Timer::Stop("drusilla_select_construct");
+ }
+ else
+ {
+ Timer::Start("qdafn_construct");
+ Log::Info << "Building QDAFN model..." << endl;
+ m.type = 1;
+ m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections));
+ Timer::Stop("qdafn_construct");
+ }
+ }
+ else
{
-// neighbor::AllkFN kfn(referenceData);
+ // We must load the model from file.
+ const string inputModelFile = CLI::GetParam<string>("input_model_file");
+ data::Load(inputModelFile, m);
+ }
+
+ // Now, do we need to do any queries?
+ if (CLI::HasParam("k"))
+ {
+ const size_t k = (size_t) CLI::GetParam<int>("k");
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ if (CLI::HasParam("query_file"))
+ {
+ const string queryFile = CLI::GetParam<string>("query_file");
+ arma::mat querySet;
+ data::Load(querySet, queryFile);
+
+ if (m.type == 0)
+ {
+ Timer::Start("drusilla_select_search");
+ boost::any_cast<DrusillaSelect<>>(m.model).Search(querySet, k,
+ neighbors, distances);
+ Timer::Stop("drusilla_select_search");
+ }
+ else
+ {
+ Timer::Start("qdafn_search");
+ boost::any_cast<QDAFN<>>(m.model).Search(querySet, k, neighbors,
+ distances);
+ Timer::Stop("qdafn_search");
+ }
+ }
+ else
+ {
+ // We will do search with the reference set.
+ if (m.type == 0)
+ boost::any_cast<DrusillaSelect<>>(m.model).Search(k, neighbors,
+ distances);
+ else
+ boost::any_cast<QDAFN<>>(m.model).Search(k, neighbors, distances);
+ }
-// arma::Mat<size_t> trueNeighbors;
- arma::mat trueDistances;
- data::Load(CLI::GetParam<string>("exact_distances_file"), trueDistances);
+ // Should we calculate error?
+ if (CLI::HasParam("calculate_error"))
+ {
+ arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
+ arma::mat exactDistances;
+ if (CLI::HasParam("exact_distances_file"))
+ {
+ data::Load(CLI::GetParam<string>("exact_distances_file"),
+ exactDistances);
+ }
+ else
+ {
+ // Calculate exact distances. We are guaranteed the reference set is
+ // available.
+ AllkFN kfn(referenceSet);
+ arma::Mat<size_t> exactNeighbors;
+ kfn.Search(set, k, exactNeighbors, exactDistances);
-// kfn.Search(queryData, 1, trueNeighbors, trueDistances);
+ const double averageError = arma::sum(trueDistances / distances.row(0))
+ / distances.n_cols;
+ const double minError = arma::min(trueDistances / distances.row(0));
+ const double maxError = arma::max(trueDistances / distances.row(0));
- const double averageError = arma::sum(trueDistances / distances.row(0)) /
- distances.n_cols;
- const double minError = arma::min(trueDistances / distances.row(0));
- const double maxError = arma::max(trueDistances / distances.row(0));
+ Log::Info << "Average error: " << averageError << "." << endl;
+ Log::Info << "Maximum error: " << maxError << "." << endl;
+ Log::Info << "Minimum error: " << minError << "." << endl;
+ }
+ }
- Log::Info << "Average error: " << averageError << "." << endl;
- Log::Info << "Maximum error: " << maxError << "." << endl;
- Log::Info << "Minimum error: " << minError << "." << endl;
+ // Save results, if desired.
+ if (CLI::HasParam("neighbors_file"))
+ data::Save(CLI::GetParam<string>("neighbors_file"), neighbors, false);
+ if (CLI::HasParam("distances_file"))
+ data::Save(CLI::GetParam<string>("distances_file"), distances, false);
}
- // Save the results.
- if (CLI::HasParam("neighbors_file"))
- data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
- if (CLI::HasParam("distances_file"))
- data::Save(CLI::GetParam<string>("distances_file"), distances);
+ // Should we save the model?
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), m);
}
More information about the mlpack-git
mailing list