[mlpack-git] master: Make traversal depth-first in the queries. This results in less memory usage, and some amount of speedup. (43ab267)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:44 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 43ab2675140761ffd7450e6e2056d4d6cd0e87bc
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Feb 27 17:09:30 2015 -0500
Make traversal depth-first in the queries. This results in less memory usage, and some amount of speedup.
>---------------------------------------------------------------
43ab2675140761ffd7450e6e2056d4d6cd0e87bc
.../breadth_first_dual_tree_traverser.hpp | 16 +++++
.../breadth_first_dual_tree_traverser_impl.hpp | 69 ++++++++++++----------
2 files changed, 55 insertions(+), 30 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
index 8a22a70..4af0372 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
@@ -11,12 +11,23 @@
#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
#include <mlpack/core.hpp>
+#include <queue>
#include "binary_space_tree.hpp"
namespace mlpack {
namespace tree {
+template<typename TreeType, typename TraversalInfoType>
+struct QueueFrame
+{
+ TreeType* queryNode;
+ TreeType* referenceNode;
+ size_t queryDepth;
+ double score;
+ TraversalInfoType traversalInfo;
+};
+
template<typename BoundType,
typename StatisticType,
typename MatType,
@@ -31,6 +42,9 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
*/
BreadthFirstDualTreeTraverser(RuleType& rule);
+ typedef QueueFrame<BinarySpaceTree, typename RuleType::TraversalInfoType>
+ QueueFrameType;
+
/**
* Traverse the two trees. This does not reset the number of prunes.
*
@@ -40,6 +54,8 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
*/
void Traverse(BinarySpaceTree& queryNode,
BinarySpaceTree& referenceNode);
+ void Traverse(BinarySpaceTree& queryNode,
+ std::priority_queue<QueueFrameType>& referenceQueue);
//! Get the number of prunes.
size_t NumPrunes() const { return numPrunes; }
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
index 5a32bfc..f48bf95 100644
--- a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -12,8 +12,6 @@
// In case it hasn't been included yet.
#include "breadth_first_dual_tree_traverser.hpp"
-#include <queue>
-
namespace mlpack {
namespace tree {
@@ -33,16 +31,6 @@ BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
{ /* Nothing to do. */ }
template<typename TreeType, typename TraversalInfoType>
-struct QueueFrame
-{
- TreeType* queryNode;
- TreeType* referenceNode;
- size_t queryDepth;
- double score;
- TraversalInfoType traversalInfo;
-};
-
-template<typename TreeType, typename TraversalInfoType>
bool operator<(const QueueFrame<TreeType, TraversalInfoType>& a,
const QueueFrame<TreeType, TraversalInfoType>& b)
{
@@ -70,16 +58,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
// Store the current traversal info.
traversalInfo = rule.TraversalInfo();
- typedef BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>
- TreeType;
-
// Must score the root combination.
const double rootScore = rule.Score(queryRoot, referenceRoot);
if (rootScore == DBL_MAX)
return; // This probably means something is wrong.
- typedef QueueFrame<TreeType, typename RuleType::TraversalInfoType>
- QueueFrameType;
std::priority_queue<QueueFrameType> queue;
QueueFrameType rootFrame;
@@ -91,13 +74,32 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
queue.push(rootFrame);
- while (!queue.empty())
+ // Start the traversal.
+ Traverse(queryRoot, queue);
+}
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::Traverse(
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryNode,
+ std::priority_queue<QueueFrameType>& referenceQueue)
+{
+ // Store queues for the children. We will recurse into the children once our
+ // queue is empty.
+ std::priority_queue<QueueFrameType> leftChildQueue;
+ std::priority_queue<QueueFrameType> rightChildQueue;
+
+ while (!referenceQueue.empty())
{
- QueueFrameType currentFrame = queue.top();
- queue.pop();
+ QueueFrameType currentFrame = referenceQueue.top();
+ referenceQueue.pop();
- TreeType& queryNode = *currentFrame.queryNode;
- TreeType& referenceNode = *currentFrame.referenceNode;
+ BinarySpaceTree& queryNode = *currentFrame.queryNode;
+ BinarySpaceTree& referenceNode = *currentFrame.referenceNode;
typename RuleType::TraversalInfoType ti = currentFrame.traversalInfo;
rule.TraversalInfo() = ti;
const size_t queryDepth = currentFrame.queryDepth;
@@ -137,11 +139,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
// We have to recurse down the query node.
QueueFrameType fl = { queryNode.Left(), &referenceNode, queryDepth + 1,
score, rule.TraversalInfo() };
- queue.push(fl);
+ leftChildQueue.push(fl);
QueueFrameType fr = { queryNode.Right(), &referenceNode, queryDepth + 1,
score, ti };
- queue.push(fr);
+ rightChildQueue.push(fr);
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
@@ -150,11 +152,11 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
// traversal information correctly.
QueueFrameType fl = { &queryNode, referenceNode.Left(), queryDepth,
score, rule.TraversalInfo() };
- queue.push(fl);
+ referenceQueue.push(fl);
QueueFrameType fr = { &queryNode, referenceNode.Right(), queryDepth,
score, ti };
- queue.push(fr);
+ referenceQueue.push(fr);
}
else
{
@@ -164,21 +166,28 @@ BreadthFirstDualTreeTraverser<RuleType>::Traverse(
// correctly.
QueueFrameType fll = { queryNode.Left(), referenceNode.Left(),
queryDepth + 1, score, rule.TraversalInfo() };
- queue.push(fll);
+ leftChildQueue.push(fll);
QueueFrameType flr = { queryNode.Left(), referenceNode.Right(),
queryDepth + 1, score, rule.TraversalInfo() };
- queue.push(flr);
+ leftChildQueue.push(flr);
QueueFrameType frl = { queryNode.Right(), referenceNode.Left(),
queryDepth + 1, score, rule.TraversalInfo() };
- queue.push(frl);
+ rightChildQueue.push(frl);
QueueFrameType frr = { queryNode.Right(), referenceNode.Right(),
queryDepth + 1, score, rule.TraversalInfo() };
- queue.push(frr);
+ rightChildQueue.push(frr);
}
}
+
+ // Now, recurse into the left and right children queues. The order doesn't
+ // matter.
+ if (leftChildQueue.size() > 0)
+ Traverse(*queryNode.Left(), leftChildQueue);
+ if (rightChildQueue.size() > 0)
+ Traverse(*queryNode.Right(), rightChildQueue);
}
}; // namespace tree
More information about the mlpack-git
mailing list