[mlpack-git] master: Refactor main executable. (fbc045a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 22 23:26:46 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/977afbec0648056124dcb206e0bf972a161d9b51...c0a1500de16d920eb0689e8c052e1aa7e1375c38
>---------------------------------------------------------------
commit fbc045abcd94a95e29b957d8b8fd64725b0cea5d
Author: ryan <ryan at ratml.org>
Date: Tue Dec 22 23:25:26 2015 -0500
Refactor main executable.
>---------------------------------------------------------------
fbc045abcd94a95e29b957d8b8fd64725b0cea5d
src/mlpack/methods/fastmks/fastmks_main.cpp | 342 +++++++++++-----------------
1 file changed, 132 insertions(+), 210 deletions(-)
diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index ada3eb7..73d81bd 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -7,6 +7,7 @@
#include <mlpack/core.hpp>
#include "fastmks.hpp"
+#include "fastmks_model.hpp"
using namespace std;
using namespace mlpack;
@@ -39,25 +40,11 @@ PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
"This executable performs FastMKS using a cover tree. The base used to "
"build the cover tree can be specified with the --base option.");
-// Define our input parameters.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
-PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
-
-PARAM_INT_REQ("k", "Number of maximum kernels to find.", "k");
-
-PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
-PARAM_STRING("indices_file", "File to save indices of kernels into.",
- "i", "");
-
+// Model-building parameters.
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+ "");
PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
"'gaussian', 'epanechnikov', 'triangular', 'hyptan'.", "K", "linear");
-
-PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
-PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
- "dual-tree search.", "S");
-
-// Cover tree parameter.
PARAM_DOUBLE("base", "Base to use during cover tree construction.", "b", 2.0);
// Kernel parameters.
@@ -68,262 +55,197 @@ PARAM_DOUBLE("bandwidth", "Bandwidth (for Gaussian, Epanechnikov, and "
"triangular kernels).", "w", 1.0);
PARAM_DOUBLE("scale", "Scale of kernel (for hyptan kernel).", "s", 1.0);
-//! Run FastMKS on a single dataset for the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
- const bool single,
- const bool naive,
- const double base,
- const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& kernels,
- KernelType& kernel)
-{
- if (naive)
- {
- // No need for trees.
- FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
- fastmks.Search(k, indices, kernels);
- }
- else
- {
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
- FirstPointIsRoot> TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType tree(referenceData, metric, base);
-
- // Create FastMKS object.
- FastMKS<KernelType> fastmks(&tree, single);
-
- // Now search with it.
- fastmks.Search(k, indices, kernels);
- }
-}
-
-//! Run FastMKS for a given query and reference set using the given kernel type.
-template<typename KernelType>
-void RunFastMKS(const arma::mat& referenceData,
- const arma::mat& queryData,
- const bool single,
- const bool naive,
- const double base,
- const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& kernels,
- KernelType& kernel)
-{
- if (naive)
- {
- // No need for trees.
- FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
- fastmks.Search(queryData, k, indices, kernels);
- }
- else
- {
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
- FirstPointIsRoot> TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType referenceTree(referenceData, metric, base);
+// Load/save models.
+PARAM_STRING("input_model_file", "File containing FastMKS model.", "m", "");
+PARAM_STRING("output_model_file", "File to save FastMKS model to.", "M", "");
- // Create FastMKS object.
- FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&referenceTree,
- single);
+// Search preferences.
+PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
+PARAM_INT("k", "Number of maximum kernels to find.", "k", 0);
+PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
+PARAM_FLAG("single", "If true, single-tree search is used (as opposed to "
+ "dual-tree search.", "S");
- // Now search with it.
- if (single)
- {
- fastmks.Search(queryData, k, indices, kernels);
- }
- else
- {
- TreeType queryTree(queryData, metric, base);
- fastmks.Search(&queryTree, k, indices, kernels);
- }
- }
-}
+PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
+PARAM_STRING("indices_file", "File to save indices of kernels into.",
+ "i", "");
int main(int argc, char** argv)
{
CLI::ParseCommandLine(argc, argv);
- // Get reference dataset filename.
- const string referenceFile = CLI::GetParam<string>("reference_file");
-
- // The number of max kernel values to find.
- const size_t k = CLI::GetParam<int>("k");
-
- // Runtime parameters.
- const bool naive = CLI::HasParam("naive");
- const bool single = CLI::HasParam("single");
+ // Validate command-line parameters.
+ if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Cannot specify both --reference_file (-r) and "
+ << "--input_model_file (-m)!" << endl;
- // For cover tree construction.
- const double base = CLI::GetParam<double>("base");
+ if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "Must specify either --reference_file (-r) or "
+ << "--input_model_file (-m)!" << endl;
- // Kernel parameters.
- const string kernelType = CLI::GetParam<string>("kernel");
- const double degree = CLI::GetParam<double>("degree");
- const double offset = CLI::GetParam<double>("offset");
- const double bandwidth = CLI::GetParam<double>("bandwidth");
- const double scale = CLI::GetParam<double>("scale");
-
- // The datasets. The query matrix may never be used.
- arma::mat referenceData;
- arma::mat queryData;
-
- data::Load(referenceFile, referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value.
- if (k > referenceData.n_cols)
+ if (CLI::HasParam("input_model_file"))
{
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
+ if (CLI::HasParam("kernel"))
+ Log::Warn << "--kernel (-k) ignored because --input_model_file (-m) is "
+ << "specified." << endl;
+ if (CLI::HasParam("bandwidth"))
+ Log::Warn << "--bandwidth (-w) ignored because --input_model_file (-m) is"
+ << " specified." << endl;
+ if (CLI::HasParam("degree"))
+ Log::Warn << "--degree (-d) ignored because --input_model_file (-m) is "
+ << " specified." << endl;
+ if (CLI::HasParam("offset"))
+ Log::Warn << "--offset (-o) ignored because --input_model_file (-m) is "
+ << " specified." << endl;
}
+ if (!CLI::HasParam("k") &&
+ (CLI::HasParam("indices_file") || CLI::HasParam("kernels_file")))
+ Log::Warn << "--indices_file and --kernels_file ignored, because no search "
+ << "task is specified (i.e., --k is not specified)!" << endl;
+
+ if (CLI::HasParam("k") &&
+ !(CLI::HasParam("indices_file") || CLI::HasParam("kernels_file")))
+ Log::Warn << "Search specified with --k, but no output will be saved "
+ << "because neither --indices_file nor --kernels_file are specified!"
+ << endl;
+
// Check on kernel type.
+ const string kernelType = CLI::GetParam<string>("kernel");
if ((kernelType != "linear") && (kernelType != "polynomial") &&
(kernelType != "cosine") && (kernelType != "gaussian") &&
- (kernelType != "graph") && (kernelType != "approxGraph") &&
(kernelType != "triangular") && (kernelType != "hyptan") &&
- (kernelType != "inv-mq") && (kernelType != "epanechnikov"))
+ (kernelType != "epanechnikov"))
{
- Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be ";
- Log::Fatal << "'linear' or 'polynomial'." << endl;
+ Log::Fatal << "Invalid kernel type: '" << kernelType << "'; must be "
+ << "'linear', 'polynomial', 'cosine', 'gaussian', 'triangular', or "
+ << "'epanechnikov'." << endl;
}
- // Load the query matrix, if we can.
- if (CLI::HasParam("query_file"))
- {
- const string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile, queryData, true);
+ // Naive mode overrides single mode.
+ if (CLI::HasParam("naive") && CLI::HasParam("single"))
+ Log::Warn << "--single ignored because --naive is present." << endl;
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- }
- else
+ FastMKSModel model;
+ arma::mat referenceData;
+ if (CLI::HasParam("reference_file"))
{
- Log::Info << "Using reference dataset as query dataset (--query_file not "
- << "specified)." << endl;
- }
+ data::Load(CLI::GetParam<string>("reference_file"), referenceData, true);
- // Naive mode overrides single mode.
- if (naive && single)
- {
- Log::Warn << "--single ignored because --naive is present." << endl;
- }
+ Log::Info << "Loaded reference data from '"
+ << CLI::GetParam<string>("reference_file") << "' ("
+ << referenceData.n_rows << " x " << referenceData.n_cols << ")."
+ << endl;
- // Matrices for output storage.
- arma::Mat<size_t> indices;
- arma::mat kernels;
+ // For cover tree construction.
+ const double base = CLI::GetParam<double>("base");
+
+ // Kernel parameters.
+ const string kernelType = CLI::GetParam<string>("kernel");
+ const double degree = CLI::GetParam<double>("degree");
+ const double offset = CLI::GetParam<double>("offset");
+ const double bandwidth = CLI::GetParam<double>("bandwidth");
+ const double scale = CLI::GetParam<double>("scale");
+
+ // Search preferences.
+ const bool naive = CLI::HasParam("naive");
+ const bool single = CLI::HasParam("single");
- // Construct FastMKS object.
- if (queryData.n_elem == 0)
- {
if (kernelType == "linear")
{
LinearKernel lk;
- RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
- kernels, lk);
+ model.KernelType() = FastMKSModel::LINEAR_KERNEL;
+ model.BuildModel(referenceData, lk, single, naive, base);
}
else if (kernelType == "polynomial")
{
-
PolynomialKernel pk(degree, offset);
- RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
- indices, kernels, pk);
+ model.KernelType() = FastMKSModel::POLYNOMIAL_KERNEL;
+ model.BuildModel(referenceData, pk, single, naive, base);
}
else if (kernelType == "cosine")
{
CosineDistance cd;
- RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
- kernels, cd);
+ model.KernelType() = FastMKSModel::COSINE_DISTANCE;
+ model.BuildModel(referenceData, cd, single, naive, base);
}
else if (kernelType == "gaussian")
{
GaussianKernel gk(bandwidth);
- RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
- kernels, gk);
+ model.KernelType() = FastMKSModel::GAUSSIAN_KERNEL;
+ model.BuildModel(referenceData, gk, single, naive, base);
}
else if (kernelType == "epanechnikov")
{
EpanechnikovKernel ek(bandwidth);
- RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
- indices, kernels, ek);
+ model.KernelType() = FastMKSModel::EPANECHNIKOV_KERNEL;
+ model.BuildModel(referenceData, ek, single, naive, base);
}
else if (kernelType == "triangular")
{
TriangularKernel tk(bandwidth);
- RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
- indices, kernels, tk);
+ model.KernelType() = FastMKSModel::TRIANGULAR_KERNEL;
+ model.BuildModel(referenceData, tk, single, naive, base);
}
else if (kernelType == "hyptan")
{
HyperbolicTangentKernel htk(scale, offset);
- RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
- indices, kernels, htk);
+ model.KernelType() = FastMKSModel::HYPTAN_KERNEL;
+ model.BuildModel(referenceData, htk, single, naive, base);
}
}
else
{
- if (kernelType == "linear")
- {
- LinearKernel lk;
- RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
- indices, kernels, lk);
- }
- else if (kernelType == "polynomial")
- {
- PolynomialKernel pk(degree, offset);
- RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
- base, k, indices, kernels, pk);
- }
- else if (kernelType == "cosine")
- {
- CosineDistance cd;
- RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
- k, indices, kernels, cd);
- }
- else if (kernelType == "gaussian")
+ // Load model from file, then do whatever is necessary.
+ data::Load(CLI::GetParam<string>("input_model_file"), "fastmks_model",
+ model, true);
+ }
+
+ // Set search preferences.
+ model.Naive() = CLI::HasParam("naive");
+ model.SingleMode() = CLI::HasParam("single");
+
+ // Should we do search?
+ if (CLI::HasParam("k"))
+ {
+ arma::mat kernels;
+ arma::Mat<size_t> indices;
+
+ if (CLI::HasParam("query_file"))
{
- GaussianKernel gk(bandwidth);
- RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
- k, indices, kernels, gk);
+ const string queryFile = CLI::GetParam<string>("query_file");
+ const double base = CLI::GetParam<double>("base");
+
+ arma::mat queryData;
+ data::Load(queryFile, queryData, true);
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ model.Search(queryData, (size_t) CLI::GetParam<int>("k"), indices,
+ kernels, base);
}
- else if (kernelType == "epanechnikov")
+ else
{
- EpanechnikovKernel ek(bandwidth);
- RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
- base, k, indices, kernels, ek);
+ model.Search((size_t) CLI::GetParam<int>("k"), indices, kernels);
}
- else if (kernelType == "triangular")
+
+ // Save output, if we were asked to.
+ if (CLI::HasParam("kernels_file"))
{
- TriangularKernel tk(bandwidth);
- RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
- base, k, indices, kernels, tk);
+ const string kernelsFile = CLI::GetParam<string>("kernels_file");
+ data::Save(kernelsFile, kernels, false);
}
- else if (kernelType == "hyptan")
+
+ if (CLI::HasParam("indices_file"))
{
- HyperbolicTangentKernel htk(scale, offset);
- RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
- naive, base, k, indices, kernels, htk);
+ const string indicesFile = CLI::GetParam<string>("indices_file");
+ data::Save(indicesFile, indices, false);
}
}
- // Save output, if we were asked to.
- if (CLI::HasParam("kernels_file"))
- {
- const string kernelsFile = CLI::GetParam<string>("kernels_file");
- data::Save(kernelsFile, kernels, false);
- }
-
- if (CLI::HasParam("indices_file"))
- {
- const string indicesFile = CLI::GetParam<string>("indices_file");
- data::Save(indicesFile, indices, false);
- }
+ // Save the model, if requested.
+ if (CLI::HasParam("output_model_file"))
+ data::Save(CLI::GetParam<string>("output_model_file"), "fastmks_model",
+ model);
}
More information about the mlpack-git
mailing list