[mlpack-svn] r17084 - mlpack/trunk/src/mlpack/core/tree/rectangle_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 20 13:23:38 EDT 2014
Author: andrewmw94
Date: Wed Aug 20 13:23:37 2014
New Revision: 17084
Log:
dual tree traverser bug fixes.
Modified:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp Wed Aug 20 13:23:37 2014
@@ -67,6 +67,7 @@
public:
RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
double score;
+ typename RuleType::TraversalInfoType travInfo;
};
static bool nodeComparator(const NodeAndScore& obj1,
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp Wed Aug 20 13:23:37 2014
@@ -76,13 +76,7 @@
else if(!queryNode.IsLeaf() && referenceNode.IsLeaf())
{
// We only need to traverse down the query node. Order doesn't matter here.
- ++numScores;
- if(rule.Score(queryNode.Child(0), referenceNode) < DBL_MAX)
- Traverse(queryNode.Child(0), referenceNode);
- else
- numPrunes++;
-
- for(size_t i = 1; i < queryNode.NumChildren(); ++i)
+ for(size_t i = 0; i < queryNode.NumChildren(); ++i)
{
// Before recursing, we have to set the traversal information correctly.
rule.TraversalInfo() = traversalInfo;
@@ -101,16 +95,18 @@
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node));
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
@@ -130,16 +126,18 @@
std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
for(int i = 0; i < referenceNode.NumChildren(); i++)
{
+ rule.TraversalInfo() = traversalInfo;
nodesAndScores[i].node = referenceNode.Children()[i];
nodesAndScores[i].score = rule.Score(queryNode, *nodesAndScores[i].node);
+ nodesAndScores[i].travInfo = rule.TraversalInfo();
}
std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
numScores += nodesAndScores.size();
for(int i = 0; i < nodesAndScores.size(); i++)
{
+ rule.TraversalInfo() = nodesAndScores[i].travInfo;
if(rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) {
- rule.TraversalInfo() = traversalInfo;
Traverse(queryNode, *(nodesAndScores[i].node));
} else {
numPrunes += nodesAndScores.size() - i;
More information about the mlpack-svn
mailing list