[mlpack-svn] r14540 - mlpack/trunk/src/mlpack/methods/neighbor_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Mar 13 15:40:37 EDT 2013
Author: rcurtin
Date: 2013-03-13 15:40:36 -0400 (Wed, 13 Mar 2013)
New Revision: 14540
Modified:
mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Log:
Fix bug in bound function for neighbor search. It is necessary to keep two
separate bounds for each node.
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allkfn_main.cpp 2013-03-13 19:40:36 UTC (rev 14540)
@@ -176,13 +176,12 @@
// Map the points back to their original locations.
if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
Unmap(neighbors, distances, oldFromNewRefs, oldFromNewQueries, neighborsOut,
- distancesOut, true);
+ distancesOut);
else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut,
- true);
+ Unmap(neighbors, distances, oldFromNewRefs, neighborsOut, distancesOut);
else
Unmap(neighbors, distances, oldFromNewRefs, oldFromNewRefs, neighborsOut,
- distancesOut, true);
+ distancesOut);
// Clean up.
if (queryTree)
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/allknn_main.cpp 2013-03-13 19:40:36 UTC (rev 14540)
@@ -238,13 +238,12 @@
// Map the results back to the correct places.
if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
- neighbors, distances, true);
+ neighbors, distances);
else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
- Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances,
- true);
+ Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
else
Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
- neighbors, distances, true);
+ neighbors, distances);
// Clean up.
if (queryTree)
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search.hpp 2013-03-13 19:40:36 UTC (rev 14540)
@@ -30,7 +30,15 @@
class QueryStat
{
private:
- //! The bound on the node's neighbor distances.
+ //! The first bound on the node's neighbor distances (B_1). This represents
+ //! the worst candidate distance of any descendants of this node.
+ double firstBound;
+ //! The second bound on the node's neighbor distances (B_2). This represents
+ //! a bound on the worst distance of any descendants of this node assembled
+ //! using the best descendant candidate distance modified by the furthest
+ //! descendant distance.
+ double secondBound;
+ //! The better of the two bounds.
double bound;
public:
@@ -38,19 +46,32 @@
* Initialize the statistic with the worst possible distance according to
* our sorting policy.
*/
- QueryStat() : bound(SortPolicy::WorstDistance()) { }
+ QueryStat() :
+ firstBound(SortPolicy::WorstDistance()),
+ secondBound(SortPolicy::WorstDistance()),
+ bound(SortPolicy::WorstDistance()) { }
/**
* Initialization for a fully initialized node. In this case, we don't need
* to worry about the node.
*/
template<typename TreeType>
- QueryStat(TreeType& /* node */)
- : bound(SortPolicy::WorstDistance()) { }
+ QueryStat(TreeType& /* node */) :
+ firstBound(SortPolicy::WorstDistance()),
+ secondBound(SortPolicy::WorstDistance()),
+ bound(SortPolicy::WorstDistance()) { }
- //! Get the bound.
+ //! Get the first bound.
+ double FirstBound() const { return firstBound; }
+ //! Modify the first bound.
+ double& FirstBound() { return firstBound; }
+ //! Get the second bound.
+ double SecondBound() const { return secondBound; }
+ //! Modify the second bound.
+ double& SecondBound() { return secondBound; }
+ //! Get the overall bound (the better of the two bounds).
double Bound() const { return bound; }
- //! Modify the bound.
+ //! Modify the overall bound (it should be the better of the two bounds).
double& Bound() { return bound; }
};
Modified: mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-03-13 18:44:09 UTC (rev 14539)
+++ mlpack/trunk/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp 2013-03-13 19:40:36 UTC (rev 14540)
@@ -43,7 +43,7 @@
// If this distance is better than any of the current candidates, the
// SortDistance() function will give us the position to insert it into.
arma::vec queryDist = distances.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
+ const size_t insertPosition = SortPolicy::SortDistance(queryDist, distance);
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
@@ -63,8 +63,7 @@
&referenceNode, &referenceChildNode, baseCaseResult);
// Update our bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestDistance = queryNode.Stat().Bound();
+ const double bestDistance = CalculateBound(queryNode);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
@@ -80,8 +79,7 @@
&queryNode, &queryChildNode, baseCaseResult);
// Update our bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestDistance = queryNode.Stat().Bound();
+ const double bestDistance = CalculateBound(queryNode);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
@@ -138,8 +136,7 @@
&referenceNode);
// Update our bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestDistance = queryNode.Stat().Bound();
+ const double bestDistance = CalculateBound(queryNode);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
@@ -154,8 +151,7 @@
&referenceNode, baseCaseResult);
// Update our bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestDistance = queryNode.Stat().Bound();
+ const double bestDistance = CalculateBound(queryNode);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
@@ -170,40 +166,18 @@
return oldScore;
// Update our bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestDistance = queryNode.Stat().Bound();
+ const double bestDistance = CalculateBound(queryNode);
return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
}
-/*
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::FinishNode(
- TreeType& queryNode) const
-{
- // Find the bound of points contained in this node.
- double pointBound = SortPolicy::BestDistance();
- const double maxDescendantDistance = queryNode.FurthestDescendantDistance();
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- // The bound for this point is the k-th best distance plus the maximum
- // distance to a child of this node.
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i)) +
- maxDescendantDistance;
- if (SortPolicy::IsBetter(pointBound, bound))
- pointBound = bound;
- }
-
- // Push bound to parent.
-}
-*/
-
-// Calculate the bound for a given query node in its current state.
+// Calculate the bound for a given query node in its current state and update
+// it.
template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
CalculateBound(TreeType& queryNode) const
{
- // We have four possible bounds, and we must take the best of them all. We
+ // We have five possible bounds, and we must take the best of them all. We
// don't use min/max here, but instead "best/worst", because this is general
// to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors,
// min = best, max = worst.
@@ -215,11 +189,12 @@
// (3) best_{all children c in queryNode} B(c) +
// 2 ( worst descendant distance of queryNode -
// worst descendant distance of c );
- // (4) B(parent of queryNode);
+ // (4) B_1(parent of queryNode)
+ // (5) B_2(parent of queryNode);
//
// D_p[k] is the current k'th candidate distance for point p.
// So we will loop over the points in queryNode and the children in queryNode
- // to calculate all four of these quantities.
+ // to calculate all five of these quantities.
double worstPointDistance = SortPolicy::BestDistance();
double bestPointDistance = SortPolicy::WorstDistance();
@@ -245,15 +220,16 @@
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
- const double bound = queryNode.Child(i).Stat().Bound();
+ const double firstBound = queryNode.Child(i).Stat().FirstBound();
+ const double secondBound = queryNode.Child(i).Stat().SecondBound();
const double childMaxDescendantDistance =
queryNode.Child(i).FurthestDescendantDistance();
- if (SortPolicy::IsBetter(worstChildBound, bound))
- worstChildBound = bound;
+ if (SortPolicy::IsBetter(worstChildBound, firstBound))
+ worstChildBound = firstBound;
// Now calculate adjustment for maximum descendant distances.
- const double adjustedBound = SortPolicy::CombineWorst(bound,
+ const double adjustedBound = SortPolicy::CombineWorst(secondBound,
2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
bestAdjustedChildBound = adjustedBound;
@@ -269,18 +245,38 @@
2 * queryMaxDescendantDistance);
// Bound (3) is bestAdjustedChildBound.
+
+ // Bounds (4) and (5) are the parent bounds.
const double fourthBound = (queryNode.Parent() != NULL) ?
- queryNode.Parent()->Stat().Bound() : SortPolicy::WorstDistance();
+ queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
+ const double fifthBound = (queryNode.Parent() != NULL) ?
+ queryNode.Parent()->Stat().SecondBound() : SortPolicy::WorstDistance();
// Now, we will take the best of these. Unfortunately due to the way
- // IsBetter() is defined, this sort of has to be a little ugly...
- const double interA = (SortPolicy::IsBetter(firstBound, secondBound)) ?
- firstBound : secondBound;
+ // IsBetter() is defined, this sort of has to be a little ugly.
+ // The variable interA represents the first bound (B_1), which is the worst
+ // candidate distance of any descendants of this node.
+ // The variable interC represents the second bound (B_2), which is a bound on
+ // the worst distance of any descendants of this node assembled using the best
+ // descendant candidate distance modified using the furthest descendant
+ // distance.
+ const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
+ firstBound : fourthBound;
const double interB =
- (SortPolicy::IsBetter(bestAdjustedChildBound, fourthBound)) ?
- bestAdjustedChildBound : fourthBound;
+ (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
+ bestAdjustedChildBound : secondBound;
+ const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
+ fifthBound;
- return (SortPolicy::IsBetter(interA, interB)) ? interA : interB;
+ // Update the first and second bounds of the node.
+ queryNode.Stat().FirstBound() = interA;
+ queryNode.Stat().SecondBound() = interC;
+
+ // Update the actual bound of the node.
+ queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interC)) ? interA :
+ interC;
+
+ return queryNode.Stat().Bound();
}
/**
More information about the mlpack-svn
mailing list