[mlpack-git] master: Update parameter types. (48e781b)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Nov 2 10:25:27 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/981ffa2d67d8fe38df6c699589005835fef710ea...04551164d9950dbdb3738f0c9d87e2d498fd8192
>---------------------------------------------------------------
commit 48e781bc5f3daaec8eee7e79e9ba3b8895897d85
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Nov 2 10:25:27 2016 -0400
Update parameter types.
>---------------------------------------------------------------
48e781bc5f3daaec8eee7e79e9ba3b8895897d85
src/mlpack/methods/approx_kfn/approx_kfn_main.cpp | 60 ++++++++++-------------
1 file changed, 27 insertions(+), 33 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
index 0342495..0467b59 100644
--- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
+++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
@@ -54,8 +54,8 @@ PROGRAM_INFO("Approximate furthest neighbor search",
"mlpack_kfn and mlpack_knn programs: each row holds the k distances or "
"neighbor indices for each query point.");
-PARAM_STRING_IN("reference_file", "File containing reference points.", "r", "");
-PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
+PARAM_MATRIX_IN("reference", "Matrix containing the reference dataset.", "r");
+PARAM_MATRIX_IN("query", "Matrix containing query points.", "q");
// Model loading and saving.
PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
@@ -68,16 +68,15 @@ PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
"table.", "p", 5);
PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
-PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
- "n", "");
-PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.",
- "d", "");
+PARAM_MATRIX_OUT("neighbors", "Matrix to save neighbor indices to.", "n");
+PARAM_UMATRIX_OUT("distances", "Matrix to save furthest neighbor distances to.",
+ "d");
PARAM_FLAG("calculate_error", "If set, calculate the average distance error for"
" the first furthest neighbor only.", "e");
-PARAM_STRING_IN("exact_distances_file", "File containing exact distances to "
+PARAM_MATRIX_IN("exact_distances", "Matrix containing exact distances to "
"furthest neighbors; this can be used to avoid explicit calculation when "
- "--calculate_error is set.", "x", "");
+ "--calculate_error is set.", "x");
// If we save a model we must also save what type it is.
class ApproxKFNModel
@@ -110,16 +109,16 @@ int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
- if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+ if (!CLI::HasParam("reference") && !CLI::HasParam("input_model_file"))
Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must"
<< " be specified!" << endl;
- if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+ if (CLI::HasParam("reference") && CLI::HasParam("input_model_file"))
Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
<< " can be specified!" << endl;
if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k"))
Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;"
<< " no task will be performed." << endl;
- if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+ if (!CLI::HasParam("neighbors") && !CLI::HasParam("distances") &&
!CLI::HasParam("output_model_file"))
Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or "
<< "--distances_file (-d) are specified; no output will be saved!"
@@ -128,8 +127,8 @@ int main(int argc, char** argv)
CLI::GetParam<string>("algorithm") != "qdafn")
Log::Fatal << "Unknown algorithm '" << CLI::GetParam<string>("algorithm")
<< "'; must be 'ds' or 'qdafn'!" << endl;
- if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") ||
- CLI::HasParam("query_file")))
+ if (CLI::HasParam("k") && !(CLI::HasParam("reference") ||
+ CLI::HasParam("query")))
Log::Fatal << "If search is being performed, then either --query_file "
<< "or --reference_file must be specified!" << endl;
@@ -145,26 +144,25 @@ int main(int argc, char** argv)
if (CLI::HasParam("calculate_error") && !CLI::HasParam("k"))
Log::Warn << "--calculate_error ignored because --k is not specified."
<< endl;
- if (CLI::HasParam("exact_distances_file") &&
- !CLI::HasParam("calculate_error"))
+ if (CLI::HasParam("exact_distances") && !CLI::HasParam("calculate_error"))
Log::Warn << "--exact_distances_file ignored beceause --calculate_error is "
<< "not specified." << endl;
if (CLI::HasParam("calculate_error") &&
- !CLI::HasParam("exact_distances_file") &&
- !CLI::HasParam("reference_file"))
+ !CLI::HasParam("exact_distances") &&
+ !CLI::HasParam("reference"))
Log::Fatal << "Cannot calculate error without either --exact_distances_file"
<< " or --reference_file specified!" << endl;
// Do the building of a model, if necessary.
ApproxKFNModel m;
arma::mat referenceSet; // This may be used at query time.
- if (CLI::HasParam("reference_file"))
+ if (CLI::HasParam("reference"))
{
- const string referenceFile = CLI::GetParam<string>("reference_file");
- data::Load(referenceFile, referenceSet);
+ referenceSet = std::move(CLI::GetParam<arma::mat>("reference"));
const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
- const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+ const size_t numProjections =
+ (size_t) CLI::GetParam<int>("num_projections");
const string algorithm = CLI::GetParam<string>("algorithm");
if (algorithm == "ds")
@@ -201,12 +199,9 @@ int main(int argc, char** argv)
arma::Mat<size_t> neighbors;
arma::mat distances;
- arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
- if (CLI::HasParam("query_file"))
- {
- const string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile, querySet);
- }
+ arma::mat& set = CLI::HasParam("query") ? querySet : referenceSet;
+ if (CLI::HasParam("query"))
+ querySet = std::move(CLI::GetParam<arma::mat>("query"));
if (m.type == 0)
{
@@ -230,10 +225,9 @@ int main(int argc, char** argv)
if (CLI::HasParam("calculate_error"))
{
arma::mat exactDistances;
- if (CLI::HasParam("exact_distances_file"))
+ if (CLI::HasParam("exact_distances"))
{
- data::Load(CLI::GetParam<string>("exact_distances_file"),
- exactDistances);
+ exactDistances = std::move(CLI::GetParam<string>("exact_distances"));
}
else
{
@@ -259,10 +253,10 @@ int main(int argc, char** argv)
}
// Save results, if desired.
- if (CLI::HasParam("neighbors_file"))
- data::Save(CLI::GetParam<string>("neighbors_file"), neighbors, false);
+ if (CLI::HasParam("neighbors"))
+ CLI::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors);
if (CLI::HasParam("distances_file"))
- data::Save(CLI::GetParam<string>("distances_file"), distances, false);
+ CLI::GetParam<arma::mat>("distances") = std::move(distances);
}
// Should we save the model?
More information about the mlpack-git
mailing list