[mlpack-git] master: Dual tree traverser. (f83ac7d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:58:57 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit f83ac7d54c66830634d4391ea1f33bb7af57afa6
Author: andrewmw94 <andrewmw94 at gmail.com>
Date: Wed Aug 20 03:21:23 2014 +0000
Dual tree traverser.
>---------------------------------------------------------------
f83ac7d54c66830634d4391ea1f33bb7af57afa6
.../tree/rectangle_tree/dual_tree_traverser.hpp | 14 +++
.../rectangle_tree/dual_tree_traverser_impl.hpp | 121 +++++++++++++++++++--
.../tree/rectangle_tree/single_tree_traverser.hpp | 9 +-
src/mlpack/methods/neighbor_search/allknn_main.cpp | 4 +-
4 files changed, 135 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 1091224..6610da5 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -61,6 +61,20 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
size_t& NumBaseCases() { return numBaseCases; }
private:
+
+ //We use this struct and this function to make the sorting and scoring easy and efficient:
+ class NodeAndScore {
+ public:
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
+ double score;
+ };
+
+ static bool nodeComparator(const NodeAndScore& obj1,
+ const NodeAndScore& obj2)
+ {
+ return obj1.score < obj2.score;
+ }
+
//! Reference to the rules with which the trees will be traversed.
RuleType& rule;
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 8af988e..7ac9b20 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -25,7 +25,10 @@ template<typename RuleType>
RectangleTree<SplitType, DescentType, StatisticType, MatType>::
DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
rule(rule),
- numPrunes(0)
+ numPrunes(0),
+ numVisited(0),
+ numScores(0),
+ numBaseCases(0)
{ /* Nothing to do */ }
template<typename SplitType,
@@ -34,13 +37,117 @@ template<typename SplitType,
typename MatType>
template<typename RuleType>
void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
-DualTreeTraverser<RuleType>::Traverse(RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
- RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
+DualTreeTraverser<RuleType>::Traverse(
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& queryNode,
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>& referenceNode)
{
- //Do nothing. Just here to prevent warnings.
- if(queryNode.NumDescendants() > referenceNode.NumDescendants())
- return;
- return;
+ // Increment the visit counter.
+ ++numVisited;
+
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
+ // We now have four options.
+ // 1) Both nodes are leaf nodes.
+ // 2) Only the reference node is a leaf node.
+ // 3) Only the query node is a leaf node.
+ // 4) Niether node is a leaf node.
+ // We go through those options in that order.
+
+ if(queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // Evaluate the base case. Do the query points on the outside so we can possibly
+ // prune the reference node for that particular point.
+ for(size_t query = 0; query < queryNode.Count(); ++query)
+ {
+ // Restore the traversal information.
+ rule.TraversalInfo() = traversalInfo;
+ const double childScore = rule.Score(queryNode.Points()[query], referenceNode);
+
+ if(childScore == DBL_MAX)
+ continue; // This point doesn't require a search in this reference node.
+
+ for(size_t ref = 0; ref < referenceNode.Count(); ++ref)
+ rule.BaseCase(queryNode.Points()[query], referenceNode.Points()[ref]);
+
+ numBaseCases += referenceNode.Count();
+ }
+ }
+ else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // We only need to traverse down the query node. Order doesn't matter here.
+ ++numScores;
+ if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
+ Traverse(queryNode.Child(0), referenceNode);
+ else
+ numPrunes++;
+
+ for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+ {
+ // Before recursing, we have to set the traversal information correctly.
+ rule.TraversalInfo() = traversalInfo;
+ ++numScores;
+ if(rule.Score(queryNode.Child(i), referenceNode) < DBL_MAX)
+ Traverse(queryNode.Child(i), referenceNode);
+ else
+ numPrunes++;
+ }
+ }
+ else if(queryNode.IsLeaf() && !referenceNode.IsLeaf())
+ {
+ // We only need to traverse down the reference node. Order does matter here.
+
+ // We sort the children of the reference node by their scores.
+ std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+ for(int i = 0; i < referenceNode.NumChildren(); i++)
+ {
+ nodesAndScores[i].node = referenceNode.Children()[i];
+ nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+ }
+ std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+ numScores += nodesAndScores.size();
+
+ for(int i = 0; i < nodesAndScores.size(); i++)
+ {
+ if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *(nodesAndScores[i].node));
+ } else {
+ numPrunes += nodesAndScores.size() - i;
+ break;
+ }
+ }
+ }
+ else
+ {
+ // We need to traverse down both the reference and the query trees.
+ // We loop through all of the query nodes, and for each of them, we
+ // loop through the reference nodes to see where we need to descend.
+
+ for(int j = 0; j < queryNode.NumChildren(); j++)
+ {
+ // We sort the children of the reference node by their scores.
+ std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+ for(int i = 0; i < referenceNode.NumChildren(); i++)
+ {
+ nodesAndScores[i].node = referenceNode.Children()[i];
+ nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+ }
+ std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+ numScores += nodesAndScores.size();
+
+ for(int i = 0; i < nodesAndScores.size(); i++)
+ {
+ if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *(nodesAndScores[i].node));
+ } else {
+ numPrunes += nodesAndScores.size() - i;
+ break;
+ }
+ }
+ }
+ }
}
}; // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
index 6eaf452..b34e148 100644
--- a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
@@ -44,10 +44,12 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
//! Modify the number of prunes.
size_t& NumPrunes() { return numPrunes; }
- // We use this struct and this function to make the sorting and scoring easy
+ private:
+
+ // We use this class and this function to make the sorting and scoring easy
// and efficient:
- struct NodeAndScore
- {
+ class NodeAndScore {
+ public:
RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
double score;
};
@@ -57,7 +59,6 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
return obj1.score < obj2.score;
}
- private:
//! Reference to the rules with which the tree will be traversed.
RuleType& rule;
diff --git a/src/mlpack/methods/neighbor_search/allknn_main.cpp b/src/mlpack/methods/neighbor_search/allknn_main.cpp
index 9401e9f..e644f73 100644
--- a/src/mlpack/methods/neighbor_search/allknn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/allknn_main.cpp
@@ -131,8 +131,8 @@ int main(int argc, char *argv[])
Log::Warn << "--cover_tree overrides --r_tree." << endl;
} else if (!singleMode && CLI::HasParam("r_tree")) // R_tree requires single mode.
{
- Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
- singleMode = true;
+// Log::Warn << "--single_mode assumed because --r_tree is present." << endl;
+// singleMode = true;
}
if (naive)
More information about the mlpack-git
mailing list