[mlpack-git] master: Dual tree traverser bug fix. (35af6e7)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:59:08 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 35af6e74e1adb0564369ebf754519b6d29e17183
Author: andrewmw94 <andrewmw94 at gmail.com>
Date: Wed Aug 20 21:07:09 2014 +0000
Dual tree traverser bug fix.
>---------------------------------------------------------------
35af6e74e1adb0564369ebf754519b6d29e17183
.../rectangle_tree/dual_tree_traverser_impl.hpp | 6 +-
src/mlpack/methods/neighbor_search/allkfn_main.cpp | 219 +++++++++++++++------
src/mlpack/methods/neighbor_search/allknn_main.cpp | 67 ++++---
3 files changed, 200 insertions(+), 92 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 8bdad30..7214986 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -128,7 +128,7 @@ DualTreeTraverser<RuleType>::Traverse(
{
rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
- nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+ nodesAndScores[i].score = rule.Score(queryNode.Child(j), *nodesAndScores[i].node);
nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
@@ -137,8 +137,8 @@ DualTreeTraverser<RuleType>::Traverse(
for(int i = 0; i < nodesAndScores.size(); i++)
{
rule.TraversalInfo() = nodesAndScores[i].travInfo;
- if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- Traverse(queryNode, *(nodesAndScores[i].node));
+ if(rule.Rescore(queryNode.Child(j), *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+ Traverse(queryNode.Child(j), *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
break;
diff --git a/src/mlpack/methods/neighbor_search/allkfn_main.cpp b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
index 5949bda..2eebeb2 100644
--- a/src/mlpack/methods/neighbor_search/allkfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allkfn_main.cpp
@@ -52,6 +52,8 @@ PARAM_INT("leaf_size", "Leaf size for tree building.", "l", 20);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "s");
+PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
+ "(experimental, may be slow.).", "T");
int main(int argc, char *argv[])
{
@@ -107,91 +109,186 @@ int main(int argc, char *argv[])
arma::Mat<size_t> neighbors;
arma::mat distances;
- AllkFN* allkfn = NULL;
+ if(!CLI::HasParam("r_tree"))
+ {
+ AllkFN* allkfn = NULL;
- std::vector<size_t> oldFromNewRefs;
+ std::vector<size_t> oldFromNewRefs;
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Log::Info << "Building reference tree..." << endl;
- Timer::Start("reference_tree_building");
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("reference_tree_building");
- BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<FurthestNeighborSort> >
- refTree(referenceData, oldFromNewRefs, leafSize);
- BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<FurthestNeighborSort> >*
- queryTree = NULL; // Empty for now.
+ BinarySpaceTree<bound::HRectBound<2>,
+ NeighborSearchStat<FurthestNeighborSort> >
+ refTree(referenceData, oldFromNewRefs, leafSize);
+ BinarySpaceTree<bound::HRectBound<2>,
+ NeighborSearchStat<FurthestNeighborSort> >*
+ queryTree = NULL; // Empty for now.
- Timer::Stop("reference_tree_building");
+ Timer::Stop("reference_tree_building");
- std::vector<size_t> oldFromNewQueries;
+ std::vector<size_t> oldFromNewQueries;
- if (CLI::GetParam<string>("query_file") != "")
- {
- string queryFile = CLI::GetParam<string>("query_file");
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
- data::Load(queryFile, queryData, true);
+ data::Load(queryFile, queryData, true);
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
- Log::Info << "Building query tree..." << endl;
+ Log::Info << "Building query tree..." << endl;
- if (naive && leafSize < queryData.n_cols)
- leafSize = queryData.n_cols;
+ if (naive && leafSize < queryData.n_cols)
+ leafSize = queryData.n_cols;
- // Build trees by hand, so we can save memory: if we pass a tree to
- // NeighborSearch, it does not copy the matrix.
- Timer::Start("query_tree_building");
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Timer::Start("query_tree_building");
- queryTree = new BinarySpaceTree<bound::HRectBound<2>,
- NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
- leafSize);
+ queryTree = new BinarySpaceTree<bound::HRectBound<2>,
+ NeighborSearchStat<FurthestNeighborSort> >(queryData, oldFromNewQueries,
+ leafSize);
- Timer::Stop("query_tree_building");
+ Timer::Stop("query_tree_building");
- allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
- singleMode);
+ allkfn = new AllkFN(&refTree, queryTree, referenceData, queryData,
+ singleMode);
- Log::Info << "Tree built." << endl;
- }
- else
- {
- allkfn = new AllkFN(&refTree, referenceData, singleMode);
+ Log::Info << "Tree built." << endl;
+ }
+ else
+ {
+ allkfn = new AllkFN(&refTree, referenceData, singleMode);
- Log::Info << "Trees built." << endl;
- }
+ Log::Info << "Trees built." << endl;
+ }
- Log::Info << "Computing " << k << " furthest neighbors..." << endl;
- allkfn->Search(k, neighbors, distances);
+ Log::Info << "Computing " << k << " furthest neighbors..." << endl;
+ allkfn->Search(k, neighbors, distances);
- Log::Info << "Neighbors computed." << endl;
+ Log::Info << "Neighbors computed." << endl;
- // We have to map back to the original indices from before the tree
- // construction.
- Log::Info << "Re-mapping indices..." << endl;
+ // We have to map back to the original indices from before the tree
+ // construction.
+ Log::Info << "Re-mapping indices..." << endl;
- arma::mat distancesOut(distances.n_rows, distances.n_cols);
- arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
+ arma::mat distancesOut(distances.n_rows, distances.n_cols);
+ arma::Mat<size_t> neighborsOut(neighbors.n_rows, neighbors.n_cols);
- // Map the points back to their original locations.
- if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
- Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
- distancesOut);
- else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
- else
- Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
- distancesOut);
+ // Map the points back to their original locations.
+ if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
+ Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
+ distancesOut);
+ else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
+ Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
+ else
+ Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
+ distancesOut);
- // Clean up.
- if (queryTree)
- delete queryTree;
+ // Clean up.
+ if (queryTree)
+ delete queryTree;
- // Save output.
+ delete allkfn;
+
+ // Save output.
data::Save(distancesFile, distancesOut);
data::Save(neighborsFile, neighborsOut);
- delete allkfn;
+ } else { // Use the R tree.
+ Log::Info << "Using R tree for furthest-neighbor calculation." << endl;
+
+ // Because we may construct it differently, we need a pointer.
+ NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat> >* allkfn = NULL;
+
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ Log::Info << "Building reference tree..." << endl;
+ Timer::Start("tree_building");
+
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat>
+ refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
+
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat>*
+ queryTree = NULL; // Empty for now.
+
+ Timer::Stop("tree_building");
+ if (CLI::GetParam<string>("query_file") != "")
+ {
+ string queryFile = CLI::GetParam<string>("query_file");
+
+ data::Load(queryFile, queryData, true);
+
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ if (!singleMode)
+ {
+ Timer::Start("tree_building");
+
+ queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
+
+ Timer::Stop("tree_building");
+ }
+
+
+ allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat> >(&refTree, queryTree,
+ referenceData, queryData, singleMode);
+ } else
+ {
+ allkfn = new NeighborSearch<FurthestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<FurthestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<FurthestNeighborSort>,
+ arma::mat> >(&refTree,
+ referenceData, singleMode);
+ }
+ Log::Info << "Tree built." << endl;
+
+ //arma::mat distancesOut;
+ //arma::Mat<size_t> neighborsOut;
+
+ Log::Info << "Computing " << k << " nearest neighbors..." << endl;
+ allkfn->Search(k, neighbors, distances);
+
+ Log::Info << "Neighbors computed." << endl;
+
+
+ if(queryTree)
+ delete queryTree;
+
+ delete allkfn;
+
+ // Save output.
+ data::Save(distancesFile, distances);
+ data::Save(neighborsFile, neighbors);
+
+ }
+
+
+
}
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index e644f73..d7dcbbe 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -58,7 +58,7 @@ PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
PARAM_FLAG("cover_tree", "If true, use cover trees to perform the search "
"(experimental, may be slow).", "c");
PARAM_FLAG("r_tree", "If true, use an R-Tree to perform the search "
- "(experimental, may be slow. Currently automatically sets single_mode.).", "T");
+ "(experimental, may be slow.).", "T");
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -129,10 +129,6 @@ int main(int argc, char *argv[])
if (CLI::HasParam("cover_tree") && CLI::HasParam("r_tree"))
{
Log::Warn << "--cover_tree overrides --r_tree." << endl;
- } else if (!singleMode && CLI::HasParam("r_tree")) // R_tree requires single mode.
- {
-// Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
-// singleMode = true;
}
if (naive)
@@ -268,14 +264,14 @@ int main(int argc, char *argv[])
delete allknn;
} else { // R tree.
// Make sure to notify the user that they are using an r tree.
- Log::Info << "Using r tree for nearest-neighbor calculation." << endl;
+ Log::Info << "Using R tree for nearest-neighbor calculation." << endl;
// Because we may construct it differently, we need a pointer.
NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat> >* allknn = NULL;
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >* allknn = NULL;
// Build trees by hand, so we can save memory: if we pass a tree to
// NeighborSearch, it does not copy the matrix.
@@ -283,38 +279,53 @@ int main(int argc, char *argv[])
Timer::Start("tree_building");
RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat>
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>
refTree(referenceData, leafSize, leafSize * 0.4, 5, 2, 0);
RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat>*
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>*
queryTree = NULL; // Empty for now.
Timer::Stop("tree_building");
if (CLI::GetParam<string>("query_file") != "")
{
- Log::Info << "Loaded query data from '" << queryFile << "' ("
- << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+ Log::Info << "Loaded query data from '" << queryFile << "' ("
+ << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
+
+ // Build trees by hand, so we can save memory: if we pass a tree to
+ // NeighborSearch, it does not copy the matrix.
+ if (!singleMode)
+ {
+ Timer::Start("tree_building");
+
+ queryTree = new RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>(queryData, leafSize, leafSize * 0.4, 5, 2, 0);
+
+ Timer::Stop("tree_building");
+ }
+
allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat> >(&refTree, queryTree,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree, queryTree,
referenceData, queryData, singleMode);
} else
{
- allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
- RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
- tree::RStarTreeDescentHeuristic,
- NeighborSearchStat<NearestNeighborSort>,
- arma::mat> >(&refTree,
- referenceData, singleMode);
+ allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
+ RectangleTree<tree::RStarTreeSplit<tree::RStarTreeDescentHeuristic, NeighborSearchStat<NearestNeighborSort>, arma::mat>,
+ tree::RStarTreeDescentHeuristic,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat> >(&refTree,
+ referenceData, singleMode);
}
Log::Info << "Tree built." << endl;
@@ -326,8 +337,8 @@ int main(int argc, char *argv[])
Log::Info << "Neighbors computed." << endl;
-
-
+ if(queryTree)
+ delete queryTree;
delete allknn;
}
}
More information about the mlpack-git
mailing list