[mlpack-git] master: Refactor range search program. Allow model saving/loading. (a9d64fc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Nov 5 12:08:32 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c3b6a9b6cf3907e737f19b544f339690f2098ace...9bd2063f96de9430b387974e7ce7204a1e57a803

>---------------------------------------------------------------

commit a9d64fcd5e78891cf795bf5252a03eed738a33be
Author: ryan <ryan at ratml.org>
Date:   Thu Nov 5 12:07:53 2015 -0500

    Refactor range search program.  Allow model saving/loading.


>---------------------------------------------------------------

a9d64fcd5e78891cf795bf5252a03eed738a33be
 .../methods/range_search/range_search_main.cpp     | 388 ++++++++++-----------
 1 file changed, 189 insertions(+), 199 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 783132c..f012ce8 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -11,6 +11,7 @@
 #include <mlpack/core/tree/cover_tree.hpp>
 
 #include "range_search.hpp"
+#include "rs_model.hpp"
 
 using namespace std;
 using namespace mlpack;
@@ -49,22 +50,36 @@ PROGRAM_INFO("Range Search",
     "regardless of the given extension.");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
-    "r");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_DOUBLE_REQ("max", "Upper bound in range.", "M");
-PARAM_DOUBLE("min", "Lower bound in range.", "m", 0.0);
-
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+    "");
+PARAM_STRING("distances_file", "File to output distances into.", "d", "");
+PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
+// The option exists to load or save models.
+PARAM_STRING("input_model_file", "File containing pre-trained range search "
+    "model.", "m", "");
+PARAM_STRING("output_model_file", "If specified, the range search model will be"
+    " saved to the given file.", "M", "");
+
+// The user may specify a query file of query points and a range to search for.
 PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
-
+PARAM_DOUBLE("max", "Upper bound in range (if not specified, +inf will be "
+    "used.", "U", 0.0);
+PARAM_DOUBLE("min", "Lower bound in range.", "L", 0.0);
+
+// The user may specify the type of tree to use, and a few parameters for tree
+// building.
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star', "
+    "'ball'.", "t", "kd");
 PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+    "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
+
+// Search settings.
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "s");
-PARAM_FLAG("cover_tree", "If true, use a cover tree for range searching "
-    "(instead of a kd-tree).", "c");
 
 typedef RangeSearch<> RSType;
 typedef CoverTree<EuclideanDistance, RangeSearchStat> CoverTreeType;
@@ -76,241 +91,216 @@ int main(int argc, char *argv[])
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
 
-  // Get all the parameters.
-  string referenceFile = CLI::GetParam<string>("reference_file");
+  if (CLI::GetParam<int>("seed") != 0)
+    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
+  else
+    math::RandomSeed((size_t) std::time(NULL));
+
+  // A user cannot specify both reference data and a model.
+  if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
+        << " may be specified!" << endl;
 
-  string distancesFile = CLI::GetParam<string>("distances_file");
-  string neighborsFile = CLI::GetParam<string>("neighbors_file");
+  // A user must specify one of them...
+  if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "No model specified (--input_model_file) and no reference "
+        << "data specified (--reference_file)!  One must be provided." << endl;
 
-  int lsInt = CLI::GetParam<int>("leaf_size");
+  if (CLI::HasParam("input_model_file"))
+  {
+    // Notify the user of parameters that will be ignored.
+    if (CLI::HasParam("tree_type"))
+      Log::Warn << "--tree_type (-t) will be ignored because --input_model_file"
+          << " is specified." << endl;
+    if (CLI::HasParam("leaf_size"))
+      Log::Warn << "--leaf_size (-l) will be ignored because --input_model_file"
+          << " is specified." << endl;
+    if (CLI::HasParam("random_basis"))
+      Log::Warn << "--random_basis (-R) will be ignored because "
+          << "--input_model_file is specified." << endl;
+    if (CLI::HasParam("naive"))
+      Log::Warn << "--naive (-N) will be ignored because --input_model_file is "
+          << "specified." << endl;
+  }
+
+  // The user must give something to do...
+  if (!CLI::HasParam("min") && !CLI::HasParam("max") &&
+      !CLI::HasParam("output_model_file"))
+    Log::Warn << "Neither --min, --max, nor --output_model_file are specified, "
+        << "so no results from this program will be saved!" << endl;
+
+  // If the user specifies a range but not output files, they should be warned.
+  if ((CLI::HasParam("min") || CLI::HasParam("max")) &&
+      !(CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")))
+    Log::Warn << "Neither --neighbors_file nor --distances_file is specified, "
+        << "so the range search results will not be saved!" << endl;
+
+  // If the user specifies output files but no range, they should be warned.
+  if ((CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")) &&
+      !(CLI::HasParam("min") || CLI::HasParam("max")))
+    Log::Warn << "An output file for range search is given (--neighbors_file "
+        << "or --distances_file), but range search is not being performed "
+        << "because neither --min nor --max are specified!  No results will be "
+        << "saved." << endl;
 
-  double max = CLI::GetParam<double>("max");
-  double min = CLI::GetParam<double>("min");
+  // Sanity check on leaf size.
+  int lsInt = CLI::GetParam<int>("leaf_size");
+  if (lsInt < 1)
+    Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater than 0."
+        << endl;
 
+  // We either have to load the reference data, or we have to load the model.
+  RSModel rs;
   const bool naive = CLI::HasParam("naive");
   const bool singleMode = CLI::HasParam("single_mode");
-  bool coverTree = CLI::HasParam("cover_tree");
+  if (CLI::HasParam("reference_file"))
+  {
+    // Get all the parameters.
+    const string referenceFile = CLI::GetParam<string>("reference_file");
+    const string treeType = CLI::GetParam<string>("tree_type");
+    const bool randomBasis = CLI::HasParam("random_basis");
+
+    int tree = 0;
+    if (treeType == "kd")
+      tree = RSModel::KD_TREE;
+    else if (treeType == "cover")
+      tree = RSModel::COVER_TREE;
+    else if (treeType == "r")
+      tree = RSModel::R_TREE;
+    else if (treeType == "r-star")
+      tree = RSModel::R_STAR_TREE;
+    else if (treeType == "ball")
+      tree = RSModel::BALL_TREE;
+    else
+      Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
+          << "'kd', 'cover', 'r', 'r-star', and 'ball'." << endl;
 
-  arma::mat referenceData;
-  arma::mat queryData; // So it doesn't go out of scope.
-  if (!data::Load(referenceFile, referenceData))
-    Log::Fatal << "Reference file " << referenceFile << "not found." << endl;
+    rs.TreeType() = tree;
+    rs.RandomBasis() = randomBasis;
 
-  Log::Info << "Loaded reference data from '" << referenceFile << "'." << endl;
+    arma::mat referenceSet;
+    data::Load(referenceFile, referenceSet, true);
 
-  // Sanity check on range value: max must be greater than min.
-  if (max <= min)
-  {
-    Log::Fatal << "Invalid range: maximum (" << max << ") must be greater than "
-        << "minimum (" << min << ")." << endl;
-  }
-  const math::Range r(min, max);
+    Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+        << referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl;
 
-  // Sanity check on leaf size.
-  if (lsInt < 0)
-  {
-    Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
-        "than or equal to 0." << endl;
-  }
-  const size_t leafSize = lsInt;
+    const size_t leafSize = size_t(lsInt);
 
-  // Naive mode overrides single mode.
-  if (singleMode && naive)
-  {
-    Log::Warn << "--single_mode ignored because --naive is present." << endl;
+    rs.BuildModel(std::move(referenceSet), leafSize, naive, singleMode);
   }
-
-  if (coverTree && naive)
+  else
   {
-    Log::Warn << "--cover_tree ignored because --naive is present." << endl;
-    coverTree = false;
+    // Load the model from file.
+    const string inputModelFile = CLI::GetParam<string>("input_model_file");
+    data::Load(inputModelFile, "rs_model", rs, true); // Fatal on failure.
+
+    Log::Info << "Loaded range search model from '" << inputModelFile << "' ("
+        << "trained on " << rs.Dataset().n_rows << "x" << rs.Dataset().n_cols
+        << " dataset)." << endl;
+
+    // Adjust singleMode and naive if necessary.
+    rs.SingleMode() = CLI::HasParam("single_mode");
+    rs.Naive() = CLI::HasParam("naive");
+    rs.LeafSize() = size_t(lsInt);
   }
 
-  vector<vector<size_t> > neighbors;
-  vector<vector<double> > distances;
-
-  // The cover tree implies different types, so we must split this section.
-  if (naive)
-  {
-    Log::Info << "Performing naive search (no trees)." << endl;
-
-    // Trees don't matter.
-    RangeSearch<> rangeSearch(referenceData, singleMode, naive);
-    rangeSearch.Search(queryData, r, neighbors, distances);
-  }
-  else if (coverTree)
+  // Perform search, if desired.
+  if (CLI::HasParam("min") || CLI::HasParam("max"))
   {
-    Log::Info << "Using cover trees." << endl;
+    const string queryFile = CLI::GetParam<string>("query_file");
+    const double min = CLI::GetParam<double>("min");
+    const double max = CLI::HasParam("max") ? CLI::GetParam<double>("max") :
+        DBL_MAX;
 
-    // This is significantly simpler than kd-tree construction because the data
-    // matrix is not modified.
-    RSCoverType rangeSearch(referenceData, singleMode);
+    math::Range r(min, max);
 
-    if (CLI::GetParam<string>("query_file") == "")
+    arma::mat queryData;
+    if (queryFile != "")
     {
-      // Single dataset.
-      rangeSearch.Search(r, neighbors, distances);
-    }
-    else
-    {
-      // Two datasets.
-      const string queryFile = CLI::GetParam<string>("query_file");
       data::Load(queryFile, queryData, true);
-
-      // Query tree is automatically built if needed.
-      rangeSearch.Search(queryData, r, neighbors, distances);
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
+          << queryData.n_rows << "x" << queryData.n_cols << ")." << endl;
     }
-  }
-  else
-  {
-    typedef KDTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
 
-    // Track mappings.
-    Log::Info << "Building reference tree..." << endl;
-    Timer::Start("tree_building");
-    vector<size_t> oldFromNewRefs;
-    vector<size_t> oldFromNewQueries; // Not used yet.
-    TreeType refTree(referenceData, oldFromNewRefs, leafSize);
-    Timer::Stop("tree_building");
+    // Naive mode overrides single mode.
+    if (singleMode && naive)
+      Log::Warn << "--single_mode ignored because --naive is present." << endl;
 
-    // Collect the results in these vectors before remapping.
-    vector<vector<double> > distancesOut;
-    vector<vector<size_t> > neighborsOut;
+    // Now run the search.
+    vector<vector<size_t>> neighbors;
+    vector<vector<double>> distances;
 
-    RSType rangeSearch(&refTree, singleMode);
-
-    if (CLI::GetParam<string>("query_file") != "")
-    {
-      const string queryFile = CLI::GetParam<string>("query_file");
-      data::Load(queryFile, queryData, true);
+    if (CLI::HasParam("query_file"))
+      rs.Search(std::move(queryData), r, neighbors, distances);
+    else
+      rs.Search(r, neighbors, distances);
 
-      Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
+    Log::Info << "Search complete." << endl;
 
-      if (singleMode)
+    // Save output, if desired.  We have to do this by hand.
+    if (CLI::HasParam("distances_file"))
+    {
+      const string distancesFile = CLI::GetParam<string>("distances_file");
+      fstream distancesStr(distancesFile.c_str(), fstream::out);
+      if (!distancesStr.is_open())
       {
-        Log::Info << "Computing neighbors within range [" << min << ", " << max
-            << "]." << endl;
-        rangeSearch.Search(queryData, r, neighborsOut, distancesOut);
+        Log::Warn << "Cannot open file '" << distancesFile << "' to save output"
+            << " distances to!" << endl;
       }
       else
       {
-        Log::Info << "Building query tree..." << endl;
+        // Loop over each point.
+        for (size_t i = 0; i < distances.size(); ++i)
+        {
+          // Store the distances of each point.  We may have 0 points to store,
+          // so we must account for that possibility.
+          for (size_t j = 0; j + 1 < distances[i].size(); ++j)
+            distancesStr << distances[i][j] << ", ";
 
-        // Build trees by hand, so we can save memory: if we pass a tree to
-        // NeighborSearch, it does not copy the matrix.
-        Timer::Start("tree_building");
-        TreeType queryTree(queryData, oldFromNewQueries, leafSize);
-        Timer::Stop("tree_building");
+          if (distances[i].size() > 0)
+            distancesStr << distances[i][distances[i].size() - 1];
 
-        Log::Info << "Tree built." << endl;
+          distancesStr << endl;
+        }
 
-        Log::Info << "Computing neighbors within range [" << min << ", " << max
-            << "]." << endl;
-        rangeSearch.Search(&queryTree, r, neighborsOut, distancesOut);
+        distancesStr.close();
       }
     }
-    else
-    {
-      Log::Info << "Computing neighbors within range [" << min << ", " << max
-          << "]." << endl;
-      rangeSearch.Search(r, neighborsOut, distancesOut);
-    }
-
-    Log::Info << "Neighbors computed." << endl;
 
-    // We have to map back to the original indices from before the tree
-    // construction.
-    Log::Info << "Re-mapping indices..." << endl;
-
-    distances.resize(distancesOut.size());
-    neighbors.resize(neighborsOut.size());
-
-    // Do the actual remapping.
-    if (CLI::GetParam<string>("query_file") != "")
+    if (CLI::HasParam("neighbors_file"))
     {
-      for (size_t i = 0; i < distances.size(); ++i)
+      const string neighborsFile = CLI::GetParam<string>("neighbors_file");
+      fstream neighborsStr(neighborsFile.c_str(), fstream::out);
+      if (!neighborsStr.is_open())
       {
-        // Map distances (copy a column).
-        distances[oldFromNewQueries[i]] = distancesOut[i];
-
-        // Map indices of neighbors.
-        neighbors[oldFromNewQueries[i]].resize(neighborsOut[i].size());
-        for (size_t j = 0; j < distancesOut[i].size(); ++j)
-        {
-          neighbors[oldFromNewQueries[i]][j] =
-              oldFromNewRefs[neighborsOut[i][j]];
-        }
+        Log::Warn << "Cannot open file '" << neighborsFile << "' to save output"
+            << " neighbor indices to!" << endl;
       }
-    }
-    else
-    {
-      for (size_t i = 0; i < distances.size(); ++i)
+      else
       {
-        // Map distances (copy a column).
-        distances[oldFromNewRefs[i]] = distancesOut[i];
-
-        // Map indices of neighbors.
-        neighbors[oldFromNewRefs[i]].resize(neighborsOut[i].size());
-        for (size_t j = 0; j < distancesOut[i].size(); ++j)
+        // Loop over each point.
+        for (size_t i = 0; i < neighbors.size(); ++i)
         {
-          neighbors[oldFromNewRefs[i]][j] = oldFromNewRefs[neighborsOut[i][j]];
-        }
-      }
-    }
-  }
+          // Store the neighbors of each point.  We may have 0 points to store,
+          // so we must account for that possibility.
+          for (size_t j = 0; j + 1 < neighbors[i].size(); ++j)
+            neighborsStr << neighbors[i][j] << ", ";
 
-  // Save output.  We have to do this by hand.
-  fstream distancesStr(distancesFile.c_str(), fstream::out);
-  if (!distancesStr.is_open())
-  {
-    Log::Warn << "Cannot open file '" << distancesFile << "' to save output "
-        << "distances to!" << endl;
-  }
-  else
-  {
-    // Loop over each point.
-    for (size_t i = 0; i < distances.size(); ++i)
-    {
-      // Store the distances of each point.  We may have 0 points to store, so
-      // we must account for that possibility.
-      for (size_t j = 0; j + 1 < distances[i].size(); ++j)
-      {
-        distancesStr << distances[i][j] << ", ";
-      }
+          if (neighbors[i].size() > 0)
+            neighborsStr << neighbors[i][neighbors[i].size() - 1];
 
-      if (distances[i].size() > 0)
-        distancesStr << distances[i][distances[i].size() - 1];
+          neighborsStr << endl;
+        }
 
-      distancesStr << endl;
+        neighborsStr.close();
+      }
     }
-
-    distancesStr.close();
   }
 
-  fstream neighborsStr(neighborsFile.c_str(), fstream::out);
-  if (!neighborsStr.is_open())
+  // Save the output model, if desired.
+  if (CLI::HasParam("output_model_file"))
   {
-    Log::Warn << "Cannot open file '" << neighborsFile << "' to save output "
-        << "neighbor indices to!" << endl;
-  }
-  else
-  {
-    // Loop over each point.
-    for (size_t i = 0; i < neighbors.size(); ++i)
-    {
-      // Store the neighbors of each point.  We may have 0 points to store, so
-      // we must account for that possibility.
-      for (size_t j = 0; j + 1 < neighbors[i].size(); ++j)
-      {
-        neighborsStr << neighbors[i][j] << ", ";
-      }
-
-      if (neighbors[i].size() > 0)
-        neighborsStr << neighbors[i][neighbors[i].size() - 1];
-
-      neighborsStr << endl;
-    }
-
-    neighborsStr.close();
+    const string outputModelFile = CLI::GetParam<string>("output_model_file");
+    data::Save(outputModelFile, "rs_model", rs);
   }
 }



More information about the mlpack-git mailing list