[mlpack-git] master: Allow access to number of base cases and scores. (5584483)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Oct 22 11:11:19 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c81893381e80d4ecae4283cec5fe5264bdf4f677...d1dfaa8e0978e01c240660a3217e68c4fa7c3e0a

>---------------------------------------------------------------

commit 5584483a1a041a86757cf006a45317dd7975d3ec
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Oct 22 15:09:10 2015 +0000

    Allow access to number of base cases and scores.


>---------------------------------------------------------------

5584483a1a041a86757cf006a45317dd7975d3ec
 src/mlpack/methods/range_search/range_search.hpp   | 14 ++++++++++++
 .../methods/range_search/range_search_impl.hpp     | 25 +++++++++++++++++++++-
 .../methods/range_search/range_search_rules.hpp    | 14 ++++++++++--
 .../range_search/range_search_rules_impl.hpp       | 11 +++++++---
 4 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 3ebf17e..d1a7c02 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -199,6 +199,15 @@ class RangeSearch
               std::vector<std::vector<size_t>>& neighbors,
               std::vector<std::vector<double>>& distances);
 
+  //! Get the number of base cases during the last search.
+  size_t BaseCases() const { return baseCases; }
+  //! Get the number of scores during the last search.
+  size_t Scores() const { return scores; }
+
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int version);
+
   //! Returns a string representation of this object.
   std::string ToString() const;
 
@@ -223,6 +232,11 @@ class RangeSearch
 
   //! Instantiated distance metric.
   MetricType metric;
+
+  //! The total number of base cases during the last search.
+  size_t baseCases;
+  //! The total number of scores during the last search.
+  size_t scores;
 };
 
 } // namespace range
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 8a156e6..517dc69 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -55,7 +55,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
     treeOwner(!naive), // If in naive mode, we are not building any trees.
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
-    metric(metric)
+    metric(metric),
+    baseCases(0),
+    scores(0)
 {
   // Nothing to do.
 }
@@ -136,6 +138,10 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
   // Create the helper object for the traversal.
   typedef RangeSearchRules<MetricType, Tree> RuleType;
 
+  // Reset counts.
+  baseCases = 0;
+  scores = 0;
+
   if (naive)
   {
     RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
@@ -145,6 +151,8 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     for (size_t i = 0; i < querySet.n_cols; ++i)
       for (size_t j = 0; j < referenceSet.n_cols; ++j)
         rules.BaseCase(i, j);
+
+    baseCases += (querySet.n_cols * referenceSet->n_cols);
   }
   else if (singleMode)
   {
@@ -156,6 +164,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     // Now have it traverse for each point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
+
+    baseCases += rules.BaseCases();
+    scores += rules.Scores();
   }
   else // Dual-tree recursion.
   {
@@ -174,6 +185,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
     traverser.Traverse(*queryTree, *referenceTree);
 
+    baseCases += rules.BaseCases();
+    scores += rules.Scores();
+
     // Clean up tree memory.
     delete queryTree;
   }
@@ -292,6 +306,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
   Timer::Stop("range_search/computing_neighbors");
 
+  baseCases = rules.BaseCases();
+  scores = rules.Scores();
+
   // Do we need to map indices?
   if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
@@ -360,6 +377,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     // Now have it traverse for each point.
     for (size_t i = 0; i < referenceSet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
+
+    baseCases = rules.BaseCases();
+    scores = rules.Scores();
   }
   else // Dual-tree recursion.
   {
@@ -367,6 +387,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*referenceTree, *referenceTree);
+
+    baseCases = rules.BaseCases();
+    scores = rules.Scores();
   }
 
   Timer::Stop("range_search/computing_neighbors");
diff --git a/src/mlpack/methods/range_search/range_search_rules.hpp b/src/mlpack/methods/range_search/range_search_rules.hpp
index 8d4af1a..78eff00 100644
--- a/src/mlpack/methods/range_search/range_search_rules.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules.hpp
@@ -101,6 +101,11 @@ class RangeSearchRules
   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
 
+  //! Get the number of base cases.
+  size_t BaseCases() const { return baseCases; }
+  //! Get the number of scores (that is, calls to RangeDistance()).
+  size_t Scores() const { return scores; }
+
  private:
   //! The reference set.
   const arma::mat& referenceSet;
@@ -135,10 +140,15 @@ class RangeSearchRules
                  TreeType& referenceNode);
 
   TraversalInfoType traversalInfo;
+
+  //! The number of base cases.
+  size_t baseCases;
+  //! THe number of scores.
+  size_t scores;
 };
 
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
 
 // Include implementation.
 #include "range_search_rules_impl.hpp"
diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp
index afa6306..30c507d 100644
--- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp
@@ -30,7 +30,9 @@ RangeSearchRules<MetricType, TreeType>::RangeSearchRules(
     metric(metric),
     sameSet(sameSet),
     lastQueryIndex(querySet.n_cols),
-    lastReferenceIndex(referenceSet.n_cols)
+    lastReferenceIndex(referenceSet.n_cols),
+    baseCases(0),
+    scores(0)
 {
   // Nothing to do.
 }
@@ -53,6 +55,7 @@ double RangeSearchRules<MetricType, TreeType>::BaseCase(
 
   const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
       referenceSet.unsafe_col(referenceIndex));
+  ++baseCases;
 
   // Update last indices, so we don't accidentally perform a base case twice.
   lastQueryIndex = queryIndex;
@@ -107,6 +110,7 @@ double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
   else
   {
     distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+    ++scores;
   }
 
   // If the ranges do not overlap, prune this node.
@@ -176,6 +180,7 @@ double RangeSearchRules<MetricType, TreeType>::Score(TreeType& queryNode,
   {
     // Just perform the calculation.
     distances = referenceNode.RangeDistance(&queryNode);
+    ++scores;
   }
 
   // If the ranges do not overlap, prune this node.
@@ -249,7 +254,7 @@ void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex,
   }
 }
 
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
 
 #endif



More information about the mlpack-git mailing list