[mlpack-git] master: Refactor main executable. (fbc045a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 23:26:46 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/977afbec0648056124dcb206e0bf972a161d9b51...c0a1500de16d920eb0689e8c052e1aa7e1375c38

>---------------------------------------------------------------

commit fbc045abcd94a95e29b957d8b8fd64725b0cea5d
Author: ryan <ryan at ratml.org>
Date:   Tue Dec 22 23:25:26 2015 -0500

    Refactor main executable.


>---------------------------------------------------------------

fbc045abcd94a95e29b957d8b8fd64725b0cea5d
 src/mlpack/methods/fastmks/fastmks_main.cpp | 342 +++++++++++-----------------
 1 file changed, 132 insertions(+), 210 deletions(-)

diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index ada3eb7..73d81bd 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -7,6 +7,7 @@
 #include <mlpack/core.hpp>
 
 #include "fastmks.hpp"
+#include "fastmks_model.hpp"
 
 using namespace std;
 using namespace mlpack;
@@ -39,25 +40,11 @@ PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
     "This executable performs FastMKS using a cover tree.  The base used to "
     "build the cover tree can be specified with the --base option.");
 
-// Define our input parameters.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
-    "r");
-PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
-
-PARAM_INT_REQ("k", "Number of maximum kernels to find.", "k");
-
-PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
-PARAM_STRING("indices_file", "File to save indices of kernels into.",
-    "i", "");
-
+// Model-building parameters.
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+    "");
 PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
     "'gaussian', 'epanechnikov', 'triangular', 'hyptan'.", "K", "linear");
-
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
-    "dual-tree search.", "S");
-
-// Cover tree parameter.
 PARAM_DOUBLE("base", "Base to use during cover tree construction.", "b", 2.0);
 
 // Kernel parameters.
@@ -68,262 +55,197 @@ PARAM_DOUBLE("bandwidth", "Bandwidth (for Gaussian, Epanechnikov, and "
     "triangular kernels).", "w", 1.0);
 PARAM_DOUBLE("scale", "Scale of kernel (for hyptan kernel).", "s", 1.0);
 
-//! Run FastMKS on a single dataset for the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
-                const bool single,
-                const bool naive,
-                const double base,
-                const size_t k,
-                arma::Mat<size_t>& indices,
-                arma::mat& kernels,
-                KernelType& kernel)
-{
-  if (naive)
-  {
-    // No need for trees.
-    FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
-    fastmks.Search(k, indices, kernels);
-  }
-  else
-  {
-    // Create the tree with the specified base.
-    typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
-        FirstPointIsRoot> TreeType;
-    IPMetric<KernelType> metric(kernel);
-    TreeType tree(referenceData, metric, base);
-
-    // Create FastMKS object.
-    FastMKS<KernelType> fastmks(&tree, single);
-
-    // Now search with it.
-    fastmks.Search(k, indices, kernels);
-  }
-}
-
-//! Run FastMKS for a given query and reference set using the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
-                const arma::mat& queryData,
-                const bool single,
-                const bool naive,
-                const double base,
-                const size_t k,
-                arma::Mat<size_t>& indices,
-                arma::mat& kernels,
-                KernelType& kernel)
-{
-  if (naive)
-  {
-    // No need for trees.
-    FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
-    fastmks.Search(queryData, k, indices, kernels);
-  }
-  else
-  {
-    // Create the tree with the specified base.
-    typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
-        FirstPointIsRoot> TreeType;
-    IPMetric<KernelType> metric(kernel);
-    TreeType referenceTree(referenceData, metric, base);
+// Load/save models.
+PARAM_STRING("input_model_file", "File containing FastMKS model.", "m", "");
+PARAM_STRING("output_model_file", "File to save FastMKS model to.", "M", "");
 
-    // Create FastMKS object.
-    FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&referenceTree,
-        single);
+// Search preferences.
+PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
+PARAM_INT("k", "Number of maximum kernels to find.", "k", 0);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
+    "dual-tree search.", "S");
 
-    // Now search with it.
-    if (single)
-    {
-      fastmks.Search(queryData, k, indices, kernels);
-    }
-    else
-    {
-      TreeType queryTree(queryData, metric, base);
-      fastmks.Search(&queryTree, k, indices, kernels);
-    }
-  }
-}
+PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
+PARAM_STRING("indices_file", "File to save indices of kernels into.",
+    "i", "");
 
 int main(int argc, char** argv)
 {
   CLI::ParseCommandLine(argc, argv);
 
-  // Get reference dataset filename.
-  const string referenceFile = CLI::GetParam<string>("reference_file");
-
-  // The number of max kernel values to find.
-  const size_t k = CLI::GetParam<int>("k");
-
-  // Runtime parameters.
-  const bool naive = CLI::HasParam("naive");
-  const bool single = CLI::HasParam("single");
+  // Validate command-line parameters.
+  if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Cannot specify both --reference_file (-r) and "
+        << "--input_model_file (-m)!" << endl;
 
-  // For cover tree construction.
-  const double base = CLI::GetParam<double>("base");
+  if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "Must specify either --reference_file (-r) or "
+        << "--input_model_file (-m)!" << endl;
 
-  // Kernel parameters.
-  const string kernelType = CLI::GetParam<string>("kernel");
-  const double degree = CLI::GetParam<double>("degree");
-  const double offset = CLI::GetParam<double>("offset");
-  const double bandwidth = CLI::GetParam<double>("bandwidth");
-  const double scale = CLI::GetParam<double>("scale");
-
-  // The datasets.  The query matrix may never be used.
-  arma::mat referenceData;
-  arma::mat queryData;
-
-  data::Load(referenceFile, referenceData, true);
-
-  Log::Info << "Loaded reference data from '" << referenceFile << "' ("
-      << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
-  // Sanity check on k value.
-  if (k > referenceData.n_cols)
+  if (CLI::HasParam("input_model_file"))
   {
-    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
-    Log::Fatal << "than or equal to the number of reference points (";
-    Log::Fatal << referenceData.n_cols << ")." << endl;
+    if (CLI::HasParam("kernel"))
+      Log::Warn << "--kernel (-k) ignored because --input_model_file (-m) is "
+          << "specified." << endl;
+    if (CLI::HasParam("bandwidth"))
+      Log::Warn << "--bandwidth (-w) ignored because --input_model_file (-m) is"
+          << " specified." << endl;
+    if (CLI::HasParam("degree"))
+      Log::Warn << "--degree (-d) ignored because --input_model_file (-m) is "
+          << " specified." << endl;
+    if (CLI::HasParam("offset"))
+      Log::Warn << "--offset (-o) ignored because --input_model_file (-m) is "
+          << " specified." << endl;
   }
 
+  if (!CLI::HasParam("k") &&
+      (CLI::HasParam("indices_file") || CLI::HasParam("kernels_file")))
+    Log::Warn << "--indices_file and --kernels_file ignored, because no search "
+        << "task is specified (i.e., --k is not specified)!" << endl;
+
+  if (CLI::HasParam("k") &&
+      !(CLI::HasParam("indices_file") || CLI::HasParam("kernels_file")))
+    Log::Warn << "Search specified with --k, but no output will be saved "
+        << "because neither --indices_file nor --kernels_file are specified!"
+        << endl;
+
   // Check on kernel type.
+  const string kernelType = CLI::GetParam<string>("kernel");
   if ((kernelType != "linear") && (kernelType != "polynomial") &&
       (kernelType != "cosine") && (kernelType != "gaussian") &&
-      (kernelType != "graph") && (kernelType != "approxGraph") &&
       (kernelType != "triangular") && (kernelType != "hyptan") &&
-      (kernelType != "inv-mq") && (kernelType != "epanechnikov"))
+      (kernelType != "epanechnikov"))
   {
-    Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be ";
-    Log::Fatal << "'linear' or 'polynomial'." << endl;
+    Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be "
+        << "'linear', 'polynomial', 'cosine', 'gaussian', 'triangular', or "
+        << "'epanechnikov'." << endl;
   }
 
-  // Load the query matrix, if we can.
-  if (CLI::HasParam("query_file"))
-  {
-    const string queryFile = CLI::GetParam<string>("query_file");
-    data::Load(queryFile, queryData, true);
+  // Naive mode overrides single mode.
+  if (CLI::HasParam("naive") && CLI::HasParam("single"))
+    Log::Warn << "--single ignored because --naive is present." << endl;
 
-    Log::Info << "Loaded query data from '" << queryFile << "' ("
-        << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-  }
-  else
+  FastMKSModel model;
+  arma::mat referenceData;
+  if (CLI::HasParam("reference_file"))
   {
-    Log::Info << "Using reference dataset as query dataset (--query_file not "
-        << "specified)." << endl;
-  }
+    data::Load(CLI::GetParam<string>("reference_file"), referenceData, true);
 
-  // Naive mode overrides single mode.
-  if (naive && single)
-  {
-    Log::Warn << "--single ignored because --naive is present." << endl;
-  }
+    Log::Info << "Loaded reference data from '"
+        << CLI::GetParam<string>("reference_file") << "' ("
+        << referenceData.n_rows << " x " << referenceData.n_cols << ")."
+        << endl;
 
-  // Matrices for output storage.
-  arma::Mat<size_t> indices;
-  arma::mat kernels;
+    // For cover tree construction.
+    const double base = CLI::GetParam<double>("base");
+
+    // Kernel parameters.
+    const string kernelType = CLI::GetParam<string>("kernel");
+    const double degree = CLI::GetParam<double>("degree");
+    const double offset = CLI::GetParam<double>("offset");
+    const double bandwidth = CLI::GetParam<double>("bandwidth");
+    const double scale = CLI::GetParam<double>("scale");
+
+    // Search preferences.
+    const bool naive = CLI::HasParam("naive");
+    const bool single = CLI::HasParam("single");
 
-  // Construct FastMKS object.
-  if (queryData.n_elem == 0)
-  {
     if (kernelType == "linear")
     {
       LinearKernel lk;
-      RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
-          kernels, lk);
+      model.KernelType() = FastMKSModel::LINEAR_KERNEL;
+      model.BuildModel(referenceData, lk, single, naive, base);
     }
     else if (kernelType == "polynomial")
     {
-
       PolynomialKernel pk(degree, offset);
-      RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
-          indices, kernels, pk);
+      model.KernelType() = FastMKSModel::POLYNOMIAL_KERNEL;
+      model.BuildModel(referenceData, pk, single, naive, base);
     }
     else if (kernelType == "cosine")
     {
       CosineDistance cd;
-      RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
-          kernels, cd);
+      model.KernelType() = FastMKSModel::COSINE_DISTANCE;
+      model.BuildModel(referenceData, cd, single, naive, base);
     }
     else if (kernelType == "gaussian")
     {
       GaussianKernel gk(bandwidth);
-      RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
-          kernels, gk);
+      model.KernelType() = FastMKSModel::GAUSSIAN_KERNEL;
+      model.BuildModel(referenceData, gk, single, naive, base);
     }
     else if (kernelType == "epanechnikov")
     {
       EpanechnikovKernel ek(bandwidth);
-      RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
-          indices, kernels, ek);
+      model.KernelType() = FastMKSModel::EPANECHNIKOV_KERNEL;
+      model.BuildModel(referenceData, ek, single, naive, base);
     }
     else if (kernelType == "triangular")
     {
       TriangularKernel tk(bandwidth);
-      RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
-          indices, kernels, tk);
+      model.KernelType() = FastMKSModel::TRIANGULAR_KERNEL;
+      model.BuildModel(referenceData, tk, single, naive, base);
     }
     else if (kernelType == "hyptan")
     {
       HyperbolicTangentKernel htk(scale, offset);
-      RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
-          indices, kernels, htk);
+      model.KernelType() = FastMKSModel::HYPTAN_KERNEL;
+      model.BuildModel(referenceData, htk, single, naive, base);
     }
   }
   else
   {
-    if (kernelType == "linear")
-    {
-      LinearKernel lk;
-      RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
-          indices, kernels, lk);
-    }
-    else if (kernelType == "polynomial")
-    {
-      PolynomialKernel pk(degree, offset);
-      RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
-          base, k, indices, kernels, pk);
-    }
-    else if (kernelType == "cosine")
-    {
-      CosineDistance cd;
-      RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
-          k, indices, kernels, cd);
-    }
-    else if (kernelType == "gaussian")
+    // Load model from file, then do whatever is necessary.
+    data::Load(CLI::GetParam<string>("input_model_file"), "fastmks_model",
+        model, true);
+  }
+
+  // Set search preferences.
+  model.Naive() = CLI::HasParam("naive");
+  model.SingleMode() = CLI::HasParam("single");
+
+  // Should we do search?
+  if (CLI::HasParam("k"))
+  {
+    arma::mat kernels;
+    arma::Mat<size_t> indices;
+
+    if (CLI::HasParam("query_file"))
     {
-      GaussianKernel gk(bandwidth);
-      RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
-          k, indices, kernels, gk);
+      const string queryFile = CLI::GetParam<string>("query_file");
+      const double base = CLI::GetParam<double>("base");
+
+      arma::mat queryData;
+      data::Load(queryFile, queryData, true);
+
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
+          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+      model.Search(queryData, (size_t) CLI::GetParam<int>("k"), indices,
+          kernels, base);
     }
-    else if (kernelType == "epanechnikov")
+    else
     {
-      EpanechnikovKernel ek(bandwidth);
-      RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
-          base, k, indices, kernels, ek);
+      model.Search((size_t) CLI::GetParam<int>("k"), indices, kernels);
     }
-    else if (kernelType == "triangular")
+
+    // Save output, if we were asked to.
+    if (CLI::HasParam("kernels_file"))
     {
-      TriangularKernel tk(bandwidth);
-      RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
-          base, k, indices, kernels, tk);
+      const string kernelsFile = CLI::GetParam<string>("kernels_file");
+      data::Save(kernelsFile, kernels, false);
     }
-    else if (kernelType == "hyptan")
+
+    if (CLI::HasParam("indices_file"))
     {
-      HyperbolicTangentKernel htk(scale, offset);
-      RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
-          naive, base, k, indices, kernels, htk);
+      const string indicesFile = CLI::GetParam<string>("indices_file");
+      data::Save(indicesFile, indices, false);
     }
   }
 
-  // Save output, if we were asked to.
-  if (CLI::HasParam("kernels_file"))
-  {
-    const string kernelsFile = CLI::GetParam<string>("kernels_file");
-    data::Save(kernelsFile, kernels, false);
-  }
-
-  if (CLI::HasParam("indices_file"))
-  {
-    const string indicesFile = CLI::GetParam<string>("indices_file");
-    data::Save(indicesFile, indices, false);
-  }
+  // Save the model, if requested.
+  if (CLI::HasParam("output_model_file"))
+    data::Save(CLI::GetParam<string>("output_model_file"), "fastmks_model",
+        model);
 }



More information about the mlpack-git mailing list