[mlpack-svn] r12663 - mlpack/trunk/src/mlpack/methods/neighbor_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 9 16:25:47 EDT 2012


Author: rcurtin
Date: 2012-05-09 16:25:47 -0400 (Wed, 09 May 2012)
New Revision: 12663

Modified:
   mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
Log:
Add an option to use cover trees.


Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2012-05-09 20:25:01 UTC (rev 12662)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp	2012-05-09 20:25:47 UTC (rev 12663)
@@ -6,6 +6,7 @@
  * options.
  */
 #include <mlpack/core.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 
 #include <string>
 #include <fstream>
@@ -52,6 +53,8 @@
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search.", "s");
+PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search.",
+    "c");
 
 int main(int argc, char *argv[])
 {
@@ -107,111 +110,170 @@
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  // Because we may construct it differently, we need a pointer.
-  AllkNN* allknn = NULL;
+  if (!CLI::HasParam("cover_tree"))
+  {
+    // Because we may construct it differently, we need a pointer.
+    AllkNN* allknn = NULL;
 
-  // Mappings for when we build the tree.
-  std::vector<size_t> oldFromNewRefs;
+    // Mappings for when we build the tree.
+    std::vector<size_t> oldFromNewRefs;
 
-  // Build trees by hand, so we can save memory: if we pass a tree to
-  // NeighborSearch, it does not copy the matrix.
-  Log::Info << "Building reference tree..." << endl;
-  Timer::Start("tree_building");
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Log::Info << "Building reference tree..." << endl;
+    Timer::Start("tree_building");
 
-  BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
+    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
       refTree(referenceData, oldFromNewRefs, leafSize);
-  BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
+    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
       queryTree = NULL; // Empty for now.
 
-  Timer::Stop("tree_building");
+    Timer::Stop("tree_building");
 
-  std::vector<size_t> oldFromNewQueries;
+    std::vector<size_t> oldFromNewQueries;
 
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    string queryFile = CLI::GetParam<string>("query_file");
+    if (CLI::GetParam<string>("query_file") != "")
+    {
+      string queryFile = CLI::GetParam<string>("query_file");
 
-    data::Load(queryFile.c_str(), queryData, true);
+      data::Load(queryFile.c_str(), queryData, true);
 
-    if (naive && leafSize < queryData.n_cols)
-      leafSize = queryData.n_cols;
+      if (naive && leafSize < queryData.n_cols)
+        leafSize = queryData.n_cols;
 
-    Log::Info << "Loaded query data from '" << queryFile << "' ("
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
         << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
-    Log::Info << "Building query tree..." << endl;
+      Log::Info << "Building query tree..." << endl;
 
-    // Build trees by hand, so we can save memory: if we pass a tree to
-    // NeighborSearch, it does not copy the matrix.
-    Timer::Start("tree_building");
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
+      Timer::Start("tree_building");
 
-    queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-        QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
-        leafSize);
+      queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+                QueryStat<NearestNeighborSort> >(queryData, oldFromNewQueries,
+                    leafSize);
 
-    Timer::Stop("tree_building");
+      Timer::Stop("tree_building");
 
-    allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
-        singleMode);
+      allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
+          singleMode);
 
-    Log::Info << "Tree built." << endl;
-  }
-  else
-  {
-    allknn = new AllkNN(&refTree, referenceData, singleMode);
+      Log::Info << "Tree built." << endl;
+    }
+    else
+    {
+      allknn = new AllkNN(&refTree, referenceData, singleMode);
 
-    Log::Info << "Trees built." << endl;
-  }
+      Log::Info << "Trees built." << endl;
+    }
 
-  Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-  allknn->Search(k, neighbors, distances);
+    arma::mat distancesOut;
+    arma::Mat<size_t> neighborsOut;
 
-  Log::Info << "Neighbors computed." << endl;
+    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+    allknn->Search(k, neighborsOut, distancesOut);
 
-  // We have to map back to the original indices from before the tree
-  // construction.
-  Log::Info << "Re-mapping indices..." << endl;
+    Log::Info << "Neighbors computed." << endl;
 
-  arma::mat distancesOut(distances.n_rows, distances.n_cols);
-  arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+    // We have to map back to the original indices from before the tree
+    // construction.
+    Log::Info << "Re-mapping indices..." << endl;
 
-  // Do the actual remapping.
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    for (size_t i = 0; i < distances.n_cols; ++i)
+    neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+    distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+
+    // Do the actual remapping.
+    if (CLI::GetParam<string>("query_file") != "")
     {
-      // Map distances (copy a column).
-      distancesOut.col(oldFromNewQueries[i]) = distances.col(i);
+      for (size_t i = 0; i < distancesOut.n_cols; ++i)
+      {
+        // Map distances (copy a column).
+        distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
 
-      // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; ++j)
+        // Map indices of neighbors.
+        for (size_t j = 0; j < distancesOut.n_rows; ++j)
+        {
+          neighbors(j, oldFromNewQueries[i]) =
+              oldFromNewRefs[neighborsOut(j, i)];
+        }
+      }
+    }
+    else
+    {
+      for (size_t i = 0; i < distancesOut.n_cols; ++i)
       {
-        neighborsOut(j, oldFromNewQueries[i]) = oldFromNewRefs[neighbors(j, i)];
+        // Map distances (copy a column).
+        distances.col(oldFromNewRefs[i]) = distancesOut.col(i);
+
+        // Map indices of neighbors.
+        for (size_t j = 0; j < distancesOut.n_rows; ++j)
+        {
+          neighbors(j, oldFromNewRefs[i]) = oldFromNewRefs[neighborsOut(j, i)];
+        }
       }
     }
+
+    // Clean up.
+    if (queryTree)
+      delete queryTree;
+
+    delete allknn;
   }
-  else
+  else // Cover trees.
   {
-    for (size_t i = 0; i < distances.n_cols; ++i)
+    // Build our reference tree.
+    Log::Info << "Building reference tree..." << endl;
+    Timer::Start("tree_building");
+    CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+        QueryStat<NearestNeighborSort> > referenceTree(referenceData);
+    CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+        QueryStat<NearestNeighborSort> >* queryTree = NULL;
+    Timer::Stop("tree_building");
+
+    NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+        CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+        QueryStat<NearestNeighborSort> > >* allknn = NULL;
+
+    // See if we have query data.
+    if (CLI::HasParam("query_file"))
     {
-      // Map distances (copy a column).
-      distancesOut.col(oldFromNewRefs[i]) = distances.col(i);
+      string queryFile = CLI::GetParam<string>("query_file");
 
-      // Map indices of neighbors.
-      for (size_t j = 0; j < distances.n_rows; ++j)
-      {
-        neighborsOut(j, oldFromNewRefs[i]) = oldFromNewRefs[neighbors(j, i)];
-      }
+      data::Load(queryFile, queryData, true);
+
+      // Build query tree.
+      Log::Info << "Building query tree..." << endl;
+      Timer::Start("tree_building");
+      queryTree = new CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+          QueryStat<NearestNeighborSort> >(queryData);
+      Timer::Stop("tree_building");
+
+      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+          CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+          QueryStat<NearestNeighborSort> > >(&referenceTree, queryTree,
+          referenceData, queryData, true);
     }
+    else
+    {
+      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2>,
+          CoverTree<metric::LMetric<2>, tree::FirstPointIsRoot,
+          QueryStat<NearestNeighborSort> > >(&referenceTree, referenceData,
+          true);
+    }
+
+    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+    allknn->Search(k, neighbors, distances);
+
+    Log::Info << "Neighbors computed." << endl;
+
+    delete allknn;
+
+    if (queryTree)
+      delete queryTree;
   }
 
-  // Clean up.
-  if (queryTree)
-    delete queryTree;
-
   // Save output.
-  data::Save(distancesFile, distancesOut);
-  data::Save(neighborsFile, neighborsOut);
-
-  delete allknn;
+  data::Save(distancesFile, distances);
+  data::Save(neighborsFile, neighbors);
 }




More information about the mlpack-svn mailing list