[mlpack-svn] r16738 - mlpack/trunk/src/mlpack/core/tree/rectangle_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 2 13:50:26 EDT 2014
Author: andrewmw94
Date: Wed Jul 2 13:50:26 2014
New Revision: 16738
Log:
fix build
Added:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
Removed:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp Wed Jul 2 13:50:26 2014
@@ -0,0 +1,74 @@
+/**
+ * @file single_tree_traverser.hpp
+ * @author Andrew Wells
+ *
+ * A nested class of Rectangle Tree for traversing rectangle type trees
+ * with a given set of rules which indicate the branches to prune and the
+ * order in which to recurse. This is a depth-first traverser.
+ */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+ SingleTreeTraverser
+{
+ public:
+ /**
+ * Instantiate the traverser with the given rule set.
+ */
+ SingleTreeTraverser(RuleType& rule);
+
+ /**
+ * Traverse the tree with the given point.
+ *
+ * @param queryIndex The index of the point in the query set which is being
+ * used as the query point.
+ * @param referenceNode The tree node to be traversed.
+ */
+ void Traverse(const size_t queryIndex, const RectangleTree& referenceNode);
+
+ //! Get the number of prunes.
+ size_t NumPrunes() const { return numPrunes; }
+ //! Modify the number of prunes.
+ size_t& NumPrunes() { return numPrunes; }
+
+ //We use this struct and this function to make the sorting and scoring easy and efficient:
+ class NodeAndScore {
+ public:
+ RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
+ double score;
+ };
+
+ static bool nodeComparator(const NodeAndScore& obj1,
+ const NodeAndScore& obj2)
+ {
+ return obj1.score < obj2.score;
+ }
+
+ private:
+ //! Reference to the rules with which the tree will be traversed.
+ RuleType& rule;
+
+ //! The number of nodes which have been prenud during traversal.
+ size_t numPrunes;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp Wed Jul 2 13:50:26 2014
@@ -0,0 +1,76 @@
+/**
+ * @file single_tree_traverser_impl.hpp
+ * @author Andrew Wells
+ *
+ * A class for traversing rectangle type trees with a given set of rules
+ * which indicate the branches to prune and the order in which to recurse.
+ * This is a depth-first traverser.
+ */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+#include "single_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
+ rule(rule),
+ numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+ typename DescentType,
+ typename StatisticType,
+ typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::Traverse(
+ const size_t queryIndex,
+ const RectangleTree<SplitType, DescentType, StatisticType, MatType>&
+ referenceNode)
+{
+ // If we reach a leaf node, we need to run the base case.
+ if(referenceNode.IsLeaf()) {
+ for(size_t i = 0; i < referenceNode.Count(); i++) {
+ rule.BaseCase(queryIndex, referenceNode.Points()[i]);
+ }
+ return;
+ }
+
+ // This is not a leaf node so we sort the children of this node by their scores.
+ std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+ for(int i = 0; i < referenceNode.NumChildren(); i++) {
+ nodesAndScores[i].node = referenceNode.Children()[i];
+ nodesAndScores[i].score = rule.Score(queryIndex, *nodesAndScores[i].node);
+ }
+
+ std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+
+ // Now iterate through them starting with the best and stopping when we reach
+ // one that isn't good enough.
+ for(int i = 0; i < referenceNode.NumChildren(); i++) {
+ if(rule.Rescore(queryIndex, *nodesAndScores[i].node, nodesAndScores[i].score) != DBL_MAX)
+ Traverse(queryIndex, nodesAndScores[i].node);
+ else {
+ numPrunes += referenceNode.NumChildren() - i;
+ return;
+ }
+ }
+ // We only get here if we couldn't prune any of them.
+ return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file
More information about the mlpack-svn
mailing list