[mlpack-git] master: Dual tree traverser bug fix. (35af6e7)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:59:08 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 35af6e74e1adb0564369ebf754519b6d29e17183
Author: andrewmw94 <andrewmw94 at gmail.com>
Date:   Wed Aug 20 21:07:09 2014 +0000

    Dual tree traverser bug fix.


>---------------------------------------------------------------

35af6e74e1adb0564369ebf754519b6d29e17183
 .../rectangle_tree/dual_tree_traverser_impl.hpp    |   6 +-
 src/mlpack/methods/neighbor_search/allkfn_main.cpp | 219 +++++++++++++++------
 src/mlpack/methods/neighbor_search/allknn_main.cpp |  67 ++++---
 3 files changed, 200 insertions(+), 92 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 8bdad30..7214986 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -128,7 +128,7 @@ DualTreeTraverser<RuleType>::Traverse(
       {
         rule.TraversalInfo() = traversalInfo;
         nodesAndScores[i].node = referenceNode.Children()[i];
-        nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+        nodesAndScores[i].score = rule.Score(queryNode.Child(j), *nodesAndScores[i].node);
         nodesAndScores[i].travInfo = rule.TraversalInfo();
       }
       std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
@@ -137,8 +137,8 @@ DualTreeTraverser<RuleType>::Traverse(
       for(int i = 0; i < nodesAndScores.size(); i++)
       {
         rule.TraversalInfo() = nodesAndScores[i].travInfo;
-        if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
-          Traverse(queryNode, *(nodesAndScores[i].node));
+        if(rule.Rescore(queryNode.Child(j), *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+          Traverse(queryNode.Child(j), *(nodesAndScores[i].node));
         } else {
           numPrunes += nodesAndScores.size() - i;
           break;
diff --git a/src/mlpack/methods/neighbor_search/allkfn_main.cpp b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
index 5949bda..2eebeb2 100644
--- a/src/mlpack/methods/neighbor_search/allkfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
@@ -52,6 +52,8 @@ PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
 PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
 PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
     "dual-tree search).", "s");
+PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
+    "(experimental, may be slow.).", "T");
 
 int main(int argc, char *argv[])
 {
@@ -107,91 +109,186 @@ int main(int argc, char *argv[])
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  AllkFN* allkfn = NULL;
+  if(!CLI::HasParam("r_tree"))
+  {
+    AllkFN* allkfn = NULL;
 
-  std::vector<size_t> oldFromNewRefs;
+    std::vector<size_t> oldFromNewRefs;
 
-  // Build trees by hand, so we can save memory: if we pass a tree to
-  // NeighborSearch, it does not copy the matrix.
-  Log::Info << "Building reference tree..." << endl;
-  Timer::Start("reference_tree_building");
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Log::Info << "Building reference tree..." << endl;
+    Timer::Start("reference_tree_building");
 
-  BinarySpaceTree<bound::HRectBound<2>,
-      NeighborSearchStat<FurthestNeighborSort> >
-      refTree(referenceData, oldFromNewRefs, leafSize);
-  BinarySpaceTree<bound::HRectBound<2>,
-      NeighborSearchStat<FurthestNeighborSort> >*
-      queryTree = NULL; // Empty for now.
+    BinarySpaceTree<bound::HRectBound<2>,
+        NeighborSearchStat<FurthestNeighborSort> >
+        refTree(referenceData, oldFromNewRefs, leafSize);
+    BinarySpaceTree<bound::HRectBound<2>,
+        NeighborSearchStat<FurthestNeighborSort> >*
+        queryTree = NULL; // Empty for now.
 
-  Timer::Stop("reference_tree_building");
+    Timer::Stop("reference_tree_building");
 
-  std::vector<size_t> oldFromNewQueries;
+    std::vector<size_t> oldFromNewQueries;
 
-  if (CLI::GetParam<string>("query_file") != "")
-  {
-    string queryFile = CLI::GetParam<string>("query_file");
+    if (CLI::GetParam<string>("query_file") != "")
+    {
+      string queryFile = CLI::GetParam<string>("query_file");
 
-    data::Load(queryFile, queryData, true);
+      data::Load(queryFile, queryData, true);
 
-    Log::Info << "Loaded query data from '" << queryFile << "' ("
-        << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
+          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
 
-    Log::Info << "Building query tree..." << endl;
+      Log::Info << "Building query tree..." << endl;
 
-    if (naive && leafSize < queryData.n_cols)
-      leafSize = queryData.n_cols;
+      if (naive && leafSize < queryData.n_cols)
+        leafSize = queryData.n_cols;
 
-    // Build trees by hand, so we can save memory: if we pass a tree to
-    // NeighborSearch, it does not copy the matrix.
-    Timer::Start("query_tree_building");
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
+      Timer::Start("query_tree_building");
 
-    queryTree = new BinarySpaceTree<bound::HRectBound<2>,
-        NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
-        leafSize);
+      queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+          NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
+          leafSize);
 
-    Timer::Stop("query_tree_building");
+      Timer::Stop("query_tree_building");
 
-    allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
-        singleMode);
+      allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
+          singleMode);
 
-    Log::Info << "Tree built." << endl;
-  }
-  else
-  {
-    allkfn = new AllkFN(&refTree, referenceData, singleMode);
+      Log::Info << "Tree built." << endl;
+    }
+    else
+    {
+      allkfn = new AllkFN(&refTree, referenceData, singleMode);
 
-    Log::Info << "Trees built." << endl;
-  }
+      Log::Info << "Trees built." << endl;
+    }
 
-  Log::Info << "Computing " << k << " furthest neighbors..." << endl;
-  allkfn->Search(k, neighbors, distances);
+    Log::Info << "Computing " << k << " furthest neighbors..." << endl;
+    allkfn->Search(k, neighbors, distances);
 
-  Log::Info << "Neighbors computed." << endl;
+    Log::Info << "Neighbors computed." << endl;
 
-  // We have to map back to the original indices from before the tree
-  // construction.
-  Log::Info << "Re-mapping indices..." << endl;
+    // We have to map back to the original indices from before the tree
+    // construction.
+    Log::Info << "Re-mapping indices..." << endl;
 
-  arma::mat distancesOut(distances.n_rows, distances.n_cols);
-  arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+    arma::mat distancesOut(distances.n_rows, distances.n_cols);
+    arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
 
-  // Map the points back to their original locations.
-  if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
-    Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
-        distancesOut);
-  else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
-    Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
-  else
-    Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
-        distancesOut);
+    // Map the points back to their original locations.
+    if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+      Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
+          distancesOut);
+    else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+      Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
+    else
+      Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
+          distancesOut);
 
-  // Clean up.
-  if (queryTree)
-    delete queryTree;
+    // Clean up.
+    if (queryTree)
+      delete queryTree;
 
-  // Save output.
+    delete allkfn;
+    
+      // Save output.
   data::Save(distancesFile, distancesOut);
   data::Save(neighborsFile, neighborsOut);
     
-  delete allkfn;
+  } else {  // Use the R tree.
+    Log::Info << "Using R tree for furthest-neighbor calculation." << endl;
+
+    // Because we may construct it differently, we need a pointer.
+    NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+       tree::RStarTreeDescentHeuristic,
+       NeighborSearchStat<FurthestNeighborSort>,
+       arma::mat> >* allkfn = NULL;
+
+
+    // Build trees by hand, so we can save memory: if we pass a tree to
+    // NeighborSearch, it does not copy the matrix.
+    Log::Info << "Building reference tree..." << endl;
+    Timer::Start("tree_building");
+
+    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+       tree::RStarTreeDescentHeuristic,
+       NeighborSearchStat<FurthestNeighborSort>,
+       arma::mat>
+    refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
+
+    RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+       tree::RStarTreeDescentHeuristic,
+       NeighborSearchStat<FurthestNeighborSort>,
+       arma::mat>*
+    queryTree = NULL; // Empty for now.
+
+    Timer::Stop("tree_building");
+    if (CLI::GetParam<string>("query_file") != "")
+    {
+      string queryFile = CLI::GetParam<string>("query_file");
+
+      data::Load(queryFile, queryData, true);
+      
+      Log::Info << "Loaded query data from '" << queryFile << "' ("
+      << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+      
+      // Build trees by hand, so we can save memory: if we pass a tree to
+      // NeighborSearch, it does not copy the matrix.
+      if (!singleMode)
+      {
+        Timer::Start("tree_building");
+        
+        queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+        tree::RStarTreeDescentHeuristic,
+        NeighborSearchStat<FurthestNeighborSort>,
+        arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
+        
+        Timer::Stop("tree_building");
+      }
+      
+      
+      allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+      tree::RStarTreeDescentHeuristic,
+      NeighborSearchStat<FurthestNeighborSort>,
+      arma::mat> >(&refTree, queryTree,
+                   referenceData, queryData, singleMode);
+    } else
+    {
+      allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+      tree::RStarTreeDescentHeuristic,
+      NeighborSearchStat<FurthestNeighborSort>,
+      arma::mat> >(&refTree,
+                   referenceData, singleMode);
+    }
+    Log::Info << "Tree built." << endl;
+    
+    //arma::mat distancesOut;
+    //arma::Mat<size_t> neighborsOut;
+    
+    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+    allkfn->Search(k, neighbors, distances);
+    
+    Log::Info << "Neighbors computed." << endl;
+    
+    
+    if(queryTree)
+      delete queryTree;
+    
+    delete allkfn;
+      
+    // Save output.
+    data::Save(distancesFile, distances);
+    data::Save(neighborsFile, neighbors);
+    
+  }
+
+
+
 }
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index e644f73..d7dcbbe 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
 PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
     "(experimental, may be slow).", "c");
 PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
-    "(experimental, may be slow.  Currently automatically sets single_mode.).", "T");
+    "(experimental, may be slow.).", "T");
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -129,10 +129,6 @@ int main(int argc, char *argv[])
   if (CLI::HasParam("cover_tree") && CLI::HasParam("r_tree"))
   {
     Log::Warn << "--cover_tree overrides --r_tree." << endl;
-  } else if (!singleMode && CLI::HasParam("r_tree"))  // R_tree requires single mode.
-  {
-//     Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
-//     singleMode = true;
   }
   
   if (naive)
@@ -268,14 +264,14 @@ int main(int argc, char *argv[])
       delete allknn;
     } else { // R tree.
       // Make sure to notify the user that they are using an r tree.
-      Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
+      Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
       
       // Because we may construct it differently, we need a pointer.
       NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
       RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-		    tree::RStarTreeDescentHeuristic,
-		    NeighborSearchStat<NearestNeighborSort>,
-		    arma::mat> >* allknn = NULL;
+         tree::RStarTreeDescentHeuristic,
+         NeighborSearchStat<NearestNeighborSort>,
+         arma::mat> >* allknn = NULL;
 
       // Build trees by hand, so we can save memory: if we pass a tree to
       // NeighborSearch, it does not copy the matrix.
@@ -283,38 +279,53 @@ int main(int argc, char *argv[])
       Timer::Start("tree_building");
 
       RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-		    tree::RStarTreeDescentHeuristic,
-		    NeighborSearchStat<NearestNeighborSort>,
-		    arma::mat>
+         tree::RStarTreeDescentHeuristic,
+         NeighborSearchStat<NearestNeighborSort>,
+         arma::mat>
       refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
 
       RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-		    tree::RStarTreeDescentHeuristic,
-		    NeighborSearchStat<NearestNeighborSort>,
-		    arma::mat>*
+         tree::RStarTreeDescentHeuristic,
+         NeighborSearchStat<NearestNeighborSort>,
+         arma::mat>*
       queryTree = NULL; // Empty for now.
 
       Timer::Stop("tree_building");
       
       if (CLI::GetParam<string>("query_file") != "")
       {
-	Log::Info << "Loaded query data from '" << queryFile << "' ("
-	    << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+        Log::Info << "Loaded query data from '" << queryFile << "' ("
+          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+          
+        // Build trees by hand, so we can save memory: if we pass a tree to
+        // NeighborSearch, it does not copy the matrix.
+        if (!singleMode)
+        {
+          Timer::Start("tree_building");
+
+          queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+          tree::RStarTreeDescentHeuristic,
+          NeighborSearchStat<NearestNeighborSort>,
+          arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
+
+          Timer::Stop("tree_building");
+        }
+          
 
         allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
         RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-	  	      tree::RStarTreeDescentHeuristic,
-  		      NeighborSearchStat<NearestNeighborSort>,
-  		      arma::mat> >(&refTree, queryTree,
+          tree::RStarTreeDescentHeuristic,
+          NeighborSearchStat<NearestNeighborSort>,
+          arma::mat> >(&refTree, queryTree,
           referenceData, queryData, singleMode);
       } else
       {
-	      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
-      RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
-		    tree::RStarTreeDescentHeuristic,
-		    NeighborSearchStat<NearestNeighborSort>,
-		    arma::mat> >(&refTree,
-        referenceData, singleMode);
+        allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+        RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+          tree::RStarTreeDescentHeuristic,
+          NeighborSearchStat<NearestNeighborSort>,
+          arma::mat> >(&refTree,
+          referenceData, singleMode);
       }
       Log::Info << "Tree built." << endl;
       
@@ -326,8 +337,8 @@ int main(int argc, char *argv[])
 
       Log::Info << "Neighbors computed." << endl;
 
-
-
+      if(queryTree)
+        delete queryTree;
       delete allknn;
     }
   }



More information about the mlpack-git mailing list