[mlpack-git] master: Allow access to number of base cases and scores. (5584483)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Oct 22 11:11:19 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/c81893381e80d4ecae4283cec5fe5264bdf4f677...d1dfaa8e0978e01c240660a3217e68c4fa7c3e0a
>---------------------------------------------------------------
commit 5584483a1a041a86757cf006a45317dd7975d3ec
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Oct 22 15:09:10 2015 +0000
Allow access to number of base cases and scores.
>---------------------------------------------------------------
5584483a1a041a86757cf006a45317dd7975d3ec
src/mlpack/methods/range_search/range_search.hpp | 14 ++++++++++++
.../methods/range_search/range_search_impl.hpp | 25 +++++++++++++++++++++-
.../methods/range_search/range_search_rules.hpp | 14 ++++++++++--
.../range_search/range_search_rules_impl.hpp | 11 +++++++---
4 files changed, 58 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 3ebf17e..d1a7c02 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -199,6 +199,15 @@ class RangeSearch
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances);
+ //! Get the number of base cases during the last search.
+ size_t BaseCases() const { return baseCases; }
+ //! Get the number of scores during the last search.
+ size_t Scores() const { return scores; }
+
+ //! Serialize the model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int version);
+
//! Returns a string representation of this object.
std::string ToString() const;
@@ -223,6 +232,11 @@ class RangeSearch
//! Instantiated distance metric.
MetricType metric;
+
+ //! The total number of base cases during the last search.
+ size_t baseCases;
+ //! The total number of scores during the last search.
+ size_t scores;
};
} // namespace range
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 8a156e6..517dc69 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -55,7 +55,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
treeOwner(!naive), // If in naive mode, we are not building any trees.
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
- metric(metric)
+ metric(metric),
+ baseCases(0),
+ scores(0)
{
// Nothing to do.
}
@@ -136,6 +138,10 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Create the helper object for the traversal.
typedef RangeSearchRules<MetricType, Tree> RuleType;
+ // Reset counts.
+ baseCases = 0;
+ scores = 0;
+
if (naive)
{
RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
@@ -145,6 +151,8 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
for (size_t i = 0; i < querySet.n_cols; ++i)
for (size_t j = 0; j < referenceSet.n_cols; ++j)
rules.BaseCase(i, j);
+
+ baseCases += (querySet.n_cols * referenceSet->n_cols);
}
else if (singleMode)
{
@@ -156,6 +164,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Now have it traverse for each point.
for (size_t i = 0; i < querySet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
+
+ baseCases += rules.BaseCases();
+ scores += rules.Scores();
}
else // Dual-tree recursion.
{
@@ -174,6 +185,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
traverser.Traverse(*queryTree, *referenceTree);
+ baseCases += rules.BaseCases();
+ scores += rules.Scores();
+
// Clean up tree memory.
delete queryTree;
}
@@ -292,6 +306,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
Timer::Stop("range_search/computing_neighbors");
+ baseCases = rules.BaseCases();
+ scores = rules.Scores();
+
// Do we need to map indices?
if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
@@ -360,6 +377,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Now have it traverse for each point.
for (size_t i = 0; i < referenceSet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
+
+ baseCases = rules.BaseCases();
+ scores = rules.Scores();
}
else // Dual-tree recursion.
{
@@ -367,6 +387,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*referenceTree, *referenceTree);
+
+ baseCases = rules.BaseCases();
+ scores = rules.Scores();
}
Timer::Stop("range_search/computing_neighbors");
diff --git a/src/mlpack/methods/range_search/range_search_rules.hpp b/src/mlpack/methods/range_search/range_search_rules.hpp
index 8d4af1a..78eff00 100644
--- a/src/mlpack/methods/range_search/range_search_rules.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules.hpp
@@ -101,6 +101,11 @@ class RangeSearchRules
const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
TraversalInfoType& TraversalInfo() { return traversalInfo; }
+ //! Get the number of base cases.
+ size_t BaseCases() const { return baseCases; }
+ //! Get the number of scores (that is, calls to RangeDistance()).
+ size_t Scores() const { return scores; }
+
private:
//! The reference set.
const arma::mat& referenceSet;
@@ -135,10 +140,15 @@ class RangeSearchRules
TreeType& referenceNode);
TraversalInfoType traversalInfo;
+
+ //! The number of base cases.
+ size_t baseCases;
+ //! THe number of scores.
+ size_t scores;
};
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
// Include implementation.
#include "range_search_rules_impl.hpp"
diff --git a/src/mlpack/methods/range_search/range_search_rules_impl.hpp b/src/mlpack/methods/range_search/range_search_rules_impl.hpp
index afa6306..30c507d 100644
--- a/src/mlpack/methods/range_search/range_search_rules_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_rules_impl.hpp
@@ -30,7 +30,9 @@ RangeSearchRules<MetricType, TreeType>::RangeSearchRules(
metric(metric),
sameSet(sameSet),
lastQueryIndex(querySet.n_cols),
- lastReferenceIndex(referenceSet.n_cols)
+ lastReferenceIndex(referenceSet.n_cols),
+ baseCases(0),
+ scores(0)
{
// Nothing to do.
}
@@ -53,6 +55,7 @@ double RangeSearchRules<MetricType, TreeType>::BaseCase(
const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
+ ++baseCases;
// Update last indices, so we don't accidentally perform a base case twice.
lastQueryIndex = queryIndex;
@@ -107,6 +110,7 @@ double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
else
{
distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
+ ++scores;
}
// If the ranges do not overlap, prune this node.
@@ -176,6 +180,7 @@ double RangeSearchRules<MetricType, TreeType>::Score(TreeType& queryNode,
{
// Just perform the calculation.
distances = referenceNode.RangeDistance(&queryNode);
+ ++scores;
}
// If the ranges do not overlap, prune this node.
@@ -249,7 +254,7 @@ void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex,
}
}
-}; // namespace range
-}; // namespace mlpack
+} // namespace range
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list