[mlpack-git] master, mlpack-1.0.x: Modify BinarySpaceTree::DualTreeTraverser to properly handle TraversalInfo objects. (014b4f4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:29 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 014b4f43b514c39b137b0d0035104046604d1a86
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 6 20:18:58 2014 +0000
Modify BinarySpaceTree::DualTreeTraverser to properly handle TraversalInfo
objects.
>---------------------------------------------------------------
014b4f43b514c39b137b0d0035104046604d1a86
.../tree/binary_space_tree/dual_tree_traverser.hpp | 13 +-
.../binary_space_tree/dual_tree_traverser_impl.hpp | 139 ++++++++++++++++-----
2 files changed, 117 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
index ac90de5..222105f 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
@@ -28,12 +28,17 @@ class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
DualTreeTraverser(RuleType& rule);
/**
- * Traverse the two trees. This does not reset the number of prunes.
+ * Traverse the two trees. This does not reset the number of prunes. If you
+ * are starting a traversal, the score for the parent node combination is
+ * irrelevant and can be left as 0 and thus does not need to be specified.
*
* @param queryNode The query node to be traversed.
* @param referenceNode The reference node to be traversed.
+ * @param score The score of the current node combination.
*/
- void Traverse(BinarySpaceTree& queryNode, BinarySpaceTree& referenceNode);
+ void Traverse(BinarySpaceTree& queryNode,
+ BinarySpaceTree& referenceNode,
+ const double score = 0.0);
//! Get the number of prunes.
size_t NumPrunes() const { return numPrunes; }
@@ -70,6 +75,10 @@ class BinarySpaceTree<BoundType, StatisticType, MatType>::DualTreeTraverser
//! The number of times a base case was calculated.
size_t numBaseCases;
+
+ //! Traversal information, held in the class so that it isn't continually
+ //! being reallocated.
+ typename RuleType::TraversalInfoType traversalInfo;
};
}; // namespace tree
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
index 42e0707..de92ea5 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
@@ -31,11 +31,15 @@ template<typename RuleType>
void BinarySpaceTree<BoundType, StatisticType, MatType>::
DualTreeTraverser<RuleType>::Traverse(
BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
- BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
+ BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode,
+ const double score /* = 0.0 */)
{
// Increment the visit counter.
++numVisited;
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
// If both are leaves, we must evaluate the base case.
if (queryNode.IsLeaf() && referenceNode.IsLeaf())
{
@@ -43,10 +47,12 @@ DualTreeTraverser<RuleType>::Traverse(
for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
{
// See if we need to investigate this point (this function should be
- // implemented for the single-tree recursion too).
- const double score = rule.Score(query, referenceNode);
+ // implemented for the single-tree recursion too). Restore the traversal
+ // information first.
+ rule.TraversalInfo() = traversalInfo;
+ const double childScore = rule.Score(query, referenceNode);
- if (score == DBL_MAX)
+ if (childScore == DBL_MAX)
continue; // We can't improve this particular point.
for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
@@ -59,53 +65,69 @@ DualTreeTraverser<RuleType>::Traverse(
{
// We have to recurse down the query node. In this case the recursion order
// does not matter.
- double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+ const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
++numScores;
if (leftScore != DBL_MAX)
- Traverse(*queryNode.Left(), referenceNode);
+ Traverse(*queryNode.Left(), referenceNode, leftScore);
else
++numPrunes;
- double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+ // Before recursing, we have to set the traversal information correctly.
+ rule.TraversalInfo() = traversalInfo;
+ const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
++numScores;
if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), referenceNode);
+ Traverse(*queryNode.Right(), referenceNode, rightScore);
else
++numPrunes;
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
// We have to recurse down the reference node. In this case the recursion
- // order does matter.
+ // order does matter. Before recursing, though, we have to set the
+ // traversal information correctly.
double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
double rightScore = rule.Score(queryNode, *referenceNode.Right());
numScores += 2;
if (leftScore < rightScore)
{
- // Recurse to the left.
- Traverse(queryNode, *referenceNode.Left());
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ traversalInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left(), leftScore);
// Is it still valid to recurse to the right?
rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
if (rightScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Right());
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *referenceNode.Right(), rightScore);
+ }
else
++numPrunes;
}
else if (rightScore < leftScore)
{
// Recurse to the right.
- Traverse(queryNode, *referenceNode.Right());
+ Traverse(queryNode, *referenceNode.Right(), rightScore);
// Is it still valid to recurse to the left?
leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
if (leftScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Left());
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left(), leftScore);
+ }
else
++numPrunes;
}
@@ -117,14 +139,21 @@ DualTreeTraverser<RuleType>::Traverse(
}
else
{
- // Choose the left first.
- Traverse(queryNode, *referenceNode.Left());
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ traversalInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left(), leftScore);
rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
rightScore);
if (rightScore != DBL_MAX)
- Traverse(queryNode, *referenceNode.Right());
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *referenceNode.Right(), rightScore);
+ }
else
++numPrunes;
}
@@ -134,36 +163,51 @@ DualTreeTraverser<RuleType>::Traverse(
{
// We have to recurse down both query and reference nodes. Because the
// query descent order does not matter, we will go to the left query child
- // first.
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+ typename RuleType::TraversalInfoType rightInfo;
numScores += 2;
if (leftScore < rightScore)
{
- // Recurse to the left.
- Traverse(*queryNode.Left(), *referenceNode.Left());
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left(), leftScore);
// Is it still valid to recurse to the right?
rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
rightScore);
if (rightScore != DBL_MAX)
- Traverse(*queryNode.Left(), *referenceNode.Right());
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Right(), rightScore);
+ }
else
++numPrunes;
}
else if (rightScore < leftScore)
{
// Recurse to the right.
- Traverse(*queryNode.Left(), *referenceNode.Right());
+ Traverse(*queryNode.Left(), *referenceNode.Right(), rightScore);
// Is it still valid to recurse to the left?
leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
leftScore);
if (leftScore != DBL_MAX)
- Traverse(*queryNode.Left(), *referenceNode.Left());
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left(), leftScore);
+ }
else
++numPrunes;
}
@@ -175,7 +219,10 @@ DualTreeTraverser<RuleType>::Traverse(
}
else
{
- // Choose the left first.
+ // Choose the left first. Restore the left traversal info and store the
+ // right traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
Traverse(*queryNode.Left(), *referenceNode.Left());
// Is it still valid to recurse to the right?
@@ -183,42 +230,62 @@ DualTreeTraverser<RuleType>::Traverse(
rightScore);
if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal information.
+ rule.TraversalInfo() = rightInfo;
Traverse(*queryNode.Left(), *referenceNode.Right());
+ }
else
++numPrunes;
}
}
+ // Restore the main traversal information.
+ rule.TraversalInfo() = traversalInfo;
+
// Now recurse down the right query node.
leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
numScores += 2;
if (leftScore < rightScore)
{
- // Recurse to the left.
- Traverse(*queryNode.Right(), *referenceNode.Left());
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
// Is it still valid to recurse to the right?
rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
rightScore);
if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Right());
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
+ }
else
++numPrunes;
}
else if (rightScore < leftScore)
{
// Recurse to the right.
- Traverse(*queryNode.Right(), *referenceNode.Right());
+ Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
// Is it still valid to recurse to the left?
leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
leftScore);
if (leftScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Left());
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
+ }
else
++numPrunes;
}
@@ -230,15 +297,22 @@ DualTreeTraverser<RuleType>::Traverse(
}
else
{
- // Choose the left first.
- Traverse(*queryNode.Right(), *referenceNode.Left());
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left(), leftScore);
// Is it still valid to recurse to the right?
rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
rightScore);
if (rightScore != DBL_MAX)
- Traverse(*queryNode.Right(), *referenceNode.Right());
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Right(), rightScore);
+ }
else
++numPrunes;
}
@@ -250,4 +324,3 @@ DualTreeTraverser<RuleType>::Traverse(
}; // namespace mlpack
#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
-
More information about the mlpack-git
mailing list