[mlpack-svn] r15779 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 16:09:48 EDT 2013
Author: rcurtin
Date: Fri Sep 13 16:09:48 2013
New Revision: 15779
Log:
Uh, this should fix the broken build. I will clean this up in the forthcoming
days. The traversers are re-implemented to make the overloads of Score() and
Rescore() that take base case values unnecessary, and actually the dual-tree
traverser now satisfies the definition given in the ICML paper correctly.
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp Fri Sep 13 16:09:48 2013
@@ -458,6 +458,12 @@
* Returns a string representation of this object.
*/
std::string ToString() const;
+
+ size_t DistanceComps() const { return distanceComps; }
+ size_t& DistanceComps() { return distanceComps; }
+
+ private:
+ size_t distanceComps;
};
}; // namespace tree
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp Fri Sep 13 16:09:48 2013
@@ -31,7 +31,8 @@
parentDistance(0),
furthestDescendantDistance(0),
localMetric(metric == NULL),
- metric(metric)
+ metric(metric),
+ distanceComps(0)
{
// If we need to create a metric, do that. We'll just do it on the heap.
if (localMetric)
@@ -66,6 +67,9 @@
// Initialize statistic.
stat = StatisticType(*this);
+
+ Log::Info << distanceComps << " distance computations during tree "
+ << "construction." << std::endl;
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -82,7 +86,8 @@
parentDistance(0),
furthestDescendantDistance(0),
localMetric(false),
- metric(&metric)
+ metric(&metric),
+ distanceComps(0)
{
// If there is only one point in the dataset, uh, we're done.
if (dataset.n_cols == 1)
@@ -113,6 +118,9 @@
// Initialize statistic.
stat = StatisticType(*this);
+
+ Log::Info << distanceComps << " distance computations during tree "
+ << "construction." << std::endl;
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -138,7 +146,8 @@
parentDistance(parentDistance),
furthestDescendantDistance(0),
localMetric(false),
- metric(&metric)
+ metric(&metric),
+ distanceComps(0)
{
// If the size of the near set is 0, this is a leaf.
if (nearSetSize == 0)
@@ -176,7 +185,8 @@
parentDistance(parentDistance),
furthestDescendantDistance(furthestDescendantDistance),
localMetric(metric == NULL),
- metric(metric)
+ metric(metric),
+ distanceComps(0)
{
// If necessary, create a local metric.
if (localMetric)
@@ -199,7 +209,8 @@
parentDistance(other.parentDistance),
furthestDescendantDistance(other.furthestDescendantDistance),
localMetric(false),
- metric(other.metric)
+ metric(other.metric),
+ distanceComps(0)
{
// Copy each child by hand.
for (size_t i = 0; i < other.NumChildren(); ++i)
@@ -409,6 +420,7 @@
size_t tempSize = 0;
children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
indices, distances, 0, tempSize, usedSetSize, *metric));
+ distanceComps += children.back()->DistanceComps();
// Every point in the near set should be a leaf.
for (size_t i = 0; i < nearSetSize; ++i)
@@ -417,6 +429,7 @@
children.push_back(new CoverTree(dataset, base, indices[i],
INT_MIN, this, distances[i], indices, distances, 0, tempSize,
usedSetSize, *metric));
+ distanceComps += children.back()->DistanceComps();
usedSetSize++;
}
@@ -458,6 +471,8 @@
// Remove any implicit nodes we may have created.
RemoveNewImplicitNodes();
+ distanceComps += children[0]->DistanceComps();
+
// Now the arrays, in memory, look like this:
// [ childFar | childUsed | far | used ]
// but we need to move the used points past our far set:
@@ -504,7 +519,8 @@
children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
this, distances[0], indices, distances, childNearSetSize, farSetSize,
usedSetSize, *metric));
- numDescendants += children[children.size() - 1]->NumDescendants();
+ distanceComps += children.back()->DistanceComps();
+ numDescendants += children.back()->NumDescendants();
// Because the far set size is 0, we don't have to do any swapping to
// move the point into the used set.
@@ -545,11 +561,13 @@
children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
this, distances[0], childIndices, childDistances, childNearSetSize,
childFarSetSize, childUsedSetSize, *metric));
- numDescendants += children[children.size() - 1]->NumDescendants();
+ numDescendants += children.back()->NumDescendants();
// Remove any implicit nodes.
RemoveNewImplicitNodes();
+ distanceComps += children.back()->DistanceComps();
+
// Now with the child created, it returns the childIndices and
// childDistances vectors in this form:
// [ childFar | childUsed ]
@@ -628,6 +646,7 @@
{
// For each point, rebuild the distances. The indices do not need to be
// modified.
+ distanceComps += pointSetSize;
for (size_t i = 0; i < pointSetSize; ++i)
{
distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
@@ -894,6 +913,7 @@
// Set its parent correctly.
old->Child(0).Parent() = this;
old->Child(0).ParentDistance() = old->ParentDistance();
+ old->Child(0).DistanceComps() = old->DistanceComps();
// Remove its child (so it doesn't delete it).
old->Children().erase(old->Children().begin() + old->Children().size() - 1);
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp Fri Sep 13 16:09:48 2013
@@ -48,6 +48,12 @@
//! Modify the number of pruned nodes.
size_t& NumPrunes() { return numPrunes; }
+ ///// These are all fake because this is a patch for kd-trees only and I still
+ ///// want it to compile!
+ size_t NumVisited() const { return 0; }
+ size_t NumScores() const { return 0; }
+ size_t NumBaseCases() const { return 0; }
+
private:
//! The instantiated rule set for pruning branches.
RuleType& rule;
@@ -57,18 +63,12 @@
//! Prepare map for recursion.
void PruneMap(CoverTree& queryNode,
- CoverTree& candidateQueryNode,
std::map<int, std::vector<DualCoverTreeMapEntry<
MetricType, RootPointPolicy, StatisticType> > >&
referenceMap,
std::map<int, std::vector<DualCoverTreeMapEntry<
MetricType, RootPointPolicy, StatisticType> > >& childMap);
- void PruneMapForSelfChild(CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<
- MetricType, RootPointPolicy, StatisticType> > >&
- referenceMap);
-
void ReferenceRecursion(CoverTree& queryNode,
std::map<int, std::vector<DualCoverTreeMapEntry<
MetricType, RootPointPolicy, StatisticType> > >&
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp Fri Sep 13 16:09:48 2013
@@ -21,12 +21,6 @@
CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
//! The score of the node.
double score;
- //! The index of the reference node used for the base case evaluation.
- size_t referenceIndex;
- //! The index of the query node used for the base case evaluation.
- size_t queryIndex;
- //! The base case evaluation.
- double baseCase;
//! Comparison operator, for sorting within the map.
bool operator<(const DualCoverTreeMapEntry& other) const
@@ -53,18 +47,13 @@
typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
MapEntryType;
- // Start by creating a map and adding the reference node to it.
+ // Start by creating a map and adding the reference root node to it.
std::map<int, std::vector<MapEntryType> > refMap;
MapEntryType rootRefEntry;
rootRefEntry.referenceNode = &referenceNode;
rootRefEntry.score = 0.0; // Must recurse into.
- rootRefEntry.referenceIndex = referenceNode.Point();
- rootRefEntry.queryIndex = queryNode.Point();
- rootRefEntry.baseCase = rule.BaseCase(queryNode.Point(),
- referenceNode.Point());
-// rule.UpdateAfterRecursion(queryNode, referenceNode);
refMap[referenceNode.Scale()].push_back(rootRefEntry);
@@ -75,13 +64,10 @@
template<typename RuleType>
void CoverTree<MetricType, RootPointPolicy, StatisticType>::
DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
+ CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
+ StatisticType> > >& referenceMap)
{
-// Log::Debug << "Recursed into query node " << queryNode.Point() << ", scale "
-// << queryNode.Scale() << "\n";
-
// Convenience typedef.
typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
MapEntryType;
@@ -99,26 +85,17 @@
{
// Recurse into the non-self-children first. The recursion order cannot
// affect the runtime of the algorithm, because each query child recursion's
- // results are separate and independent.
- for (size_t i = 1; i < queryNode.NumChildren(); ++i)
+ // results are separate and independent. I don't think this is true in
+ // every case, and we may have to modify this section to consider scores in
+ // the future.
+ std::map<int, std::vector<MapEntryType> > childMap;
+ PruneMap(queryNode, referenceMap, childMap);
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
- std::map<int, std::vector<MapEntryType> > childMap;
- PruneMap(queryNode, queryNode.Child(i), referenceMap, childMap);
-
-// Log::Debug << "Recurse into query child " << i << ": " <<
-// queryNode.Child(i).Point() << " scale " << queryNode.Child(i).Scale()
-// << "; this parent is " << queryNode.Point() << " scale " <<
-// queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(i), childMap);
+ // We need a copy of the map for this child.
+ std::map<int, std::vector<MapEntryType> > thisChildMap = childMap;
+ Traverse(queryNode.Child(i), thisChildMap);
}
-
- PruneMapForSelfChild(queryNode.Child(0), referenceMap);
-
- // Now we can use the existing map (without a copy) for the self-child.
-// Log::Warn << "Recurse into query self-child " << queryNode.Child(0).Point()
-// << " scale " << queryNode.Child(0).Scale() << "; this parent is "
-// << queryNode.Point() << " scale " << queryNode.Scale() << std::endl;
- Traverse(queryNode.Child(0), referenceMap);
}
if (queryNode.Scale() != INT_MIN)
@@ -129,7 +106,6 @@
Log::Assert((*referenceMap.begin()).first == INT_MIN);
Log::Assert(queryNode.Scale() == INT_MIN);
std::vector<MapEntryType>& pointVector = (*referenceMap.begin()).second;
-// Log::Debug << "Onto base case evaluations\n";
for (size_t i = 0; i < pointVector.size(); ++i)
{
@@ -138,36 +114,26 @@
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
- const double oldScore = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t queryIndex = frame.queryIndex;
-// Log::Debug << "Consider query " << queryNode.Point() << ", reference "
-// << refNode->Point() << "\n";
-// Log::Debug << "Old score " << oldScore << " with refParent " << refIndex
-// << " and parent query node " << queryIndex << "\n";
- // First, ensure that we have not already calculated the base case.
- if ((refIndex == refNode->Point()) && (queryIndex == queryNode.Point()))
+ // If the point is the same as both parents, then we have already done this
+ // base case.
+ if ((refNode->Point() == refNode->Parent()->Point()) &&
+ (queryNode.Point() == queryNode.Parent()->Point()))
{
-// Log::Debug << "Pruned because we already did the base case and its value "
-// << " was " << frame.baseCase << std::endl;
++numPrunes;
continue;
}
- // Now, check if we can prune it.
- const double rescore = rule.Rescore(queryNode, *refNode, oldScore);
+ // Score the node, to see if we can prune it.
+ double score = rule.Score(queryNode, *refNode);
- if (rescore == DBL_MAX)
+ if (score == DBL_MAX)
{
-// Log::Debug << "Pruned after rescoring\n";
++numPrunes;
continue;
}
// If not, compute the base case.
-// Log::Debug << "Not pruned, performing base case " << queryNode.Point() <<
-// " " << pointVector[i].referenceNode->Point() << "\n";
rule.BaseCase(queryNode.Point(), pointVector[i].referenceNode->Point());
}
}
@@ -176,16 +142,12 @@
template<typename RuleType>
void CoverTree<MetricType, RootPointPolicy, StatisticType>::
DualTreeTraverser<RuleType>::PruneMap(
- CoverTree& /* queryNode */,
- CoverTree& candidateQueryNode,
+ CoverTree& queryNode,
std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
RootPointPolicy, StatisticType> > >& referenceMap,
std::map<int, std::vector<DualCoverTreeMapEntry<MetricType,
RootPointPolicy, StatisticType> > >& childMap)
{
-// Log::Debug << "Prep for recurse into query child point " <<
-// candidateQueryNode.Point() << " scale " <<
-// candidateQueryNode.Scale() << std::endl;
typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
MapEntryType;
@@ -194,15 +156,17 @@
typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
referenceMap.rbegin();
- while ((it != referenceMap.rend()) && ((*it).first != INT_MIN))
+ while ((it != referenceMap.rend()))
{
// Get a reference to the vector representing the entries at this scale.
- const std::vector<MapEntryType>& scaleVector = (*it).second;
+ std::vector<MapEntryType>& scaleVector = (*it).second;
+
+ // Before traversing all the points in this scale, sort by score.
+ std::sort(scaleVector.begin(), scaleVector.end());
const int thisScale = (*it).first;
childMap[thisScale].reserve(scaleVector.size());
std::vector<MapEntryType>& newScaleVector = childMap[thisScale];
-// newScaleVector.reserve(scaleVector.size()); // Maximum possible size.
// Loop over each entry in the vector.
for (size_t j = 0; j < scaleVector.size(); ++j)
@@ -212,27 +176,9 @@
// First evaluate if we can prune without performing the base case.
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
- const double oldScore = frame.score;
-
- // Try to prune based on shell(). This is hackish and will need to be
- // refined or cleaned at some point.
-// double score = rule.PrescoreQ(queryNode, candidateQueryNode, *refNode,
-// frame.baseCase);
-
-// if (score == DBL_MAX)
-// {
-// ++numPrunes;
-// continue;
-// }
-
-// Log::Debug << "Recheck reference node " << refNode->Point() <<
-// " scale " << refNode->Scale() << " which has old score " <<
-// oldScore << " with old reference index " << frame.referenceIndex
-// << " and old query index " << frame.queryIndex << std::endl;
-
- double score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-// Log::Debug << "Rescored as " << score << std::endl;
+ // Perform the actual scoring.
+ double score = rule.Score(queryNode, *refNode);
if (score == DBL_MAX)
{
@@ -241,29 +187,12 @@
continue;
}
- // Evaluate base case.
-// Log::Debug << "Must evaluate base case "
-// << candidateQueryNode.Point() << " " << refNode->Point()
-// << "\n";
- double baseCase = rule.BaseCase(candidateQueryNode.Point(),
- refNode->Point());
-// Log::Debug << "Base case was " << baseCase << std::endl;
-
- score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
+ // If it isn't pruned, we must evaluate the base case.
+ rule.BaseCase(queryNode.Point(), refNode->Point());
// Add to child map.
newScaleVector.push_back(frame);
newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().referenceIndex = refNode->Point();
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
}
// If we didn't add anything, then strike this vector from the map.
@@ -272,108 +201,6 @@
++it; // Advance to next scale.
}
-
- childMap[INT_MIN] = referenceMap[INT_MIN];
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::PruneMapForSelfChild(
- CoverTree& candidateQueryNode,
- std::map<int, std::vector<DualCoverTreeMapEntry<MetricType, RootPointPolicy,
- StatisticType> > >& referenceMap)
-{
-// Log::Debug << "Prep for recurse into query self-child point " <<
-// candidateQueryNode.Point() << " scale " <<
-// candidateQueryNode.Scale() << std::endl;
- typedef DualCoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
- MapEntryType;
-
- // Create the child reference map. We will do this by recursing through
- // every entry in the reference map and evaluating (or pruning) it. But
- // in this setting we do not recurse into any children of the reference
- // entries.
- if (referenceMap.empty())
- return; // Nothing to do.
- typename std::map<int, std::vector<MapEntryType> >::reverse_iterator it =
- referenceMap.rbegin();
-
- while (it != referenceMap.rend() && (*it).first != INT_MIN)
- {
- // Get a reference to the vector representing the entries at this scale.
- std::vector<MapEntryType>& newScaleVector = (*it).second;
- const std::vector<MapEntryType> scaleVector = newScaleVector;
-
- newScaleVector.clear();
- newScaleVector.reserve(scaleVector.size());
-
- // Loop over each entry in the vector.
- for (size_t j = 0; j < scaleVector.size(); ++j)
- {
- const MapEntryType& frame = scaleVector[j];
-
- // First evaluate if we can prune without performing the base case.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
- frame.referenceNode;
- const double oldScore = frame.score;
- double baseCase = frame.baseCase;
- const size_t queryIndex = frame.queryIndex;
- const size_t refIndex = frame.referenceIndex;
-
-// Log::Debug << "Recheck reference node " << refNode->Point() << " scale "
-// << refNode->Scale() << " which has old score " << oldScore
-// << std::endl;
-
- // Have we performed the base case yet?
- double score;
- if ((refIndex != refNode->Point()) ||
- (queryIndex != candidateQueryNode.Point()))
- {
- // Attempt to rescore before performing the base case.
- score = rule.Rescore(candidateQueryNode, *refNode, oldScore);
-
- if (score == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
-
- // If not pruned, we have to perform the base case.
- baseCase = rule.BaseCase(candidateQueryNode.Point(), refNode->Point());
- }
-
- score = rule.Score(candidateQueryNode, *refNode, baseCase);
-
-// Log::Debug << "Rescored as " << score << std::endl;
-
- if (score == DBL_MAX)
- {
- // Pruned. Move on.
- ++numPrunes;
- continue;
- }
-
-// Log::Debug << "Kept in map\n";
-
- // Add to child map.
- newScaleVector.push_back(frame);
- newScaleVector.back().score = score;
- newScaleVector.back().baseCase = baseCase;
- newScaleVector.back().queryIndex = candidateQueryNode.Point();
- newScaleVector.back().referenceIndex = refNode->Point();
- }
-
- // If we didn't add anything, then strike this vector from the map.
- if (newScaleVector.size() == 0)
- {
- referenceMap.erase((*it).first);
- if (referenceMap.empty())
- break;
- }
-
- ++it; // Advance to next scale.
- }
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -411,115 +238,30 @@
CoverTree<MetricType, RootPointPolicy, StatisticType>* refNode =
frame.referenceNode;
- const double score = frame.score;
- const size_t refIndex = frame.referenceIndex;
- const size_t refPoint = refNode->Point();
- const size_t queryIndex = frame.queryIndex;
- const size_t queryPoint = queryNode.Point();
- double baseCase = frame.baseCase;
-
-// Log::Debug << "Currently inspecting reference node " << refNode->Point()
-// << " scale " << refNode->Scale() << " earlier query index " <<
-// queryIndex << std::endl;
-
-// Log::Debug << "Old score " << score << " with refParent " << refIndex
-// << " and queryIndex " << queryIndex << "\n";
-
- // First we recalculate the score of this node to find if we can prune it.
- if (rule.Rescore(queryNode, *refNode, score) == DBL_MAX)
- {
-// Log::Warn << "Pruned after rescore\n";
- ++numPrunes;
- continue;
- }
-
- // If this is a self-child, the base case has already been evaluated.
- // We also must ensure that the base case was evaluated with this query
- // point.
- if ((refPoint != refIndex) || (queryPoint != queryIndex))
- {
-// Log::Warn << "Must evaluate base case " << queryNode.Point() << " "
-// << refPoint << "\n";
- baseCase = rule.BaseCase(queryPoint, refPoint);
-// Log::Debug << "Base case " << baseCase << std::endl;
- }
// Create the score for the children.
- double childScore = rule.Score(queryNode, *refNode, baseCase);
+ double score = rule.Score(queryNode, *refNode);
// Now if this childScore is DBL_MAX we can prune all children. In this
// recursion setup pruning is all or nothing for children.
- if (childScore == DBL_MAX)
+ if (score == DBL_MAX)
{
-// Log::Warn << "Pruned all children.\n";
- numPrunes += refNode->NumChildren();
+ ++numPrunes;
continue;
}
- // We must treat the self-leaf differently. The base case has already
- // been performed.
- childScore = rule.Score(queryNode, refNode->Child(0), baseCase);
+ // If it is not pruned, we must evaluate the base case.
+ rule.BaseCase(queryNode.Point(), refNode->Point());
- if (childScore != DBL_MAX)
+ // Add the children.
+ for (size_t j = 0; j < refNode->NumChildren(); ++j)
{
- MapEntryType newFrame;
- newFrame.referenceNode = &refNode->Child(0);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refPoint;
- newFrame.queryIndex = queryNode.Point();
-
- referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
- }
- else
- {
- ++numPrunes;
- }
-
- // Add the non-self-leaf children.
- for (size_t j = 1; j < refNode->NumChildren(); ++j)
- {
- const size_t queryIndex = queryNode.Point();
- const size_t refIndex = refNode->Child(j).Point();
-
- // We need to incorporate shell() here to try and avoid base case
- // computations. TODO
-// Log::Debug << "Prescore query " << queryNode.Point() << " scale "
-// << queryNode.Scale() << ", reference " << refNode->Point() <<
-// " scale " << refNode->Scale() << ", reference child " <<
-// refNode->Child(j).Point() << " scale " << refNode->Child(j).Scale()
-// << " with base case " << baseCase;
-// childScore = rule.Prescore(queryNode, *refNode, refNode->Child(j),
-// frame.baseCase);
-// Log::Debug << " and result " << childScore << ".\n";
-
-// if (childScore == DBL_MAX)
-// {
-// ++numPrunes;
-// continue;
-// }
-
- // Calculate the base case of each child.
- baseCase = rule.BaseCase(queryIndex, refIndex);
-
- // See if we can prune it.
- double childScore = rule.Score(queryNode, refNode->Child(j), baseCase);
-
- if (childScore == DBL_MAX)
- {
- ++numPrunes;
- continue;
- }
+// const size_t queryIndex = queryNode.Point();
+// const size_t refIndex = refNode->Child(j).Point();
MapEntryType newFrame;
newFrame.referenceNode = &refNode->Child(j);
- newFrame.score = childScore;
- newFrame.baseCase = baseCase;
- newFrame.referenceIndex = refIndex;
- newFrame.queryIndex = queryIndex;
-
-// Log::Debug << "Push onto map child " << refNode->Child(j).Point() <<
-// " scale " << refNode->Child(j).Scale() << std::endl;
+ newFrame.score = score; // Use the score of the parent.
referenceMap[newFrame.referenceNode->Scale()].push_back(newFrame);
}
@@ -528,7 +270,6 @@
// Now clear the memory for this scale; it isn't needed anymore.
referenceMap.erase((*referenceMap.rbegin()).first);
}
-
}
}; // namespace tree
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp Fri Sep 13 16:09:48 2013
@@ -64,12 +64,8 @@
// largest scale.
std::map<int, std::vector<MapEntryType> > mapQueue;
- // Manually add the children of the first node. These cannot be pruned
- // anyway.
- double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
-
// Create the score for the children.
- double rootChildScore = rule.Score(queryIndex, referenceNode, rootBaseCase);
+ double rootChildScore = rule.Score(queryIndex, referenceNode);
if (rootChildScore == DBL_MAX)
{
@@ -77,6 +73,12 @@
}
else
{
+ // Manually add the children of the first node.
+ // Often, a ruleset will return without doing any computation on cover trees
+ // using TreeTraits::FirstPointIsCentroid; this is an optimization that
+ // (theoretically) the compiler should get right.
+ double rootBaseCase = rule.BaseCase(queryIndex, referenceNode.Point());
+
// Don't add the self-leaf.
size_t i = 0;
if (referenceNode.Child(0).NumChildren() == 0)
@@ -132,12 +134,8 @@
continue;
}
- // If we are a self-child, the base case has already been evaluated.
- if (point != parent)
- baseCase = rule.BaseCase(queryIndex, point);
-
// Create the score for the children.
- const double childScore = rule.Score(queryIndex, *node, baseCase);
+ const double childScore = rule.Score(queryIndex, *node);
// Now if this childScore is DBL_MAX we can prune all children. In this
// recursion setup pruning is all or nothing for children.
@@ -147,6 +145,13 @@
continue;
}
+ // If we are a self-child, the base case has already been evaluated.
+ // Often, a ruleset will return without doing any computation on cover
+ // trees using TreeTraits::FirstPointIsCentroid; this is an optimization
+ // that (theoretically) the compiler should get right.
+ if (point != parent)
+ baseCase = rule.BaseCase(queryIndex, point);
+
// Don't add the self-leaf.
size_t j = 0;
if (node->Child(0).NumChildren() == 0)
@@ -181,16 +186,32 @@
const size_t point = node->Point();
// First, recalculate the score of this node to find if we can prune it.
- double actualScore = rule.Rescore(queryIndex, *node, score);
+ double rescore = rule.Rescore(queryIndex, *node, score);
- if (actualScore == DBL_MAX)
+ if (rescore == DBL_MAX)
{
++numPrunes;
continue;
}
- // There are no self-leaves; evaluate the base case.
- rule.BaseCase(queryIndex, point);
+ // For this to be a valid dual-tree algorithm, we *must* evaluate the
+ // combination, even if pruning it will make no difference. It's the
+ // definition.
+ const double actualScore = rule.Score(queryIndex, *node);
+
+ if (actualScore == DBL_MAX)
+ {
+ ++numPrunes;
+ continue;
+ }
+ else
+ {
+ // Evaluate the base case, since the combination was not pruned.
+ // Often, a ruleset will return without doing any computation on cover
+ // trees using TreeTraits::FirstPointIsCentroid; this is an optimization
+ // that (theoretically) the compiler should get right.
+ rule.BaseCase(queryIndex, point);
+ }
}
}
More information about the mlpack-svn
mailing list