[mlpack-git] master: Refactor main program. (bb63250)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 9 16:54:50 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a39d474593067343b4972d4a5217bcfae84ca5d...dd7c8b93fe5f299cb534cda70c1c786456f9a78f
>---------------------------------------------------------------
commit bb63250ea78839602b02d3a70ea41cd33849c5ee
Author: ryan <ryan at ratml.org>
Date: Wed Dec 9 16:54:11 2015 -0500
Refactor main program.
>---------------------------------------------------------------
bb63250ea78839602b02d3a70ea41cd33849c5ee
src/mlpack/methods/rann/allkrann_main.cpp | 313 +++++++++++++++++-------------
1 file changed, 174 insertions(+), 139 deletions(-)
diff --git a/src/mlpack/methods/rann/allkrann_main.cpp b/src/mlpack/methods/rann/allkrann_main.cpp
index c5add09..363aec1 100644
--- a/src/mlpack/methods/rann/allkrann_main.cpp
+++ b/src/mlpack/methods/rann/allkrann_main.cpp
@@ -14,6 +14,7 @@
#include <iostream>
#include "ra_search.hpp"
+#include "ra_model.hpp"
#include <mlpack/methods/neighbor_search/unmap.hpp>
using namespace std;
@@ -50,187 +51,221 @@ PROGRAM_INFO("All K-Rank-Approximate-Nearest-Neighbors",
"corresponds to the distance between those two points.");
// Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
- "r");
+PARAM_STRING("reference_file", "File containing the reference dataset.",
+ "r", "");
PARAM_STRING("distances_file", "File to output distances into.", "d", "");
PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-
-PARAM_STRING("query_file", "File containing query points (optional).",
- "q", "");
-
+// The option exists to load or save models.
+PARAM_STRING("input_model_file", "File containing pre-trained kNN model.", "m",
+ "");
+PARAM_STRING("output_model_file", "If specified, the kNN model will be saved "
+ "to the given file.", "M", "");
+
+// The user may specify a query file of query points and a number of nearest
+// neighbors to search for.
+PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0);
+
+// The user may specify the type of tree to use, and a few parameters for tree
+// building.
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', or "
+ "'r-star'.", "t", "kd");
+PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R "
+ "trees, and R* trees).", "l", 20);
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+ "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
+
+// Search options.
PARAM_DOUBLE("tau", "The allowed rank-error in terms of the percentile of "
"the data.", "t", 5);
PARAM_DOUBLE("alpha", "The desired success probability.", "a", 0.95);
-
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
PARAM_FLAG("naive", "If true, sampling will be done without using a tree.",
"N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search.", "s");
-
PARAM_FLAG("sample_at_leaves", "The flag to trigger sampling at leaves.", "L");
PARAM_FLAG("first_leaf_exact", "The flag to trigger sampling only after "
"exactly exploring the first leaf.", "X");
PARAM_INT("single_sample_limit", "The limit on the maximum number of "
"samples (and hence the largest node you can approximate).", "S", 20);
+// Convenience typedef.
+typedef RAModel<NearestNeighborSort> RANNModel;
+
int main(int argc, char *argv[])
{
// Give CLI the command line parameters the user passed in.
CLI::ParseCommandLine(argc, argv);
- math::RandomSeed(time(NULL));
-
- // Get all the parameters.
- string referenceFile = CLI::GetParam<string>("reference_file");
- string distancesFile = CLI::GetParam<string>("distances_file");
- string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
- int lsInt = CLI::GetParam<int>("leaf_size");
- size_t singleSampleLimit = CLI::GetParam<int>("single_sample_limit");
-
- size_t k = CLI::GetParam<int>("k");
-
- double tau = CLI::GetParam<double>("tau");
- double alpha = CLI::GetParam<double>("alpha");
-
- bool naive = CLI::HasParam("naive");
- bool singleMode = CLI::HasParam("single_mode");
- bool sampleAtLeaves = CLI::HasParam("sample_at_leaves");
- bool firstLeafExact = CLI::HasParam("first_leaf_exact");
-
- arma::mat referenceData;
- arma::mat queryData; // So it doesn't go out of scope.
- data::Load(referenceFile, referenceData, true);
-
- Log::Info << "Loaded reference data from '" << referenceFile << "' ("
- << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
- // Sanity check on k value: must be greater than 0, must be less than the
- // number of reference points.
- if (k > referenceData.n_cols)
- {
- Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
- Log::Fatal << "than or equal to the number of reference points (";
- Log::Fatal << referenceData.n_cols << ")." << endl;
- }
-
- // Load query data, if necessary.
- if (CLI::HasParam("query_file"))
+ if (CLI::GetParam<int>("seed") != 0)
+ math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+ else
+ math::RandomSeed((size_t) std::time(NULL));
+ // A user cannot specify both reference data and a model.
+ if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+ Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
+ << " may be specified!" << endl;
+
+ // A user must specify one of them...
+ if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+ Log::Fatal << "No model specified (--input_model_file) and no reference "
+ << "data specified (--reference_file)! One must be provided." << endl;
+
+ if (CLI::HasParam("input_model_file"))
{
- const string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile, queryData, true);
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ // Notify the user of parameters that will be ignored.
+ if (CLI::HasParam("tree_type"))
+ Log::Warn << "--tree_type (-t) will be ignored because --input_model_file"
+ << " is specified." << endl;
+ if (CLI::HasParam("leaf_size"))
+ Log::Warn << "--leaf_size (-l) will be ignored because --input_model_file"
+ << " is specified." << endl;
+ if (CLI::HasParam("random_basis"))
+ Log::Warn << "--random_basis (-R) will be ignored because "
+ << "--input_model_file is specified." << endl;
+ if (CLI::HasParam("naive"))
+ Log::Warn << "--naive (-N) will be ignored because --input_model_file is "
+ << "specified." << endl;
}
- // Sanity check on the value of 'tau' with respect to 'k' so that
- // 'k' neighbors are not requested from the top-'rank_error' neighbors
- // where 'rank_error' <= 'k'.
- size_t rank_error = (size_t) ceil(tau *
- (double) referenceData.n_cols / 100.0);
- if (rank_error <= k)
- Log::Fatal << "Invalid 'tau' (" << tau << ") - k (" << k << ") " <<
- "combination. Increase 'tau' or decrease 'k'." << endl;
+ // The user should give something to do...
+ if (!CLI::HasParam("k") && !CLI::HasParam("output_model_file"))
+ Log::Warn << "Neither -k nor --output_model_file are specified, so no "
+ << "results from this program will be saved!" << endl;
+
+ // If the user specifies k but no output files, they should be warned.
+ if (CLI::HasParam("k") &&
+ !(CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")))
+ Log::Warn << "Neither --neighbors_file nor --distances_file is specified, "
+ << "so the nearest neighbor search results will not be saved!" << endl;
+
+ // If the user specifies output files but no k, they should be warned.
+ if ((CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")) &&
+ !CLI::HasParam("k"))
+ Log::Warn << "An output file for nearest neighbor search is given ("
+ << "--neighbors_file or --distances_file), but nearest neighbor search "
+ << "is not being performed because k (--k) is not specified! No "
+ << "results will be saved." << endl;
// Sanity check on leaf size.
- if (lsInt < 0)
+ const int lsInt = CLI::GetParam<int>("leaf_size");
+ if (lsInt < 1)
+ {
Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
- "than or equal to 0." << endl;
- size_t leafSize = lsInt;
-
- // Naive mode overrides single mode.
- if (singleMode && naive)
- Log::Warn << "--single_mode ignored because --naive is present." << endl;
-
- // The actual output after the remapping.
- arma::Mat<size_t> neighbors;
- arma::mat distances;
+ "than 0." << endl;
+ }
- if (naive)
+ // We either have to load the reference data, or we have to load the model.
+ RANNModel rann;
+ const bool naive = CLI::HasParam("naive");
+ const bool singleMode = CLI::HasParam("single_mode");
+ if (CLI::HasParam("reference_file"))
{
- AllkRANN allkrann(referenceData, naive, false, tau, alpha);
+ // Get all the parameters.
+ const string referenceFile = CLI::GetParam<string>("reference_file");
+ const string treeType = CLI::GetParam<string>("tree_type");
+ const bool randomBasis = CLI::HasParam("random_basis");
+
+ int tree = 0;
+ if (treeType == "kd")
+ tree = RANNModel::KD_TREE;
+ else if (treeType == "cover")
+ tree = RANNModel::COVER_TREE;
+ else if (treeType == "r")
+ tree = RANNModel::R_TREE;
+ else if (treeType == "r-star")
+ tree = RANNModel::R_STAR_TREE;
+ else
+ Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
+ << "'kd', 'cover', 'r', and 'r-star'." << endl;
- Log::Info << "Computing " << k << " nearest neighbors " << "with "
- << tau << "% rank approximation..." << endl;
+ rann.TreeType() = tree;
+ rann.RandomBasis() = randomBasis;
- if (CLI::GetParam<string>("query_file") != "")
- allkrann.Search(queryData, k, neighbors, distances);
- else
- allkrann.Search(k, neighbors, distances);
+ arma::mat referenceSet;
+ data::Load(referenceFile, referenceSet, true);
- Log::Info << "Neighbors computed." << endl;
+ Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+ << referenceSet.n_rows << " x " << referenceSet.n_cols << ")."
+ << endl;
+
+ rann.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode);
}
else
{
- // The results output by the AllkRANN class are
- // shuffled if the tree construction shuffles the point sets.
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
-
- // Mappings for when we build the tree.
- std::vector<size_t> oldFromNewRefs;
- std::vector<size_t> oldFromNewQueries;
-
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("tree_building");
- typedef KDTree<EuclideanDistance, RAQueryStat<NearestNeighborSort>,
- arma::mat> TreeType;
- TreeType refTree(referenceData, oldFromNewRefs, leafSize);
- Timer::Stop("tree_building");
-
- // Because we may construct it differently, we need a pointer.
- AllkRANN allkrann(&refTree, singleMode, tau, alpha, sampleAtLeaves,
- firstLeafExact, singleSampleLimit);
-
- if (CLI::HasParam("query_file") && !singleMode)
+ // Load the model from file.
+ const string inputModelFile = CLI::GetParam<string>("input_model_file");
+ data::Load(inputModelFile, "rann_model", rann, true); // Fatal on failure.
+
+ Log::Info << "Loaded rank-approximate kNN model from '" << inputModelFile
+ << "' (trained on " << rann.Dataset().n_rows << "x"
+ << rann.Dataset().n_cols << " dataset)." << endl;
+
+ // Adjust singleMode and naive if necessary.
+ rann.SingleMode() = CLI::HasParam("single_mode");
+ rann.Naive() = CLI::HasParam("naive");
+ rann.LeafSize() = size_t(lsInt);
+ }
+
+ // Apply the parameters for search.
+ if (CLI::HasParam("tau"))
+ rann.Tau() = CLI::GetParam<double>("tau");
+ if (CLI::HasParam("alpha"))
+ rann.Alpha() = CLI::GetParam<double>("alpha");
+ if (CLI::HasParam("single_sample_limit"))
+ rann.SingleSampleLimit() = CLI::GetParam<double>("single_sample_limit");
+ rann.SampleAtLeaves() = CLI::HasParam("sample_at_leaves");
+ rann.FirstLeafExact() = CLI::HasParam("sample_at_leaves");
+
+ // Perform search, if desired.
+ if (CLI::HasParam("k"))
+ {
+ const string queryFile = CLI::GetParam<string>("query_file");
+ const size_t k = (size_t) CLI::GetParam<int>("k");
+
+ arma::mat queryData;
+ if (queryFile != "")
{
- Log::Info << "Building query tree..." << endl;
- Timer::Start("tree_building");
- TreeType queryTree(queryData, oldFromNewQueries, leafSize);
- Timer::Stop("tree_building");
- Log::Info << "Tree built." << endl;
-
- Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
- tau << "% rank approximation..." << endl;
- allkrann.Search(&queryTree, k, neighborsOut, distancesOut);
+ data::Load(queryFile, queryData, true);
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << "x" << queryData.n_cols << ")." << endl;
}
- else if (CLI::HasParam("query_file") && singleMode)
+
+ // Sanity check on k value: must be greater than 0, must be less than the
+ // number of reference points. Since it is unsigned, we only test the upper
+ // bound.
+ if (k > rann.Dataset().n_cols)
{
- Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
- tau << "% rank approximation..." << endl;
- allkrann.Search(queryData, k, neighborsOut, distancesOut);
+ Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+ Log::Fatal << "than or equal to the number of reference points (";
+ Log::Fatal << rann.Dataset().n_cols << ")." << endl;
}
- else
+
+ // Naive mode overrides single mode.
+ if (singleMode && naive)
{
- Log::Info << "Computing " << k << " nearest neighbors " << "with " <<
- tau << "% rank approximation..." << endl;
- allkrann.Search(k, neighborsOut, distancesOut);
+ Log::Warn << "--single_mode ignored because --naive is present." << endl;
}
- Log::Info << "Neighbors computed." << endl;
-
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
-
- // Map the results back to the correct places.
- if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
- neighbors, distances);
- else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ if (CLI::HasParam("query_file"))
+ rann.Search(std::move(queryData), k, neighbors, distances);
else
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
- neighbors, distances);
+ rann.Search(k, neighbors, distances);
+ Log::Info << "Search complete." << endl;
+
+ // Save output, if desired.
+ if (CLI::HasParam("neighbors_file"))
+ data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
+ if (CLI::HasParam("distances_file"))
+ data::Save(CLI::GetParam<string>("distances_file"), distances);
}
- // Save output.
- if (distancesFile != "")
- data::Save(distancesFile, distances);
- if (neighborsFile != "")
- data::Save(neighborsFile, neighbors);
+ if (CLI::HasParam("output_model_file"))
+ {
+ const string outputModelFile = CLI::GetParam<string>("output_model_file");
+ data::Save(outputModelFile, "rann_model", rann);
+ }
}
More information about the mlpack-git
mailing list