[mlpack-git] master: Add a semi-hackish breadth-first traverser. The tree abstractions will need to change to support arbitrary traverser types (probably by adding a template parameter) but for now this works to make DualTreeKMeans work. (019ecbc)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:02:34 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 019ecbc045b09c000504eb633585109053c63bae
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Nov 5 21:40:37 2014 +0000
Add a semi-hackish breadth-first traverser. The tree abstractions will need to
change to support arbitrary traverser types (probably by adding a template
parameter) but for now this works to make DualTreeKMeans work.
>---------------------------------------------------------------
019ecbc045b09c000504eb633585109053c63bae
.../tree/binary_space_tree/binary_space_tree.hpp | 3 +
...r.hpp => breadth_first_dual_tree_traverser.hpp} | 22 +-
.../breadth_first_dual_tree_traverser_impl.hpp | 442 +++++++++++++++++++++
3 files changed, 456 insertions(+), 11 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index eadd300..db1ece7 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -87,6 +87,9 @@ class BinarySpaceTree
template<typename RuleType>
class DualTreeTraverser;
+ template<typename RuleType>
+ class BreadthFirstDualTreeTraverser;
+
/**
* Construct this as the root node of a binary space tree using the given
* dataset. This will modify the ordering of the points in the dataset!
diff --git a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
similarity index 75%
copy from src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
copy to src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
index 7cd1871..8a22a70 100644
--- a/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser.hpp
@@ -1,14 +1,14 @@
/**
- * @file dual_tree_traverser.hpp
+ * @file breadth_first_dual_tree_traverser.hpp
* @author Ryan Curtin
*
- * Defines the DualTreeTraverser for the BinarySpaceTree tree type. This is a
- * nested class of BinarySpaceTree which traverses two trees in a depth-first
- * manner with a given set of rules which indicate the branches which can be
- * pruned and the order in which to recurse.
+ * Defines the BreadthFirstDualTreeTraverser for the BinarySpaceTree tree type.
+ * This is a nested class of BinarySpaceTree which traverses two trees in a
+ * breadth-first manner with a given set of rules which indicate the branches
+ * which can be pruned and the order in which to recurse.
*/
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
#include <mlpack/core.hpp>
@@ -23,13 +23,13 @@ template<typename BoundType,
typename SplitType>
template<typename RuleType>
class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
- DualTreeTraverser
+ BreadthFirstDualTreeTraverser
{
public:
/**
* Instantiate the dual-tree traverser with the given rule set.
*/
- DualTreeTraverser(RuleType& rule);
+ BreadthFirstDualTreeTraverser(RuleType& rule);
/**
* Traverse the two trees. This does not reset the number of prunes.
@@ -86,7 +86,7 @@ class BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
}; // namespace mlpack
// Include implementation.
-#include "dual_tree_traverser_impl.hpp"
+#include "breadth_first_dual_tree_traverser_impl.hpp"
-#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_HPP
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_HPP
diff --git a/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
new file mode 100644
index 0000000..bd81df2
--- /dev/null
+++ b/src/mlpack/core/tree/binary_space_tree/breadth_first_dual_tree_traverser_impl.hpp
@@ -0,0 +1,442 @@
+/**
+ * @file breadth_first_dual_tree_traverser_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of the BreadthFirstDualTreeTraverser for BinarySpaceTree.
+ * This is a way to perform a dual-tree traversal of two trees. The trees must
+ * be the same type.
+ */
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "breadth_first_dual_tree_traverser.hpp"
+
+#include <queue>
+
+namespace mlpack {
+namespace tree {
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::BreadthFirstDualTreeTraverser(
+ RuleType& rule) :
+ rule(rule),
+ numPrunes(0),
+ numVisited(0),
+ numScores(0),
+ numBaseCases(0)
+{ /* Nothing to do. */ }
+
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename RuleType>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+BreadthFirstDualTreeTraverser<RuleType>::Traverse(
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>& queryRoot,
+ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>&
+ referenceRoot)
+{
+ // Increment the visit counter.
+ ++numVisited;
+
+ // Store the current traversal info.
+ traversalInfo = rule.TraversalInfo();
+
+ typedef BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>
+ TreeType;
+
+ std::queue<TreeType*> queryList;
+ std::queue<TreeType*> referenceList;
+ std::queue<typename RuleType::TraversalInfoType> traversalInfos;
+ queryList.push(&queryRoot);
+ referenceList.push(&referenceRoot);
+ traversalInfos.push(rule.TraversalInfo());
+
+ while (!queryList.empty())
+ {
+ TreeType& queryNode = *queryList.front();
+ TreeType& referenceNode = *referenceList.front();
+ typename RuleType::TraversalInfoType ti = traversalInfos.front();
+
+ queryList.pop();
+ referenceList.pop();
+ traversalInfos.pop();
+
+ rule.TraversalInfo() = ti;
+
+ // If both are leaves, we must evaluate the base case.
+ if (queryNode.IsLeaf() && referenceNode.IsLeaf())
+ {
+ // Loop through each of the points in each node.
+ for (size_t query = queryNode.Begin(); query < queryNode.End(); ++query)
+ {
+ // See if we need to investigate this point (this function should be
+ // implemented for the single-tree recursion too). Restore the traversal
+ // information first.
+// const double childScore = rule.Score(query, referenceNode);
+
+// if (childScore == DBL_MAX)
+// continue; // We can't improve this particular point.
+
+ for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
+ rule.BaseCase(query, ref);
+
+ numBaseCases += referenceNode.Count();
+ }
+ }
+ else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
+ {
+ // We have to recurse down the query node. In this case the recursion order
+ // does not matter.
+ const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+ ++numScores;
+
+ if (leftScore != DBL_MAX)
+ {
+ queryList.push(queryNode.Left());
+ referenceList.push(&referenceNode);
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push1 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ {
+ ++numPrunes;
+ }
+
+ // Before recursing, we have to set the traversal information correctly.
+ rule.TraversalInfo() = ti;
+ const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+ ++numScores;
+
+ if (rightScore != DBL_MAX)
+ {
+ queryList.push(queryNode.Right());
+ referenceList.push(&referenceNode);
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push2 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
+ {
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter. Before recursing, though, we have to set the
+ // traversal information correctly.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push3 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push4 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push5 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push6 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else // leftScore is equal to rightScore.
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push7 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(&queryNode);
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push8 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ else
+ {
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+ typename RuleType::TraversalInfoType rightInfo;
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push9 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push10 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push11 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push12 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info and store the
+ // right traversal info.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push13 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal information.
+ queryList.push(queryNode.Left());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push14 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+
+ // Restore the main traversal information.
+ rule.TraversalInfo() = ti;
+
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = ti;
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+ numScores += 2;
+
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push15 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push16 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
+ {
+ // Recurse to the right.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push17 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push18 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ else
+ {
+ if (leftScore == DBL_MAX)
+ {
+ numPrunes += 2;
+ }
+ else
+ {
+ // Choose the left first. Restore the left traversal info. Store the
+ // right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Left());
+ traversalInfos.push(leftInfo);
+// Log::Debug << "Push19 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ queryList.push(queryNode.Right());
+ referenceList.push(referenceNode.Right());
+ traversalInfos.push(rule.TraversalInfo());
+// Log::Debug << "Push20 " << queryList.back()->Begin() << ", " <<
+//queryList.back()->Count() << "; " << referenceList.back()->Begin() << ", "
+// << referenceList.back()->Count() << "\n";
+ }
+ else
+ ++numPrunes;
+ }
+ }
+ }
+ }
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_BINARY_SPACE_TREE_BREADTH_FIRST_DUAL_TREE_TRAVERSER_IMPL_HPP
More information about the mlpack-git
mailing list