[mlpack-git] master, mlpack-1.0.x: Modify CoverTree::DualTreeTraverser to properly handle TraversalInfo objects. (41a730e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:42:31 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 41a730ee43311e15bacae3e35dac9328193d97d4
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Feb 6 20:21:02 2014 +0000
Modify CoverTree::DualTreeTraverser to properly handle TraversalInfo objects.
>---------------------------------------------------------------
41a730ee43311e15bacae3e35dac9328193d97d4
.../core/tree/cover_tree/dual_tree_traverser.hpp | 51 +++---
.../tree/cover_tree/dual_tree_traverser_impl.hpp | 180 +++++++++++++--------
2 files changed, 147 insertions(+), 84 deletions(-)
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
index 0e7142e..e09d2d0 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
@@ -13,10 +13,6 @@
namespace mlpack {
namespace tree {
-//! Forward declaration of struct to be used for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry;
-
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
template<typename RuleType>
class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
@@ -35,14 +31,6 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
*/
void Traverse(CoverTree& queryNode, CoverTree& referenceNode);
- /**
- * Helper function for traversal of the two trees.
- */
- void Traverse(CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap);
-
//! Get the number of pruned nodes.
size_t NumPrunes() const { return numPrunes; }
//! Modify the number of pruned nodes.
@@ -61,17 +49,44 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
//! The number of pruned nodes.
size_t numPrunes;
+ //! Struct used for traversal.
+ struct DualCoverTreeMapEntry
+ {
+ //! The node this entry refers to.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+ //! The score of the node.
+ double score;
+ //! The base case.
+ double baseCase;
+ //! The traversal info associated with the call to Score() for this entry.
+ typename RuleType::TraversalInfoType traversalInfo;
+
+ //! Comparison operator, for sorting within the map.
+ bool operator<(const DualCoverTreeMapEntry& other) const
+ {
+ if (score == other.score)
+ return (baseCase < other.baseCase);
+ else
+ return (score < other.score);
+ }
+ };
+
+ /**
+ * Helper function for traversal of the two trees.
+ */
+ void Traverse(CoverTree& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry> >&
+ referenceMap);
+
//! Prepare map for recursion.
void PruneMap(CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
+ std::map<int, std::vector<DualCoverTreeMapEntry> >&
referenceMap,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >& childMap);
+ std::map<int, std::vector<DualCoverTreeMapEntry> >&
+ childMap);
void ReferenceRecursion(CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
+ std::map<int, std::vector<DualCoverTreeMapEntry> >&
referenceMap);
};
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
index 55d88a7..741e257 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
@@ -13,22 +13,6 @@
namespace mlpack {
namespace tree {
-//! The object placed in the map for tree traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-struct DualCoverTreeMapEntry
-{
- //! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
- //! The score of the node.
- double score;
-
- //! Comparison operator, for sorting within the map.
- bool operator<(const DualCoverTreeMapEntry& other) const
- {
- return (score < other.score);
- }
-};
-
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
template<typename RuleType>
CoverTree<MetricType, RootPointPolicy, StatisticType>::
@@ -44,16 +28,15 @@ DualTreeTraverser<RuleType>::Traverse(
CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
// Start by creating a map and adding the reference root node to it.
- std::map<int, std::vector<MapEntryType> > refMap;
+ std::map<int, std::vector<DualCoverTreeMapEntry> > refMap;
- MapEntryType rootRefEntry;
+ DualCoverTreeMapEntry rootRefEntry;
rootRefEntry.referenceNode = &referenceNode;
rootRefEntry.score = 0.0; // Must recurse into.
+ rootRefEntry.baseCase = 0.0;
+ rootRefEntry.traversalInfo = rule.TraversalInfo();
refMap[referenceNode.Scale()].push_back(rootRefEntry);
@@ -65,13 +48,8 @@ template<typename RuleType>
void CoverTree<MetricType, RootPointPolicy, StatisticType>::
DualTreeTraverser<RuleType>::Traverse(
CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
+ std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
{
- // Convenience typedef.
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
if (referenceMap.size() == 0)
return; // Nothing to do!
@@ -92,14 +70,16 @@ DualTreeTraverser<RuleType>::Traverse(
// results are separate and independent. I don't think this is true in
// every case, and we may have to modify this section to consider scores in
// the future.
- std::map<int, std::vector<MapEntryType> > childMap;
- PruneMap(queryNode, referenceMap, childMap);
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ for (size_t i = 1; i < queryNode.NumChildren(); ++i)
{
// We need a copy of the map for this child.
- std::map<int, std::vector<MapEntryType> > thisChildMap = childMap;
- Traverse(queryNode.Child(i), thisChildMap);
+ std::map<int, std::vector<DualCoverTreeMapEntry> > childMap;
+ PruneMap(queryNode.Child(i), referenceMap, childMap);
+ Traverse(queryNode.Child(i), childMap);
}
+ std::map<int, std::vector<DualCoverTreeMapEntry> > selfChildMap;
+ PruneMap(queryNode.Child(0), referenceMap, selfChildMap);
+ Traverse(queryNode.Child(0), selfChildMap);
}
if (queryNode.Scale() != INT_MIN)
@@ -109,12 +89,13 @@ DualTreeTraverser<RuleType>::Traverse(
// evaluations to do.
Log::Assert((*referenceMap.begin()).first == INT_MIN);
Log::Assert(queryNode.Scale() == INT_MIN);
- std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
+ std::vector<DualCoverTreeMapEntry>& pointVector =
+ (*referenceMap.begin()).second;
for (size_t i = 0; i < pointVector.size(); ++i)
{
// Get a reference to the frame.
- const MapEntryType& frame = pointVector[i];
+ const DualCoverTreeMapEntry& frame = pointVector[i];
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
@@ -128,7 +109,9 @@ DualTreeTraverser<RuleType>::Traverse(
continue;
}
- // Score the node, to see if we can prune it.
+ // Score the node, to see if we can prune it, after restoring the traversal
+ // info.
+ rule.TraversalInfo() = frame.traversalInfo;
double score = rule.Score(queryNode, *refNode);
if (score == DBL_MAX)
@@ -147,41 +130,91 @@ template<typename RuleType>
void CoverTree<MetricType, RootPointPolicy, StatisticType>::
DualTreeTraverser<RuleType>::PruneMap(
CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& referenceMap,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
- RootPointPolicy, StatisticType> > >& childMap)
+ std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap,
+ std::map<int, std::vector<DualCoverTreeMapEntry> >& childMap)
{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
if (referenceMap.empty())
return; // Nothing to do.
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
- while ((it != referenceMap.rend()))
+ // Copy the zero set first.
+ if ((*referenceMap.begin()).first == INT_MIN)
{
// Get a reference to the vector representing the entries at this scale.
- std::vector<MapEntryType>& scaleVector = (*it).second;
+ std::vector<DualCoverTreeMapEntry>& scaleVector =
+ (*referenceMap.begin()).second;
// Before traversing all the points in this scale, sort by score.
std::sort(scaleVector.begin(), scaleVector.end());
+ const int thisScale = (*referenceMap.begin()).first;
+ childMap[thisScale].reserve(scaleVector.size());
+ std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
+
+ // Loop over each entry in the vector.
+ for (size_t j = 0; j < scaleVector.size(); ++j)
+ {
+ const DualCoverTreeMapEntry& frame = scaleVector[j];
+
+ // First evaluate if we can prune without performing the base case.
+ CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
+ frame.referenceNode;
+
+ // Perform the actual scoring, after restoring the traversal info.
+ rule.TraversalInfo() = frame.traversalInfo;
+ double score = rule.Score(queryNode, *refNode);
+
+ if (score == DBL_MAX)
+ {
+ // Pruned. Move on.
+ ++numPrunes;
+ continue;
+ }
+
+ // If it isn't pruned, we must evaluate the base case.
+ const double baseCase = rule.BaseCase(queryNode.Point(),
+ refNode->Point());
+
+ // Add to child map.
+ newScaleVector.push_back(frame);
+ newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().traversalInfo = rule.TraversalInfo();
+ }
+
+ // If we didn't add anything, then strike this vector from the map.
+ if (newScaleVector.size() == 0)
+ childMap.erase((*referenceMap.begin()).first);
+ }
+
+ typename std::map<int, std::vector<DualCoverTreeMapEntry> >::reverse_iterator
+ it = referenceMap.rbegin();
+
+ while ((it != referenceMap.rend()))
+ {
const int thisScale = (*it).first;
+ if (thisScale == INT_MIN) // We already did it.
+ break;
+
+ // Get a reference to the vector representing the entries at this scale.
+ std::vector<DualCoverTreeMapEntry>& scaleVector = (*it).second;
+
+ // Before traversing all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
+
childMap[thisScale].reserve(scaleVector.size());
- std::vector<MapEntryType>& newScaleVector = childMap[thisScale];
+ std::vector<DualCoverTreeMapEntry>& newScaleVector = childMap[thisScale];
// Loop over each entry in the vector.
for (size_t j = 0; j < scaleVector.size(); ++j)
{
- const MapEntryType& frame = scaleVector[j];
+ const DualCoverTreeMapEntry& frame = scaleVector[j];
// First evaluate if we can prune without performing the base case.
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
- // Perform the actual scoring.
+ // Perform the actual scoring, after restoring the traversal info.
+ rule.TraversalInfo() = frame.traversalInfo;
double score = rule.Score(queryNode, *refNode);
if (score == DBL_MAX)
@@ -192,11 +225,14 @@ DualTreeTraverser<RuleType>::PruneMap(
}
// If it isn't pruned, we must evaluate the base case.
- rule.BaseCase(queryNode.Point(), refNode->Point());
+ const double baseCase = rule.BaseCase(queryNode.Point(),
+ refNode->Point());
// Add to child map.
newScaleVector.push_back(frame);
newScaleVector.back().score = score;
+ newScaleVector.back().baseCase = baseCase;
+ newScaleVector.back().traversalInfo = rule.TraversalInfo();
}
// If we didn't add anything, then strike this vector from the map.
@@ -212,17 +248,19 @@ template<typename RuleType>
void CoverTree<MetricType, RootPointPolicy, StatisticType>::
DualTreeTraverser<RuleType>::ReferenceRecursion(
CoverTree& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
+ std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
{
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
// First, reduce the maximum scale in the reference map down to the scale of
// the query node.
- while (!referenceMap.empty() &&
- ((*referenceMap.rbegin()).first > queryNode.Scale()))
+ while (!referenceMap.empty())
{
+ // Hacky bullshit to imitate jl cover tree.
+ if (queryNode.Parent() == NULL && (*referenceMap.rbegin()).first <
+ queryNode.Scale())
+ break;
+ if (queryNode.Parent() != NULL && (*referenceMap.rbegin()).first <=
+ queryNode.Scale())
+ break;
// If the query node's scale is INT_MIN and the reference map's maximum
// scale is INT_MIN, don't try to recurse...
if ((queryNode.Scale() == INT_MIN) &&
@@ -230,7 +268,7 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
break;
// Get a reference to the current largest scale.
- std::vector<MapEntryType>& scaleVector = (*referenceMap.rbegin()).second;
+ std::vector<DualCoverTreeMapEntry>& scaleVector = (*referenceMap.rbegin()).second;
// Before traversing all the points in this scale, sort by score.
std::sort(scaleVector.begin(), scaleVector.end());
@@ -239,13 +277,13 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
for (size_t i = 0; i < scaleVector.size(); ++i)
{
// Get a reference to the current element.
- const MapEntryType& frame = scaleVector.at(i);
+ const DualCoverTreeMapEntry& frame = scaleVector.at(i);
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
// Create the score for the children.
- double score = rule.Score(queryNode, *refNode);
+ double score = rule.Rescore(queryNode, *refNode, frame.score);
// Now if this childScore is DBL_MAX we can prune all children. In this
// recursion setup pruning is all or nothing for children.
@@ -256,17 +294,27 @@ DualTreeTraverser<RuleType>::ReferenceRecursion(
}
// If it is not pruned, we must evaluate the base case.
- rule.BaseCase(queryNode.Point(), refNode->Point());
// Add the children.
for (size_t j = 0; j < refNode->NumChildren(); ++j)
{
-// const size_t queryIndex = queryNode.Point();
-// const size_t refIndex = refNode->Child(j).Point();
-
- MapEntryType newFrame;
+ rule.TraversalInfo() = frame.traversalInfo;
+ double childScore = rule.Score(queryNode, refNode->Child(j));
+ if (childScore == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+
+ // It wasn't pruned; evaluate the base case.
+ const double baseCase = rule.BaseCase(queryNode.Point(),
+ refNode->Child(j).Point());
+
+ DualCoverTreeMapEntry newFrame;
newFrame.referenceNode = &refNode->Child(j);
- newFrame.score = score; // Use the score of the parent.
+ newFrame.score = childScore; // Use the score of the parent.
+ newFrame.baseCase = baseCase;
+ newFrame.traversalInfo = rule.TraversalInfo();
referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
}
More information about the mlpack-git
mailing list