[mlpack-svn] r16738 - mlpack/trunk/src/mlpack/core/tree/rectangle_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 2 13:50:26 EDT 2014


Author: andrewmw94
Date: Wed Jul  2 13:50:26 2014
New Revision: 16738

Log:
fix build

Added:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
Removed:
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/rectangle_tree/rectangle_tree_traverser_impl.hpp

Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp	Wed Jul  2 13:50:26 2014
@@ -0,0 +1,74 @@
+/**
+  * @file single_tree_traverser.hpp
+  * @author Andrew Wells
+  *
+  * A nested class of Rectangle Tree for traversing rectangle type trees
+  * with a given set of rules which indicate the branches to prune and the
+  * order in which to recurse.  This is a depth-first traverser.
+  */
+#ifndef __MLPACK_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_HPP
+#define __MLPACK_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_HPP
+
+#include <mlpack/core.hpp>
+
+#include "rectangle_tree.hpp"
+
+namespace mlpack {
+namespace tree {
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+class RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+    SingleTreeTraverser
+{
+ public:
+  /**
+    * Instantiate the traverser with the given rule set.
+    */
+    SingleTreeTraverser(RuleType& rule);
+
+  /**
+    * Traverse the tree with the given point.
+    *
+    * @param queryIndex The index of the point in the query set which is being
+    *     used as the query point.
+    * @param referenceNode The tree node to be traversed.
+    */
+  void Traverse(const size_t queryIndex, const RectangleTree& referenceNode);
+
+  //! Get the number of prunes.
+  size_t NumPrunes() const { return numPrunes; }
+  //! Modify the number of prunes.
+  size_t& NumPrunes() { return numPrunes; }
+
+  //We use this struct and this function to make the sorting and scoring easy and efficient:
+  class NodeAndScore {
+  public:
+    RectangleTree<SplitType, DescentType, StatisticType, MatType>* node;
+    double score;
+  };
+
+  static bool nodeComparator(const NodeAndScore& obj1,
+                      const NodeAndScore& obj2)
+  {
+    return obj1.score < obj2.score;
+  }
+  
+ private:
+  //! Reference to the rules with which the tree will be traversed.
+  RuleType& rule;
+
+  //! The number of nodes which have been prenud during traversal.
+  size_t numPrunes;
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+// Include implementation.
+#include "single_tree_traverser_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp	Wed Jul  2 13:50:26 2014
@@ -0,0 +1,76 @@
+/**
+  * @file single_tree_traverser_impl.hpp
+  * @author Andrew Wells
+  *
+  * A class for traversing rectangle type trees with a given set of rules
+  * which indicate the branches to prune and the order in which to recurse.
+  * This is a depth-first traverser.
+  */
+#ifndef __MLPAC_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+#define __MLPAC_CORE_TREE_RECTANGLE_TREE_SINGLE_TREE_TRAVERSER_IMPL_HPP
+
+#include "single_tree_traverser.hpp"
+
+#include <algorithm>
+#include <stack>
+
+namespace mlpack {
+namespace tree {
+  
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
+    rule(rule),
+    numPrunes(0)
+{ /* Nothing to do */ }
+
+template<typename SplitType,
+         typename DescentType,
+	 typename StatisticType,
+         typename MatType>
+template<typename RuleType>
+void RectangleTree<SplitType, DescentType, StatisticType, MatType>::
+SingleTreeTraverser<RuleType>::Traverse(
+    const size_t queryIndex,
+    const RectangleTree<SplitType, DescentType, StatisticType, MatType>&
+        referenceNode)
+{
+  // If we reach a leaf node, we need to run the base case.
+  if(referenceNode.IsLeaf()) {
+    for(size_t i = 0; i < referenceNode.Count(); i++) {
+      rule.BaseCase(queryIndex, referenceNode.Points()[i]);
+    }
+    return;
+  }
+  
+  // This is not a leaf node so we sort the children of this node by their scores.    
+  std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
+  for(int i = 0; i < referenceNode.NumChildren(); i++) {
+    nodesAndScores[i].node = referenceNode.Children()[i];
+    nodesAndScores[i].score = rule.Score(queryIndex, *nodesAndScores[i].node);
+  }
+  
+  std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
+  
+  // Now iterate through them starting with the best and stopping when we reach
+  // one that isn't good enough.
+  for(int i = 0; i < referenceNode.NumChildren(); i++) {
+    if(rule.Rescore(queryIndex, *nodesAndScores[i].node, nodesAndScores[i].score) != DBL_MAX)
+      Traverse(queryIndex, nodesAndScores[i].node);
+    else {
+      numPrunes += referenceNode.NumChildren() - i;
+      return;
+    }
+  }
+  // We only get here if we couldn't prune any of them.
+  return;
+}
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif
\ No newline at end of file



More information about the mlpack-svn mailing list