[mlpack-svn] r15780 - mlpack/trunk/src/mlpack/core/tree/binary_space_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 16:10:12 EDT 2013


Author: rcurtin
Date: Fri Sep 13 16:10:12 2013
New Revision: 15780

Log:
Add functions to get number of base cases, number of node combinations visited.


Modified:
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
   mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser.hpp	Fri Sep 13 16:10:12 2013
@@ -40,12 +40,36 @@
   //! Modify the number of prunes.
   size_t& NumPrunes() { return numPrunes; }
 
+  //! Get the number of visited combinations.
+  size_t NumVisited() const { return numVisited; }
+  //! Modify the number of visited combinations.
+  size_t& NumVisited() { return numVisited; }
+
+  //! Get the number of times a node combination was scored.
+  size_t NumScores() const { return numScores; }
+  //! Modify the number of times a node combination was scored.
+  size_t& NumScores() { return numScores; }
+
+  //! Get the number of times a base case was calculated.
+  size_t NumBaseCases() const { return numBaseCases; }
+  //! Modify the number of times a base case was calculated.
+  size_t& NumBaseCases() { return numBaseCases; }
+
  private:
   //! Reference to the rules with which the trees will be traversed.
   RuleType& rule;
 
-  //! The number of nodes which have been pruned during traversal.
+  //! The number of prunes.
   size_t numPrunes;
+
+  //! The number of node combinations that have been visited during traversal.
+  size_t numVisited;
+
+  //! The number of times a node combination was scored.
+  size_t numScores;
+
+  //! The number of times a base case was calculated.
+  size_t numBaseCases;
 };
 
 }; // namespace tree

Modified: mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/binary_space_tree/dual_tree_traverser_impl.hpp	Fri Sep 13 16:10:12 2013
@@ -20,7 +20,10 @@
 BinarySpaceTree<BoundType, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     rule(rule),
-    numPrunes(0)
+    numPrunes(0),
+    numVisited(0),
+    numScores(0),
+    numBaseCases(0)
 { /* Nothing to do. */ }
 
 template<typename BoundType, typename StatisticType, typename MatType>
@@ -30,6 +33,9 @@
     BinarySpaceTree<BoundType, StatisticType, MatType>& queryNode,
     BinarySpaceTree<BoundType, StatisticType, MatType>& referenceNode)
 {
+  // Increment the visit counter.
+  ++numVisited;
+
   // If both are leaves, we must evaluate the base case.
   if (queryNode.IsLeaf() && referenceNode.IsLeaf())
   {
@@ -45,6 +51,8 @@
 
       for (size_t ref = referenceNode.Begin(); ref < referenceNode.End(); ++ref)
         rule.BaseCase(query, ref);
+
+      numBaseCases += referenceNode.Count();
     }
   }
   else if ((!queryNode.IsLeaf()) && referenceNode.IsLeaf())
@@ -52,6 +60,7 @@
     // We have to recurse down the query node.  In this case the recursion order
     // does not matter.
     double leftScore = rule.Score(*queryNode.Left(), referenceNode);
+    ++numScores;
 
     if (leftScore != DBL_MAX)
       Traverse(*queryNode.Left(), referenceNode);
@@ -59,6 +68,7 @@
       ++numPrunes;
 
     double rightScore = rule.Score(*queryNode.Right(), referenceNode);
+    ++numScores;
 
     if (rightScore != DBL_MAX)
       Traverse(*queryNode.Right(), referenceNode);
@@ -71,6 +81,7 @@
     // order does matter.
     double leftScore = rule.Score(queryNode, *referenceNode.Left());
     double rightScore = rule.Score(queryNode, *referenceNode.Right());
+    numScores += 2;
 
     if (leftScore < rightScore)
     {
@@ -126,6 +137,7 @@
     // first.
     double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
     double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+    numScores += 2;
 
     if (leftScore < rightScore)
     {
@@ -180,6 +192,7 @@
     // Now recurse down the right query node.
     leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
     rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+    numScores += 2;
 
     if (leftScore < rightScore)
     {



More information about the mlpack-svn mailing list