[mlpack-svn] r16644 - mlpack/trunk/src/mlpack/core/tree/rectangle_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jun 5 16:52:29 EDT 2014
Author: andrewmw94
Date: Thu Jun 5 16:52:29 2014
New Revision: 16644
Log:
Rectangle Tree Traversal implementation.
Modified:
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp
mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp Thu Jun 5 16:52:29 2014
@@ -18,9 +18,10 @@
template<typename StatisticType,
typename MatType,
- typename SplitType>
+ typename SplitType
+ typename DescentType>
template<typename RuleType>
-class RectangleTree<StatisticType, MatType, SplitType>::
+class RectangleTree<StatisticType, MatType, SplitType, DescentType>::
RectangleTreeTraverser
{
public:
Modified: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp Thu Jun 5 16:52:29 2014
@@ -1,5 +1,5 @@
/**
- * @file rectangle_tree_traverser.hpp
+ * @file rectangle_tree_traverser_impl.hpp
* @author Andrew Wells
*
* A class for traversing rectangle type trees with a given set of rules
@@ -11,6 +11,7 @@
#include "rectangle_tree_traverser.hpp"
+#include <algorithm>
#include <stack>
namespace mlpack {
@@ -38,7 +39,36 @@
RectangeTree<StatisticType, MatyType, SplitType, DescentType>&
referenceNode)
{
-
+ // If we reach a leaf node, we need to run the base case.
+ if(referenceNode.IsLeaf()) {
+ for(size_t i = 0; i < referenceNode.Count(); i++) {
+ rule.BaseCase(queryIndex, i);
+ }
+ return;
+ }
+
+ // This is not a leaf node so we:
+ // Sort the children of this node by their scores.
+ std::vector<RectangleTree*> nodes = new std::vector<RectangleTree*>(referenceNode.NumChildren());
+ std::vector<double> scores = new std::vector<double>(referenceNode.NumChildren());
+ for(int i = 0; i < referenceNode.NumChildren(); i++) {
+ nodes[i] = referenceNode.Children()[i];
+ scores[i] = Rule.Score(nodes[i]);
+ }
+ Rule.sortNodesAndScores(&nodes, &scores);
+
+ // Iterate through them starting with the best and stopping when we reach
+ // one that isn't good enough.
+ for(int i = 0; i < referenceNode.NumChildren(); i++) {
+ if(Rule.Rescore(queryIndex, nodes[i], scores[i]) != DBL_MAX)
+ Traverse(queryIndex, nodes[i]);
+ else {
+ numPrunes += referenceNode.NumChildren - i;
+ return;
+ }
+ }
+ // We only get here if we couldn't prune any of them.
+ return;
}
}; // namespace tree
More information about the mlpack-svn
mailing list