[mlpack-git] master: Update parameter types. (48e781b)

gitdub at mlpack.org gitdub at mlpack.org
Wed Nov 2 10:25:27 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/981ffa2d67d8fe38df6c699589005835fef710ea...04551164d9950dbdb3738f0c9d87e2d498fd8192

>---------------------------------------------------------------

commit 48e781bc5f3daaec8eee7e79e9ba3b8895897d85
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Nov 2 10:25:27 2016 -0400

    Update parameter types.


>---------------------------------------------------------------

48e781bc5f3daaec8eee7e79e9ba3b8895897d85
 src/mlpack/methods/approx_kfn/approx_kfn_main.cpp | 60 ++++++++++-------------
 1 file changed, 27 insertions(+), 33 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
index 0342495..0467b59 100644
--- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
+++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
@@ -54,8 +54,8 @@ PROGRAM_INFO("Approximate furthest neighbor search",
     "mlpack_kfn and mlpack_knn programs: each row holds the k distances or "
     "neighbor indices for each query point.");
 
-PARAM_STRING_IN("reference_file", "File containing reference points.", "r", "");
-PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
+PARAM_MATRIX_IN("reference", "Matrix containing the reference dataset.", "r");
+PARAM_MATRIX_IN("query", "Matrix containing query points.", "q");
 
 // Model loading and saving.
 PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
@@ -68,16 +68,15 @@ PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
     "table.", "p", 5);
 PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
 
-PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
-    "n", "");
-PARAM_STRING_IN("distances_file", "File to save furthest neighbor distances to.",
-    "d", "");
+PARAM_MATRIX_OUT("neighbors", "Matrix to save neighbor indices to.", "n");
+PARAM_UMATRIX_OUT("distances", "Matrix to save furthest neighbor distances to.",
+    "d");
 
 PARAM_FLAG("calculate_error", "If set, calculate the average distance error for"
     " the first furthest neighbor only.", "e");
-PARAM_STRING_IN("exact_distances_file", "File containing exact distances to "
+PARAM_MATRIX_IN("exact_distances", "Matrix containing exact distances to "
     "furthest neighbors; this can be used to avoid explicit calculation when "
-    "--calculate_error is set.", "x", "");
+    "--calculate_error is set.", "x");
 
 // If we save a model we must also save what type it is.
 class ApproxKFNModel
@@ -110,16 +109,16 @@ int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
-  if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+  if (!CLI::HasParam("reference") && !CLI::HasParam("input_model_file"))
     Log::Fatal << "Either --reference_file (-r) or --input_model_file (-m) must"
         << " be specified!" << endl;
-  if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+  if (CLI::HasParam("reference") && CLI::HasParam("input_model_file"))
     Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
         << " can be specified!" << endl;
   if (!CLI::HasParam("output_model_file") && !CLI::HasParam("k"))
     Log::Warn << "Neither --output_model_file (-M) nor --k (-k) are specified;"
         << " no task will be performed." << endl;
-  if (!CLI::HasParam("neighbors_file") && !CLI::HasParam("distances_file") &&
+  if (!CLI::HasParam("neighbors") && !CLI::HasParam("distances") &&
       !CLI::HasParam("output_model_file"))
     Log::Warn << "None of --output_model_file (-M), --neighbors_file (-n), or "
         << "--distances_file (-d) are specified; no output will be saved!"
@@ -128,8 +127,8 @@ int main(int argc, char** argv)
       CLI::GetParam<string>("algorithm") != "qdafn")
     Log::Fatal << "Unknown algorithm '" << CLI::GetParam<string>("algorithm")
         << "'; must be 'ds' or 'qdafn'!" << endl;
-  if (CLI::HasParam("k") && !(CLI::HasParam("reference_file") ||
-                              CLI::HasParam("query_file")))
+  if (CLI::HasParam("k") && !(CLI::HasParam("reference") ||
+                              CLI::HasParam("query")))
     Log::Fatal << "If search is being performed, then either --query_file "
         << "or --reference_file must be specified!" << endl;
 
@@ -145,26 +144,25 @@ int main(int argc, char** argv)
   if (CLI::HasParam("calculate_error") && !CLI::HasParam("k"))
     Log::Warn << "--calculate_error ignored because --k is not specified."
         << endl;
-  if (CLI::HasParam("exact_distances_file") &&
-      !CLI::HasParam("calculate_error"))
+  if (CLI::HasParam("exact_distances") && !CLI::HasParam("calculate_error"))
     Log::Warn << "--exact_distances_file ignored beceause --calculate_error is "
         << "not specified." << endl;
   if (CLI::HasParam("calculate_error") &&
-      !CLI::HasParam("exact_distances_file") &&
-      !CLI::HasParam("reference_file"))
+      !CLI::HasParam("exact_distances") &&
+      !CLI::HasParam("reference"))
     Log::Fatal << "Cannot calculate error without either --exact_distances_file"
         << " or --reference_file specified!" << endl;
 
   // Do the building of a model, if necessary.
   ApproxKFNModel m;
   arma::mat referenceSet; // This may be used at query time.
-  if (CLI::HasParam("reference_file"))
+  if (CLI::HasParam("reference"))
   {
-    const string referenceFile = CLI::GetParam<string>("reference_file");
-    data::Load(referenceFile, referenceSet);
+    referenceSet = std::move(CLI::GetParam<arma::mat>("reference"));
 
     const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
-    const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+    const size_t numProjections =
+        (size_t) CLI::GetParam<int>("num_projections");
     const string algorithm = CLI::GetParam<string>("algorithm");
 
     if (algorithm == "ds")
@@ -201,12 +199,9 @@ int main(int argc, char** argv)
     arma::Mat<size_t> neighbors;
     arma::mat distances;
 
-    arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
-    if (CLI::HasParam("query_file"))
-    {
-      const string queryFile = CLI::GetParam<string>("query_file");
-      data::Load(queryFile, querySet);
-    }
+    arma::mat& set = CLI::HasParam("query") ? querySet : referenceSet;
+    if (CLI::HasParam("query"))
+      querySet = std::move(CLI::GetParam<arma::mat>("query"));
 
     if (m.type == 0)
     {
@@ -230,10 +225,9 @@ int main(int argc, char** argv)
     if (CLI::HasParam("calculate_error"))
     {
       arma::mat exactDistances;
-      if (CLI::HasParam("exact_distances_file"))
+      if (CLI::HasParam("exact_distances"))
       {
-        data::Load(CLI::GetParam<string>("exact_distances_file"),
-            exactDistances);
+        exactDistances = std::move(CLI::GetParam<string>("exact_distances"));
       }
       else
       {
@@ -259,10 +253,10 @@ int main(int argc, char** argv)
     }
 
     // Save results, if desired.
-    if (CLI::HasParam("neighbors_file"))
-      data::Save(CLI::GetParam<string>("neighbors_file"), neighbors, false);
+    if (CLI::HasParam("neighbors"))
+      CLI::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors);
     if (CLI::HasParam("distances_file"))
-      data::Save(CLI::GetParam<string>("distances_file"), distances, false);
+      CLI::GetParam<arma::mat>("distances") = std::move(distances);
   }
 
   // Should we save the model?




More information about the mlpack-git mailing list