[mlpack-git] master: Add dual tree traverser for Spill Trees. (b4d5a1d)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:35 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit b4d5a1d389a87936365ef9f556c7fbd943a055d4
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Wed Jul 27 00:03:48 2016 -0300
Add dual tree traverser for Spill Trees.
>---------------------------------------------------------------
b4d5a1d389a87936365ef9f556c7fbd943a055d4
.../tree/spill_tree/dual_tree_traverser_impl.hpp | 346 ++++++++++++++++++++-
1 file changed, 337 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/dual_tree_traverser_impl.hpp
index b2eae9f..ff4b90c 100644
--- a/src/mlpack/core/tree/spill_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/dual_tree_traverser_impl.hpp
@@ -34,20 +34,348 @@ DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
template<typename MetricType,
typename StatisticType,
typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
template<typename RuleType>
-void SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+void SpillTree<MetricType, StatisticType, MatType, SplitType>::
DualTreeTraverser<RuleType>::Traverse(
- SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- /* queryNode */,
- SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
- /* referenceNode */)
+ SpillTree<MetricType, StatisticType, MatType, SplitType>& queryNode,
+ SpillTree<MetricType, StatisticType, MatType, SplitType>& referenceNode)
{
- // TODO: Add support for dual tree traverser.
- throw std::runtime_error("Dual tree traverser not implemented for "
- "spill trees.");
+ // Increment the visit counter.
+ ++numVisited;
+
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
+ // If both are leaves, we must evaluate the base case.
+ if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // Loop through each of the points in each node.
+ const size_t queryEnd = queryNode.NumPoints();
+ const size_t refEnd = referenceNode.NumPoints();
+ for (size_t query = 0; query < queryEnd; ++query)
+ {
+ const size_t queryIndex = queryNode.Point(query);
+ // See if we need to investigate this point. Restore the traversal
+ // information first.
+ rule.TraversalInfo() = traversalInfo;
+ const double childScore = rule.Score(queryIndex, referenceNode);
+
+ if (childScore == DBL_MAX)
+ continue; // We can't improve this particular point.
+
+ for (size_t ref = 0; ref < refEnd; ++ref)
+ rule.BaseCase(queryIndex, referenceNode.Point(ref));
+
+ numBaseCases += refEnd;
+ }
+ }
+ else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) ||
+ (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() &&
+ !queryNode.IsLeaf() && !referenceNode.IsLeaf()))
+ {
+ // We have to recurse down the query node. In this case the recursion order
+ // does not matter.
+ const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+ ++numScores;
+
+ if (leftScore != DBL_MAX)
+ Traverse(*queryNode.Left(), referenceNode);
+ else
+ ++numPrunes;
+
+ // Before recursing, we have to set the traversal information correctly.
+ rule.TraversalInfo() = traversalInfo;
+ const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+ ++numScores;
+
+ if (rightScore != DBL_MAX)
+ Traverse(*queryNode.Right(), referenceNode);
+ else
+ ++numPrunes;
+ }
+ else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
+ {
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter. Before recursing, though, we have to set the
+ // traversal information correctly.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ traversalInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(queryNode, *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
+ }
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ if(referenceNode.Overlap())
+ {
+ // If referenceNode is a overlapping node and we can't decide which
+ // child node to traverse, this means that queryNode is at both sides
+ // of the splitting hyperplane. So, as queryNode is a leafNode, all we
+ // can do is single tree search for each point in the query node.
+ const size_t queryEnd = queryNode.NumPoints();
+ SingleTreeTraverser<RuleType> st(rule);
+ // Loop through each of the points in query node.
+ for (size_t query = 0; query < queryEnd; ++query)
+ {
+ const size_t queryIndex = queryNode.Point(query);
+ // See if we need to investigate this point.
+ const double childScore = rule.Score(queryIndex, referenceNode);
+
+ if (childScore == DBL_MAX)
+ continue; // We can't improve this particular point.
+
+ st.Traverse(queryIndex, referenceNode);
+ }
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ traversalInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left());
+
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ }
+ else
+ {
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+ typename RuleType::TraversalInfoType rightInfo;
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ if(referenceNode.Overlap())
+ {
+ // If referenceNode is a overlapping node and we can't decide which
+ // child node to traverse, this means that queryNode.Left() is at both
+ // sides of the splitting hyperplane. So, let's recurse down only the
+ // query node.
+ Traverse(*queryNode.Left(), referenceNode);
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info and store
+ // the right traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal information.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+
+ // Restore the main traversal information.
+ rule.TraversalInfo() = traversalInfo;
+
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ if(referenceNode.Overlap())
+ {
+ // If referenceNode is a overlapping node and we can't decide which
+ // child node to traverse, this means that queryNode.Right() is at
+ // both sides of the splitting hyperplane. So, let's recurse down only
+ // the query node.
+ Traverse(*queryNode.Right(), referenceNode);
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ }
}
} // namespace tree
More information about the mlpack-git
mailing list