[mlpack-git] master: dual tree traverser bug fixes. (b88ae3b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:59:02 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit b88ae3b5fd041a1bbe05732d17a5e70be5d9097c
Author: andrewmw94 <andrewmw94 at gmail.com>
Date: Wed Aug 20 17:23:37 2014 +0000
dual tree traverser bug fixes.
>---------------------------------------------------------------
b88ae3b5fd041a1bbe05732d17a5e70be5d9097c
.../core/tree/rectangle_tree/dual_tree_traverser.hpp | 1 +
.../tree/rectangle_tree/dual_tree_traverser_impl.hpp | 16 +++++++---------
2 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 6610da5..50f971d 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -67,6 +67,7 @@ class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
public:
RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
double score;
+ typename RuleType::TraversalInfoType travInfo;
};
static bool nodeComparator(const NodeAndScore& obj1,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 7ac9b20..8bdad30 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -76,13 +76,7 @@ DualTreeTraverser<RuleType>::Traverse(
else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
{
// We only need to traverse down the query node. Order doesn't matter here.
- ++numScores;
- if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
- Traverse(queryNode.Child(0), referenceNode);
- else
- numPrunes++;
-
- for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+ for(size_t i = 0; i < queryNode.NumChildren(); ++i)
{
// Before recursing, we have to set the traversal information correctly.
rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@ DualTreeTraverser<RuleType>::Traverse(
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@ DualTreeTraverser<RuleType>::Traverse(
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
More information about the mlpack-git
mailing list