[mlpack-git] master: Complete refactoring of allknn program. Allows saving models and selection of different tree types. (fecf119)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:56 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262

>---------------------------------------------------------------

commit fecf1194c123ced12d56e7daad761c7b9aaac262
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 19 16:02:01 2015 -0400

    Complete refactoring of allknn program.
    Allows saving models and selection of different tree types.
    
    My plan is to deploy this idea to all of mlpack's dual-tree algorithms; I'll
    also add ball trees.


>---------------------------------------------------------------

fecf1194c123ced12d56e7daad761c7b9aaac262
 src/mlpack/methods/neighbor_search/allknn_main.cpp | 421 ++++++++-------------
 1 file changed, 161 insertions(+), 260 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index c7c2e9c..f6d0413 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -14,6 +14,7 @@
 
 #include "neighbor_search.hpp"
 #include "unmap.hpp"
+#include "ns_model.hpp"
 
 using namespace std;
 using namespace mlpack;
@@ -22,8 +23,8 @@ using namespace mlpack::tree;
 using namespace mlpack::metric;
 
 // Information about the program itself.
-PROGRAM_INFO("All K-Nearest-Neighbors",
-    "This program will calculate the all k-nearest-neighbors of a set of "
+PROGRAM_INFO("k-Nearest-Neighbors",
+    "This program will calculate the k-nearest-neighbors of a set of "
     "points using kd-trees or cover trees (cover tree support is experimental "
     "and may be slow). You may specify a separate set of "
     "reference points and query points, or just a reference set which will be "
@@ -43,26 +44,39 @@ PROGRAM_INFO("All K-Nearest-Neighbors",
     "corresponds to the distance between those two points.");
 
 // Define our input parameters that this program will take.
-PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
-    "r");
-PARAM_STRING_REQ("distances_file", "File to output distances into.", "d");
-PARAM_STRING_REQ("neighbors_file", "File to output neighbors into.", "n");
-
-PARAM_INT_REQ("k", "Number of nearest neighbors to find.", "k");
-
+PARAM_STRING("reference_file", "File containing the reference dataset.", "r",
+    "");
+PARAM_STRING("distances_file", "File to output distances into.", "d", "");
+PARAM_STRING("neighbors_file", "File to output neighbors into.", "n", "");
+
+// The option exists to load or save models.
+PARAM_STRING("input_model_file", "File containing pre-trained kNN model.", "m",
+    "");
+PARAM_STRING("output_model_file", "If specified, the kNN model will be saved "
+    "to the given file.", "M", "");
+
+// The user may specify a query file of query points and a number of nearest
+// neighbors to search for.
 PARAM_STRING("query_file", "File containing query points (optional).", "q", "");
+PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0);
+
+// The user may specify the type of tree to use, and a few parameters for tree
+// building.
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'cover', 'r', 'r-star'.",
+    "t", "kd");
+PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, R "
+    "trees, and R* trees).", "l", 20);
+PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
+    "random orthogonal basis.", "R");
+PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
 
-PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
+// Search settings.
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "S");
-PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
-    "(experimental, may be slow).", "c");
-PARAM_FLAG("r_tree", "If true, use an R*-Tree to perform the search "
-    "(experimental, may be slow.).", "T");
-PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
-    "random orthogonal basis.", "R");
-PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
+
+// Convenience typedef.
+typedef NSModel<NearestNeighborSort> KNNModel;
 
 int main(int argc, char *argv[])
 {
@@ -74,279 +88,166 @@ int main(int argc, char *argv[])
   else
     math::RandomSeed((size_t) std::time(NULL));
 
-  // Get all the parameters.
-  const string referenceFile = CLI::GetParam<string>("reference_file");
-  const string queryFile = CLI::GetParam<string>("query_file");
-
-  const string distancesFile = CLI::GetParam<string>("distances_file");
-  const string neighborsFile = CLI::GetParam<string>("neighbors_file");
-
-  int lsInt = CLI::GetParam<int>("leaf_size");
-
-  size_t k = CLI::GetParam<int>("k");
+  // A user cannot specify both reference data and a model.
+  if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file"))
+    Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)"
+        << " may be specified!" << endl;
 
-  bool naive = CLI::HasParam("naive");
-  bool singleMode = CLI::HasParam("single_mode");
-  const bool randomBasis = CLI::HasParam("random_basis");
+  // A user must specify one of them...
+  if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file"))
+    Log::Fatal << "No model specified (--input_model_file) and no reference "
+        << "data specified (--reference_file)!  One must be provided." << endl;
 
-  arma::mat referenceData;
-  arma::mat queryData; // So it doesn't go out of scope.
-  data::Load(referenceFile, referenceData, true);
-
-  Log::Info << "Loaded reference data from '" << referenceFile << "' ("
-      << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;
-
-  if (queryFile != "")
+  if (CLI::HasParam("input_model_file"))
   {
-    data::Load(queryFile, queryData, true);
-    Log::Info << "Loaded query data from '" << queryFile << "' ("
-      << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+    // Notify the user of parameters that will be ignored.
+    if (CLI::HasParam("tree_type"))
+      Log::Warn << "--tree_type (-t) will be ignored because --input_model_file"
+          << " is specified." << endl;
+    if (CLI::HasParam("leaf_size"))
+      Log::Warn << "--leaf_size (-l) will be ignored because --input_model_file"
+          << " is specified." << endl;
+    if (CLI::HasParam("random_basis"))
+      Log::Warn << "--random_basis (-R) will be ignored because "
+          << "--input_model_file is specified." << endl;
+    if (CLI::HasParam("naive"))
+      Log::Warn << "--naive (-N) will be ignored because --input_model_file "
+          << "is specified" << endl;
   }
 
-  // Sanity check on k value: must be greater than 0, must be less than the
-  // number of reference points.  Since it is unsigned, we only test the upper bound.
-  if (k > referenceData.n_cols)
-  {
-    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
-    Log::Fatal << "than or equal to the number of reference points (";
-    Log::Fatal << referenceData.n_cols << ")." << endl;
-  }
+  // The user should give something to do...
+  if (!CLI::HasParam("k") && !CLI::HasParam("output_model_file"))
+    Log::Warn << "Neither -k nor --output_model_file are specified, so no "
+        << "results from this program will be saved!" << endl;
+
+  // If the user specifies k but no output files, they should be warned.
+  if (CLI::HasParam("k") &&
+      !(CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")))
+    Log::Warn << "Neither --neighbors_file nor --distances_file is specified, "
+        << "so the nearest neighbor search results will not be saved!" << endl;
+
+  // If the user specifies output files but no k, they should be warned.
+  if ((CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")) &&
+      !CLI::HasParam("k"))
+    Log::Warn << "An output file for nearest neighbor search is given ("
+        << "--neighbors_file or --distances_file), but nearest neighbor search "
+        << "is not being performed because k (--k) is not specified!" << endl;
 
   // Sanity check on leaf size.
+  const int lsInt = CLI::GetParam<int>("leaf_size");
   if (lsInt < 1)
   {
     Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
         "than 0." << endl;
   }
-  size_t leafSize = lsInt;
 
-  // Naive mode overrides single mode.
-  if (singleMode && naive)
+  // We either have to load the reference data, or we have to load the model.
+  NSModel<NearestNeighborSort> knn;
+  const bool naive = CLI::HasParam("naive");
+  const bool singleMode = CLI::HasParam("single_mode");
+  if (CLI::HasParam("reference_file"))
   {
-    Log::Warn << "--single_mode ignored because --naive is present." << endl;
-  }
+    // Get all the parameters.
+    const string referenceFile = CLI::GetParam<string>("reference_file");
+    const string treeType = CLI::GetParam<string>("tree_type");
+    const bool randomBasis = CLI::HasParam("random_basis");
+
+    int tree = 0;
+    if (treeType == "kd")
+      tree = KNNModel::KD_TREE;
+    else if (treeType == "cover")
+      tree = KNNModel::COVER_TREE;
+    else if (treeType == "r")
+      tree = KNNModel::R_TREE;
+    else if (treeType == "r-star")
+      tree = KNNModel::R_STAR_TREE;
+    else
+      Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
+          << "'kd', 'cover', 'r', and 'r-star'." << endl;
 
-   // cover_tree overrides r_tree.
-  if (CLI::HasParam("cover_tree") && CLI::HasParam("r_tree"))
-  {
-    Log::Warn << "--cover_tree overrides --r_tree." << endl;
-  }
+    knn.TreeType() = tree;
+    knn.RandomBasis() = randomBasis;
 
-  // See if we want to project onto a random basis.
-  if (randomBasis)
-  {
-    // Generate the random basis.
-    while (true)
-    {
-      // [Q, R] = qr(randn(d, d));
-      // Q = Q * diag(sign(diag(R)));
-      arma::mat q, r;
-      if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
-          referenceData.n_rows)))
-      {
-        arma::vec rDiag(r.n_rows);
-        for (size_t i = 0; i < rDiag.n_elem; ++i)
-        {
-          if (r(i, i) < 0)
-            rDiag(i) = -1;
-          else if (r(i, i) > 0)
-            rDiag(i) = 1;
-          else
-            rDiag(i) = 0;
-        }
-
-        q *= arma::diagmat(rDiag);
-
-        // Check if the determinant is positive.
-        if (arma::det(q) >= 0)
-        {
-          referenceData = q * referenceData;
-          if (queryFile != "")
-            queryData = q * queryData;
-          break;
-        }
-      }
-    }
-  }
+    arma::mat referenceSet;
+    data::Load(referenceFile, referenceSet, true);
 
-  arma::Mat<size_t> neighbors;
-  arma::mat distances;
+    Log::Info << "Loaded reference data from '" << referenceFile << "' ("
+        << referenceSet.n_rows << " x " << referenceSet.n_cols << ")." 
+        << endl;
 
-  if (naive)
-  {
-    AllkNN allknn(referenceData, false, naive);
+    const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
 
-    if (CLI::GetParam<string>("query_file") != "")
-      allknn.Search(queryData, k, neighbors, distances);
-    else
-      allknn.Search(k, neighbors, distances);
+    knn.BuildModel(std::move(referenceSet), leafSize, naive, singleMode);
   }
-  else if (!CLI::HasParam("cover_tree"))
+  else
   {
-    if (!CLI::HasParam("r_tree"))
-    {
-      // We're using the kd-tree.
-      // Mappings for when we build the tree.
-      std::vector<size_t> oldFromNewRefs;
-
-      // Convenience typedef.
-      typedef KDTree<EuclideanDistance, NeighborSearchStat<NearestNeighborSort>,
-          arma::mat> TreeType;
-
-      // Build trees by hand, so we can save memory: if we pass a tree to
-      // NeighborSearch, it does not copy the matrix.
-      Log::Info << "Building reference tree..." << endl;
-      Timer::Start("tree_building");
-      TreeType refTree(referenceData, oldFromNewRefs, leafSize);
-      Timer::Stop("tree_building");
-
-      AllkNN allknn(&refTree, singleMode);
-
-      std::vector<size_t> oldFromNewQueries;
-
-      arma::mat distancesOut;
-      arma::Mat<size_t> neighborsOut;
-
-      if (CLI::GetParam<string>("query_file") != "")
-      {
-        // Build trees by hand, so we can save memory: if we pass a tree to
-        // NeighborSearch, it does not copy the matrix.
-        if (!singleMode)
-        {
-          Log::Info << "Building query tree..." << endl;
-          Timer::Start("tree_building");
-          TreeType queryTree(queryData, oldFromNewQueries, leafSize);
-          Timer::Stop("tree_building");
-          Log::Info << "Tree built." << endl;
-
-          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-          allknn.Search(&queryTree, k, neighborsOut, distancesOut);
-        }
-        else
-        {
-          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-          allknn.Search(queryData, k, neighborsOut, distancesOut);
-        }
-      }
-      else
-      {
-        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-        allknn.Search(k, neighborsOut, distancesOut);
-      }
-
-      Log::Info << "Neighbors computed." << endl;
-
-      // We have to map back to the original indices from before the tree
-      // construction.
-      Log::Info << "Re-mapping indices..." << endl;
-
-      // Map the results back to the correct places.
-      if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
-        Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
-            neighbors, distances);
-      else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-        Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
-      else
-        Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
-            neighbors, distances);
-    }
-    else
-    {
-      // Make sure to notify the user that they are using an r tree.
-      Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
-
-      // Convenience typedef.
-      typedef RStarTree<EuclideanDistance,
-          NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
-
-      // Build tree by hand in order to apply user options.
-      Log::Info << "Building reference tree..." << endl;
-      Timer::Start("tree_building");
-      TreeType refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
-      Timer::Stop("tree_building");
-      Log::Info << "Tree built." << endl;
-
-      typedef NeighborSearch<NearestNeighborSort, EuclideanDistance, arma::mat,
-          RStarTree> AllkNNType;
-      AllkNNType allknn(&refTree, singleMode);
-
-      if (CLI::GetParam<string>("query_file") != "")
-      {
-        // Build trees by hand, so we can save memory: if we pass a tree to
-        // NeighborSearch, it does not copy the matrix.
-        if (!singleMode)
-        {
-          Log::Info << "Building query tree..." << endl;
-          Timer::Start("tree_building");
-          TreeType queryTree(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
-          Timer::Stop("tree_building");
-          Log::Info << "Tree built." << endl;
-
-          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-          allknn.Search(&queryTree, k, neighbors, distances);
-        }
-        else
-        {
-          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-          allknn.Search(queryData, k, neighbors, distances);
-        }
-      }
-      else
-      {
-        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-        allknn.Search(k, neighbors, distances);
-      }
-    }
+    // Load the model from file.
+    const string inputModelFile = CLI::GetParam<string>("input_model_file");
+    data::Load(inputModelFile, "knn_model", knn, true); // Fatal on failure.
+
+    Log::Info << "Loaded kNN model from '" << inputModelFile << "' (trained on "
+        << knn.Dataset().n_rows << "x" << knn.Dataset().n_cols << " dataset)."
+        << endl;
+
+    // Adjust singleMode and naive if necessary.
+    if (CLI::HasParam("single_mode"))
+      knn.SingleMode() = CLI::HasParam("single_mode");
+    if (CLI::HasParam("naive"))
+      knn.Naive() = CLI::HasParam("naive");
+    if (CLI::HasParam("leaf_size"))
+      knn.LeafSize() = (size_t) CLI::GetParam<int>("leaf_size");
   }
-  else // Cover trees.
-  {
-    // Make sure to notify the user that they are using cover trees.
-    Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
-
-    // Convenience typedef.
-    typedef StandardCoverTree<metric::EuclideanDistance,
-        NeighborSearchStat<NearestNeighborSort>, arma::mat> TreeType;
 
-    // Build our reference tree.
-    Log::Info << "Building reference tree..." << endl;
-    Timer::Start("tree_building");
-    TreeType refTree(referenceData, 1.3);
-    Timer::Stop("tree_building");
+  // Perform search, if desired.
+  if (CLI::HasParam("k"))
+  {
+    const string queryFile = CLI::GetParam<string>("query_file");
+    const size_t k = (size_t) CLI::GetParam<int>("k");
 
-    typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        arma::mat, StandardCoverTree> AllkNNType;
-    AllkNNType allknn(&refTree, singleMode);
+    arma::mat queryData;
+    if (queryFile != "")
+    {
+      data::Load(queryFile, queryData, true);
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
+          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+    }
 
-    // See if we have query data.
-    if (CLI::HasParam("query_file"))
+    // Sanity check on k value: must be greater than 0, must be less than the
+    // number of reference points.  Since it is unsigned, we only test the upper
+    // bound.
+    if (k > knn.Dataset().n_cols)
     {
-      // Build query tree.
-      if (!singleMode)
-      {
-        Log::Info << "Building query tree..." << endl;
-        Timer::Start("tree_building");
-        TreeType queryTree(queryData, 1.3);
-        Timer::Stop("tree_building");
-
-        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-        allknn.Search(&queryTree, k, neighbors, distances);
-      }
-      else
-      {
-        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-        allknn.Search(queryData, k, neighbors, distances);
-      }
+      Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
+      Log::Fatal << "than or equal to the number of reference points (";
+      Log::Fatal << knn.Dataset().n_cols << ")." << endl;
     }
-    else
+
+    // Naive mode overrides single mode.
+    if (singleMode && naive)
     {
-      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-      allknn.Search(k, neighbors, distances);
+      Log::Warn << "--single_mode ignored because --naive is present." << endl;
     }
 
-    Log::Info << "Neighbors computed." << endl;
+    // Now run the search.
+    arma::Mat<size_t> neighbors;
+    arma::mat distances;
+
+    if (CLI::HasParam("query_file"))
+      knn.Search(std::move(queryData), k, neighbors, distances);
+    else
+      knn.Search(k, neighbors, distances);
+    Log::Info << "Search complete." << endl;
+
+    // Save output, if desired.
+    if (CLI::HasParam("neighbors_file"))
+      data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
+    if (CLI::HasParam("distances_file"))
+      data::Save(CLI::GetParam<string>("distances_file"), distances);
   }
 
-  // Save put.
-  data::Save(distancesFile, distances);
-  data::Save(neighborsFile, neighbors);
+  if (CLI::HasParam("output_model_file"))
+  {
+    const string outputModelFile = CLI::GetParam<string>("output_model_file");
+    data::Save(outputModelFile, "knn_model", knn);
+  }
 }



More information about the mlpack-git mailing list