[mlpack-svn] r15624 - mlpack/trunk/src/mlpack/methods/range_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Aug 15 15:17:27 EDT 2013


Author: rcurtin
Date: Thu Aug 15 15:17:27 2013
New Revision: 15624

Log:
Add --cover_tree option to range_search.


Modified:
   mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp

Modified: mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp	(original)
+++ mlpack/trunk/src/mlpack/methods/range_search/range_search_main.cpp	Thu Aug 15 15:17:27 2013
@@ -8,6 +8,7 @@
  */
 #include <mlpack/core.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 
 #include "range_search.hpp"
 
@@ -60,10 +61,12 @@
 PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
-    "dual-tree search.", "s");
+    "dual-tree search).", "s");
+PARAM_FLAG("cover_tree", "If true, use a cover tree for range searching "
+    "(instead of a kd-tree).", "c");
 
-typedef RangeSearch<metric::SquaredEuclideanDistance,
-    BinarySpaceTree<bound::HRectBound<2>, EmptyStatistic> > RSType;
+typedef RangeSearch<> RSType;
+typedef RangeSearch<metric::EuclideanDistance, CoverTree<> > RSCoverType;
 
 int main(int argc, char *argv[])
 {
@@ -81,8 +84,9 @@
   double max = CLI::GetParam<double>("max");
   double min = CLI::GetParam<double>("min");
 
-  bool naive = CLI::HasParam("naive");
-  bool singleMode = CLI::HasParam("single_mode");
+  const bool naive = CLI::HasParam("naive");
+  const bool singleMode = CLI::HasParam("single_mode");
+  bool coverTree = CLI::HasParam("cover_tree");
 
   arma::mat referenceData;
   arma::mat queryData; // So it doesn't go out of scope.
@@ -115,117 +119,164 @@
   if (naive)
     leafSize = referenceData.n_cols;
 
+  if (coverTree && naive)
+  {
+    Log::Warn << "--cover_tree ignored because --naive is present." << endl;
+    coverTree = false;
+  }
+
   vector<vector<size_t> > neighbors;
   vector<vector<double> > distances;
 
-  // Because we may construct it differently, we need a pointer.
-  RSType* rangeSearch = NULL;
+  // The cover tree implies different types, so we must split this section.
+  if (CLI::HasParam("cover_tree"))
+  {
+    Log::Info << "Using cover trees." << endl;
 
-  // Mappings for when we build the tree.
-  vector<size_t> oldFromNewRefs;
+    // This is significantly simpler than kd-tree construction because the data
+    // matrix is not modified.
+    RSCoverType* rangeSearch = NULL;
+    CoverTree<> referenceTree(referenceData);
+    CoverTree<>* queryTree = NULL;
 
-  // Build trees by hand, so we can save memory: if we pass a tree to
-  // NeighborSearch, it does not copy the matrix.
-  Log::Info << "Building reference tree..." << endl;
-  Timer::Start("tree_building");
+    if (CLI::GetParam<string>("query_file") == "")
+    {
+      // Single dataset.
+      rangeSearch = new RSCoverType(&referenceTree, referenceData, singleMode);
+    }
+    else
+    {
+      // Two datasets.
+      const string queryFile = CLI::GetParam<string>("query_file");
+      data::Load(queryFile, queryData, true);
+      queryTree = new CoverTree<>(queryData);
 
-  BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
-      refTree(referenceData, oldFromNewRefs, leafSize);
-  BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
-      queryTree = NULL; // Empty for now.
+      rangeSearch = new RSCoverType(&referenceTree, queryTree, referenceData,
+          queryData, singleMode);
+    }
 
-  Timer::Stop("tree_building");
+    Log::Info << "Trees built." << endl;
 
-  std::vector<size_t> oldFromNewQueries;
+    const math::Range r(min, max);
+    rangeSearch->Search(r, neighbors, distances);
 
-  if (CLI::GetParam<string>("query_file") != "")
+    if (queryTree)
+      delete queryTree;
+    delete rangeSearch;
+  }
+  else
   {
-    string queryFile = CLI::GetParam<string>("query_file");
+    // Because we may construct it differently, we need a pointer.
+    RSType* rangeSearch = NULL;
 
-    if (!data::Load(queryFile, queryData))
-      Log::Fatal << "Query file " << queryFile << " not found" << endl;
-
-    if (naive && leafSize < queryData.n_cols)
-      leafSize = queryData.n_cols;
-
-    Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
-
-    Log::Info << "Building query tree..." << endl;
+    // Mappings for when we build the tree.
+    vector<size_t> oldFromNewRefs;
 
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
+    Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
 
-    queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-        tree::EmptyStatistic >(queryData, oldFromNewQueries,
-        leafSize);
+    BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>
+        refTree(referenceData, oldFromNewRefs, leafSize);
+    BinarySpaceTree<bound::HRectBound<2>, tree::EmptyStatistic>*
+        queryTree = NULL; // Empty for now.
 
     Timer::Stop("tree_building");
 
-    rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
-        singleMode);
+    vector<size_t> oldFromNewQueries;
 
-    Log::Info << "Tree built." << endl;
-  }
-  else
-  {
-    rangeSearch = new RSType(&refTree, referenceData, singleMode);
+    if (CLI::GetParam<string>("query_file") != "")
+    {
+      const string queryFile = CLI::GetParam<string>("query_file");
+      data::Load(queryFile, queryData, true);
 
-    Log::Info << "Trees built." << endl;
-  }
+      if (naive && leafSize < queryData.n_cols)
+        leafSize = queryData.n_cols;
 
-  Log::Info << "Computing neighbors within range [" << min << ", " << max
-      << "]." << endl;
+      Log::Info << "Loaded query data from '" << queryFile << "'." << endl;
 
-  math::Range r = math::Range(min, max);
-  rangeSearch->Search(r, neighbors, distances);
+      Log::Info << "Building query tree..." << endl;
 
-  Log::Info << "Neighbors computed." << endl;
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
+      Timer::Start("tree_building");
 
-  // We have to map back to the original indices from before the tree
-  // construction.
-  Log::Info << "Re-mapping indices..." << endl;
+      queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+          tree::EmptyStatistic >(queryData, oldFromNewQueries,
+          leafSize);
 
-  vector<vector<double> > distancesOut;
-  distancesOut.resize(distances.size());
-  vector<vector<size_t> > neighborsOut;
-  neighborsOut.resize(neighbors.size());
+      Timer::Stop("tree_building");
 
-  // Do the actual remapping.
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    for (size_t i = 0; i < distances.size(); ++i)
+      rangeSearch = new RSType(&refTree, queryTree, referenceData, queryData,
+          singleMode);
+
+      Log::Info << "Tree built." << endl;
+    }
+    else
     {
-      // Map distances (copy a column).
-      distancesOut[oldFromNewQueries[i]] = distances[i];
+      rangeSearch = new RSType(&refTree, referenceData, singleMode);
+
+      Log::Info << "Trees built." << endl;
+    }
+
+    Log::Info << "Computing neighbors within range [" << min << ", " << max
+        << "]." << endl;
+
+    // Collect the results in these vectors before remapping.
+    vector<vector<double> > distancesOut;
+    vector<vector<size_t> > neighborsOut;
 
-      // Map indices of neighbors.
-      neighborsOut[oldFromNewQueries[i]].resize(neighbors[i].size());
-      for (size_t j = 0; j < distances[i].size(); ++j)
+    const math::Range r(min, max);
+    rangeSearch->Search(r, neighborsOut, distancesOut);
+
+    Log::Info << "Neighbors computed." << endl;
+
+    // We have to map back to the original indices from before the tree
+    // construction.
+    Log::Info << "Re-mapping indices..." << endl;
+
+    distances.resize(distancesOut.size());
+    neighbors.resize(neighborsOut.size());
+
+    // Do the actual remapping.
+    if (CLI::GetParam<string>("query_file") != "")
+    {
+      for (size_t i = 0; i < distances.size(); ++i)
       {
-        neighborsOut[oldFromNewQueries[i]][j] = oldFromNewRefs[neighbors[i][j]];
+        // Map distances (copy a column).
+        distances[oldFromNewQueries[i]] = distancesOut[i];
+
+        // Map indices of neighbors.
+        neighbors[oldFromNewQueries[i]].resize(neighborsOut[i].size());
+        for (size_t j = 0; j < distancesOut[i].size(); ++j)
+        {
+          neighbors[oldFromNewQueries[i]][j] =
+              oldFromNewRefs[neighborsOut[i][j]];
+        }
       }
     }
-  }
-  else
-  {
-    for (size_t i = 0; i < distances.size(); ++i)
+    else
     {
-      // Map distances (copy a column).
-      distancesOut[oldFromNewRefs[i]] = distances[i];
-
-      // Map indices of neighbors.
-      neighborsOut[oldFromNewRefs[i]].resize(neighbors[i].size());
-      for (size_t j = 0; j < distances[i].size(); ++j)
+      for (size_t i = 0; i < distances.size(); ++i)
       {
-        neighborsOut[oldFromNewRefs[i]][j] = oldFromNewRefs[neighbors[i][j]];
+        // Map distances (copy a column).
+        distances[oldFromNewRefs[i]] = distancesOut[i];
+
+        // Map indices of neighbors.
+        neighbors[oldFromNewRefs[i]].resize(neighborsOut[i].size());
+        for (size_t j = 0; j < distancesOut[i].size(); ++j)
+        {
+          neighbors[oldFromNewRefs[i]][j] = oldFromNewRefs[neighborsOut[i][j]];
+        }
       }
     }
-  }
 
-  // Clean up.
-  if (queryTree)
-    delete queryTree;
+    // Clean up.
+    if (queryTree)
+      delete queryTree;
+    delete rangeSearch;
+  }
 
   // Save output.  We have to do this by hand.
   fstream distancesStr(distancesFile.c_str(), fstream::out);
@@ -237,17 +288,17 @@
   else
   {
     // Loop over each point.
-    for (size_t i = 0; i < distancesOut.size(); ++i)
+    for (size_t i = 0; i < distances.size(); ++i)
     {
       // Store the distances of each point.  We may have 0 points to store, so
       // we must account for that possibility.
-      for (size_t j = 0; j + 1 < distancesOut[i].size(); ++j)
+      for (size_t j = 0; j + 1 < distances[i].size(); ++j)
       {
-        distancesStr << distancesOut[i][j] << ", ";
+        distancesStr << distances[i][j] << ", ";
       }
 
-      if (distancesOut[i].size() > 0)
-        distancesStr << distancesOut[i][distancesOut[i].size() - 1];
+      if (distances[i].size() > 0)
+        distancesStr << distances[i][distances[i].size() - 1];
 
       distancesStr << endl;
     }
@@ -264,23 +315,21 @@
   else
   {
     // Loop over each point.
-    for (size_t i = 0; i < neighborsOut.size(); ++i)
+    for (size_t i = 0; i < neighbors.size(); ++i)
     {
       // Store the neighbors of each point.  We may have 0 points to store, so
       // we must account for that possibility.
-      for (size_t j = 0; j + 1 < neighborsOut[i].size(); ++j)
+      for (size_t j = 0; j + 1 < neighbors[i].size(); ++j)
       {
-        neighborsStr << neighborsOut[i][j] << ", ";
+        neighborsStr << neighbors[i][j] << ", ";
       }
 
-      if (neighborsOut[i].size() > 0)
-        neighborsStr << neighborsOut[i][neighborsOut[i].size() - 1];
+      if (neighbors[i].size() > 0)
+        neighborsStr << neighbors[i][neighbors[i].size() - 1];
 
       neighborsStr << endl;
     }
 
     neighborsStr.close();
   }
-
-  delete rangeSearch;
 }



More information about the mlpack-svn mailing list