[mlpack-git] master: Refactor executables for new NeighborSearch API. (57d0567)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Apr 22 16:32:33 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8f85309ae9be40e819b301b39c9a940aa28f3bb2...57d0567dddff01feea73b348f38cc040dc3cf8e3

>---------------------------------------------------------------

commit 57d0567dddff01feea73b348f38cc040dc3cf8e3
Author: ryan <ryan at ratml.org>
Date:   Wed Apr 22 16:30:08 2015 -0400

    Refactor executables for new NeighborSearch API.


>---------------------------------------------------------------

57d0567dddff01feea73b348f38cc040dc3cf8e3
 src/mlpack/methods/neighbor_search/allkfn_main.cpp | 216 ++++++++-----------
 src/mlpack/methods/neighbor_search/allknn_main.cpp | 229 ++++++++-------------
 2 files changed, 168 insertions(+), 277 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/allkfn_main.cpp b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
index dcc29cd..28d79bc 100644
--- a/src/mlpack/methods/neighbor_search/allkfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
@@ -89,6 +89,12 @@ int main(int argc, char *argv[])
     Log::Fatal << referenceData.n_cols << ")." << endl;
   }
 
+  if (CLI::GetParam<string>("query_file") != "")
+  {
+    string queryFile = CLI::GetParam<string>("query_file");
+    data::Load(queryFile, queryData, true);
+  }
+
   // Sanity check on leaf size.
   if (lsInt < 0)
   {
@@ -103,192 +109,134 @@ int main(int argc, char *argv[])
     Log::Warn << "--single_mode ignored because --naive is present." << endl;
   }
 
-  if (naive)
-    leafSize = referenceData.n_cols;
-
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  if(!CLI::HasParam("r_tree"))
+  if (naive)
   {
-    AllkFN* allkfn = NULL;
+    AllkFN allkfn(referenceData, false, naive);
 
+    if (CLI::HasParam("query_file"))
+      allkfn.Search(queryData, k, neighbors, distances);
+    else
+      allkfn.Search(k, neighbors, distances);
+  }
+  if (!CLI::HasParam("r_tree"))
+  {
+    // Use default kd-tree.
     std::vector<size_t> oldFromNewRefs;
 
+    typedef BinarySpaceTree<bound::HRectBound<2>,
+        NeighborSearchStat<FurthestNeighborSort>> TreeType;
+
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("reference_tree_building");
-
-    BinarySpaceTree<bound::HRectBound<2>,
-        NeighborSearchStat<FurthestNeighborSort> >
-        refTree(referenceData, oldFromNewRefs, leafSize);
-    BinarySpaceTree<bound::HRectBound<2>,
-        NeighborSearchStat<FurthestNeighborSort> >*
-        queryTree = NULL; // Empty for now.
-
+    TreeType refTree(referenceData, oldFromNewRefs, leafSize);
     Timer::Stop("reference_tree_building");
 
     std::vector<size_t> oldFromNewQueries;
 
-    if (CLI::GetParam<string>("query_file") != "")
-    {
-      string queryFile = CLI::GetParam<string>("query_file");
-
-      data::Load(queryFile, queryData, true);
-
-      Log::Info << "Loaded query data from '" << queryFile << "' ("
-          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
-      Log::Info << "Building query tree..." << endl;
-
-      if (naive && leafSize < queryData.n_cols)
-        leafSize = queryData.n_cols;
-
-      // Build trees by hand, so we can save memory: if we pass a tree to
-      // NeighborSearch, it does not copy the matrix.
-      Timer::Start("query_tree_building");
+    AllkFN allkfn(&refTree, singleMode);
 
-      queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-          NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
-          leafSize);
-
-      Timer::Stop("query_tree_building");
+    arma::mat distancesOut(distances.n_rows, distances.n_cols);
+    arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
 
-      allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
-          singleMode);
+    if (CLI::HasParam("query_file"))
+    {
+      if (!singleMode)
+      {
+        // Build trees by hand, so we can save memory: if we pass a tree to
+        // NeighborSearch, it does not copy the matrix.
+        Log::Info << "Building query tree..." << endl;
+        Timer::Start("tree_building");
+        TreeType queryTree(queryData, oldFromNewQueries, leafSize);
+        Timer::Stop("tree_building");
+        Log::Info << "Tree built." << endl;
 
-      Log::Info << "Tree built." << endl;
+        Log::Info << "Computing " << k << " furthest neighbors..." << endl;
+        allkfn.Search(&queryTree, k, neighborsOut, distancesOut);
+      }
+      else
+      {
+        Log::Info << "Computing " << k << " furthest neighbors..." << endl;
+        allkfn.Search(queryData, k, neighborsOut, distancesOut);
+      }
     }
     else
     {
-      allkfn = new AllkFN(&refTree, referenceData, singleMode);
-
-      Log::Info << "Trees built." << endl;
+      Log::Info << "Computing " << k << " furthest neighbors..." << endl;
+      allkfn.Search(k, neighborsOut, distancesOut);
     }
 
-    Log::Info << "Computing " << k << " furthest neighbors..." << endl;
-    allkfn->Search(k, neighbors, distances);
-
     Log::Info << "Neighbors computed." << endl;
 
     // We have to map back to the original indices from before the tree
     // construction.
     Log::Info << "Re-mapping indices..." << endl;
 
-    arma::mat distancesOut(distances.n_rows, distances.n_cols);
-    arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
-
     // Map the points back to their original locations.
     if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
-      Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
-          distancesOut);
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries, neighbors,
+          distances);
     else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-      Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
     else
-      Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
-          distancesOut);
-
-    // Clean up.
-    if (queryTree)
-      delete queryTree;
-
-    delete allkfn;
-
-      // Save output.
-  data::Save(distancesFile, distancesOut);
-  data::Save(neighborsFile, neighborsOut);
-
-  } else {  // Use the R tree.
+      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs, neighbors,
+          distances);
+  }
+  else
+  {
+    // Use the R tree.
     Log::Info << "Using R tree for furthest-neighbor calculation." << endl;
 
-    // Because we may construct it differently, we need a pointer.
-    NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
-    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-       tree::RStarTreeDescentHeuristic,
-       NeighborSearchStat<FurthestNeighborSort>,
-       arma::mat> >* allkfn = NULL;
-
+    // Convenience typedef.
+    typedef RectangleTree<
+        tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
+            NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<FurthestNeighborSort>,
+        arma::mat> TreeType;
 
     // Build trees by hand, so we can save memory: if we pass a tree to
     // NeighborSearch, it does not copy the matrix.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
+    TreeType refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
+    Timer::Stop("tree_building");
+    Log::Info << "Tree built." << endl;
 
-    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-       tree::RStarTreeDescentHeuristic,
-       NeighborSearchStat<FurthestNeighborSort>,
-       arma::mat>
-    refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
-
-    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-       tree::RStarTreeDescentHeuristic,
-       NeighborSearchStat<FurthestNeighborSort>,
-       arma::mat>*
-    queryTree = NULL; // Empty for now.
+    typedef NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+        TreeType> AllkFNType;
+    AllkFNType allkfn(&refTree, singleMode);
 
-    Timer::Stop("tree_building");
     if (CLI::GetParam<string>("query_file") != "")
     {
-      string queryFile = CLI::GetParam<string>("query_file");
-
-      data::Load(queryFile, queryData, true);
-
-      Log::Info << "Loaded query data from '" << queryFile << "' ("
-      << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
-      // Build trees by hand, so we can save memory: if we pass a tree to
-      // NeighborSearch, it does not copy the matrix.
       if (!singleMode)
       {
         Timer::Start("tree_building");
-
-        queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-        tree::RStarTreeDescentHeuristic,
-        NeighborSearchStat<FurthestNeighborSort>,
-        arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
-
+        TreeType queryTree(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
         Timer::Stop("tree_building");
-      }
-
 
-      allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-      tree::RStarTreeDescentHeuristic,
-      NeighborSearchStat<FurthestNeighborSort>,
-      arma::mat> >(&refTree, queryTree,
-                   referenceData, queryData, singleMode);
-    } else
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allkfn.Search(&queryTree, k, neighbors, distances);
+      }
+      else
+      {
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allkfn.Search(queryData, k, neighbors, distances);
+      }
+    }
+    else
     {
-      allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
-      tree::RStarTreeDescentHeuristic,
-      NeighborSearchStat<FurthestNeighborSort>,
-      arma::mat> >(&refTree,
-                   referenceData, singleMode);
+      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+      allkfn.Search(k, neighbors, distances);
     }
-    Log::Info << "Tree built." << endl;
-
-    //arma::mat distancesOut;
-    //arma::Mat<size_t> neighborsOut;
-
-    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-    allkfn->Search(k, neighbors, distances);
-
     Log::Info << "Neighbors computed." << endl;
-
-
-    if(queryTree)
-      delete queryTree;
-
-    delete allkfn;
-
-    // Save output.
-    data::Save(distancesFile, distances);
-    data::Save(neighborsFile, neighbors);
-
   }
 
-
-
+  // Save output.
+  data::Save(distancesFile, distances);
+  data::Save(neighborsFile, neighbors);
 }
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 3aaf45e..5a83b4d 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -68,13 +68,6 @@ int main(int argc, char *argv[])
   // Give CLI the command line parameters the user passed in.
   CLI::ParseCommandLine(argc, argv);
 
-      Log::Info << "sizeof(BinarySpaceTree<>): " << sizeof(BinarySpaceTree<bound::HRectBound<2>>) << ".\n";
-      Log::Info << "sizeof(HRectBound<2>): " << sizeof(bound::HRectBound<2>) << ".\n";
-      Log::Info << "sizeof(NeighborSearchStat): " << sizeof(NeighborSearchStat<NearestNeighborSort>) << ".\n";
-      Log::Info << "sizeof(TreeType): " <<
-sizeof(BinarySpaceTree<bound::HRectBound<2>,
-NeighborSearchStat<NearestNeighborSort>>) << ".\n";
-
   if (CLI::GetParam<int>("seed") != 0)
     math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
   else
@@ -138,9 +131,6 @@ NeighborSearchStat<NearestNeighborSort>>) << ".\n";
     Log::Warn << "--cover_tree overrides --r_tree." << endl;
   }
 
-  if (naive)
-    leafSize = referenceData.n_cols;
-
   // See if we want to project onto a random basis.
   if (randomBasis)
   {
@@ -181,73 +171,68 @@ NeighborSearchStat<NearestNeighborSort>>) << ".\n";
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  if (!CLI::HasParam("cover_tree"))
+  if (naive)
   {
-    if(!CLI::HasParam("r_tree"))
-    {
-      // Because we may construct it differently, we need a pointer.
-      AllkNN* allknn = NULL;
+    AllkNN allknn(referenceData, false, naive);
 
+    if (CLI::GetParam<string>("query_file") != "")
+      allknn.Search(queryData, k, neighbors, distances);
+    else
+      allknn.Search(k, neighbors, distances);
+  }
+  else if (!CLI::HasParam("cover_tree"))
+  {
+    if (!CLI::HasParam("r_tree"))
+    {
+      // We're using the kd-tree.
       // Mappings for when we build the tree.
       std::vector<size_t> oldFromNewRefs;
 
+      // Convenience typedef.
+      typedef BinarySpaceTree<bound::HRectBound<2>,
+          NeighborSearchStat<NearestNeighborSort>> TreeType;
+
       // Build trees by hand, so we can save memory: if we pass a tree to
       // NeighborSearch, it does not copy the matrix.
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
-
-      BinarySpaceTree<bound::HRectBound<2>,
-          NeighborSearchStat<NearestNeighborSort> >
-          refTree(referenceData, oldFromNewRefs, leafSize);
-      BinarySpaceTree<bound::HRectBound<2>,
-          NeighborSearchStat<NearestNeighborSort> >*
-          queryTree = NULL; // Empty for now.
-
+      TreeType refTree(referenceData, oldFromNewRefs, leafSize);
       Timer::Stop("tree_building");
 
+      AllkNN allknn(&refTree, singleMode);
+
       std::vector<size_t> oldFromNewQueries;
 
+      arma::mat distancesOut;
+      arma::Mat<size_t> neighborsOut;
+
       if (CLI::GetParam<string>("query_file") != "")
       {
-        if (naive && leafSize < queryData.n_cols)
-          leafSize = queryData.n_cols;
-
-        Log::Info << "Loaded query data from '" << queryFile << "' ("
-            << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
-        Log::Info << "Building query tree..." << endl;
-
         // Build trees by hand, so we can save memory: if we pass a tree to
         // NeighborSearch, it does not copy the matrix.
         if (!singleMode)
         {
+          Log::Info << "Building query tree..." << endl;
           Timer::Start("tree_building");
-
-          queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-              NeighborSearchStat<NearestNeighborSort> >(queryData,
-              oldFromNewQueries, leafSize);
-
+          TreeType queryTree(queryData, oldFromNewQueries, leafSize);
           Timer::Stop("tree_building");
-        }
+          Log::Info << "Tree built." << endl;
 
-        allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
-            singleMode);
-
-        Log::Info << "Tree built." << endl;
+          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+          allknn.Search(&queryTree, k, neighborsOut, distancesOut);
+        }
+        else
+        {
+          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+          allknn.Search(queryData, k, neighborsOut, distancesOut);
+        }
       }
       else
       {
-        allknn = new AllkNN(&refTree, referenceData, singleMode);
-
-        Log::Info << "Trees built." << endl;
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allknn.Search(k, neighborsOut, distancesOut);
       }
 
-      arma::mat distancesOut;
-      arma::Mat<size_t> neighborsOut;
-
-      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-      allknn->Search(k, neighborsOut, distancesOut);
-
       Log::Info << "Neighbors computed." << endl;
 
       // We have to map back to the original indices from before the tree
@@ -263,90 +248,57 @@ NeighborSearchStat<NearestNeighborSort>>) << ".\n";
       else
         Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
             neighbors, distances);
-
-      // Clean up.
-      if (queryTree)
-        delete queryTree;
-
-      delete allknn;
-    } else { // R tree.
+    }
+    else
+    {
       // Make sure to notify the user that they are using an r tree.
       Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
 
-      // Because we may construct it differently, we need a pointer.
-      NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-         tree::RStarTreeDescentHeuristic,
-         NeighborSearchStat<NearestNeighborSort>,
-         arma::mat> >* allknn = NULL;
+      // Convenience typedef.
+      typedef RectangleTree<
+          tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic,
+              NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+          tree::RStarTreeDescentHeuristic,
+          NeighborSearchStat<NearestNeighborSort>,
+          arma::mat> TreeType;
 
-      // Build trees by hand, so we can save memory: if we pass a tree to
-      // NeighborSearch, it does not copy the matrix.
+      // Build tree by hand in order to apply user options.
       Log::Info << "Building reference tree..." << endl;
       Timer::Start("tree_building");
-
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-         tree::RStarTreeDescentHeuristic,
-         NeighborSearchStat<NearestNeighborSort>,
-         arma::mat>
-      refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
-
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-         tree::RStarTreeDescentHeuristic,
-         NeighborSearchStat<NearestNeighborSort>,
-         arma::mat>*
-      queryTree = NULL; // Empty for now.
-
+      TreeType refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
       Timer::Stop("tree_building");
+      Log::Info << "Tree built." << endl;
+
+      typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+          TreeType> AllkNNType;
+      AllkNNType allknn(&refTree, singleMode);
 
       if (CLI::GetParam<string>("query_file") != "")
       {
-        Log::Info << "Loaded query data from '" << queryFile << "' ("
-          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
-
         // Build trees by hand, so we can save memory: if we pass a tree to
         // NeighborSearch, it does not copy the matrix.
         if (!singleMode)
         {
+          Log::Info << "Building query tree..." << endl;
           Timer::Start("tree_building");
-
-          queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-          tree::RStarTreeDescentHeuristic,
-          NeighborSearchStat<NearestNeighborSort>,
-          arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
-
+          TreeType queryTree(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
           Timer::Stop("tree_building");
-        }
+          Log::Info << "Tree built." << endl;
 
-
-        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-          tree::RStarTreeDescentHeuristic,
-          NeighborSearchStat<NearestNeighborSort>,
-          arma::mat> >(&refTree, queryTree,
-          referenceData, queryData, singleMode);
-      } else
+          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+          allknn.Search(&queryTree, k, neighbors, distances);
+        }
+        else
+        {
+          Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+          allknn.Search(queryData, k, neighbors, distances);
+        }
+      }
+      else
       {
-        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-          tree::RStarTreeDescentHeuristic,
-          NeighborSearchStat<NearestNeighborSort>,
-          arma::mat> >(&refTree,
-          referenceData, singleMode);
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allknn.Search(k, neighbors, distances);
       }
-      Log::Info << "Tree built." << endl;
-
-      //arma::mat distancesOut;
-      //arma::Mat<size_t> neighborsOut;
-
-      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-      allknn->Search(k, neighbors, distances);
-
-      Log::Info << "Neighbors computed." << endl;
-
-      if(queryTree)
-        delete queryTree;
-      delete allknn;
     }
   }
   else // Cover trees.
@@ -354,19 +306,19 @@ NeighborSearchStat<NearestNeighborSort>>) << ".\n";
     // Make sure to notify the user that they are using cover trees.
     Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;
 
+    // Convenience typedef.
+    typedef CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
+        NeighborSearchStat<NearestNeighborSort>> TreeType;
+
     // Build our reference tree.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
-    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        NeighborSearchStat<NearestNeighborSort> > referenceTree(referenceData,
-        1.3);
-    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        NeighborSearchStat<NearestNeighborSort> >* queryTree = NULL;
+    TreeType refTree(referenceData, 1.3);
     Timer::Stop("tree_building");
 
-    NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-        CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-        NeighborSearchStat<NearestNeighborSort> > >* allknn = NULL;
+    typedef NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+        TreeType> AllkNNType;
+    AllkNNType allknn(&refTree, singleMode);
 
     // See if we have query data.
     if (CLI::HasParam("query_file"))
@@ -376,34 +328,25 @@ NeighborSearchStat<NearestNeighborSort>>) << ".\n";
       {
         Log::Info << "Building query tree..." << endl;
         Timer::Start("tree_building");
-        queryTree = new CoverTree<metric::LMetric<2, true>,
-            tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> >(
-            queryData, 1.3);
+        TreeType queryTree(queryData, 1.3);
         Timer::Stop("tree_building");
-      }
 
-      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree, queryTree,
-          referenceData, queryData, singleMode);
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allknn.Search(&queryTree, k, neighbors, distances);
+      }
+      else
+      {
+        Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+        allknn.Search(queryData, k, neighbors, distances);
+      }
     }
     else
     {
-      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
-          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree,
-          referenceData, singleMode);
+      Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+      allknn.Search(k, neighbors, distances);
     }
 
-    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
-    allknn->Search(k, neighbors, distances);
-
     Log::Info << "Neighbors computed." << endl;
-
-    delete allknn;
-
-    if (queryTree)
-      delete queryTree;
   }
 
   // Save put.



More information about the mlpack-git mailing list