[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. (31aeb5e)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b
>---------------------------------------------------------------
commit 31aeb5e813413505150a14be5c08979920c0d261
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 21 12:52:40 2016 -0300
Use a priority queue (heap) to store the list of candidates while searching.
>---------------------------------------------------------------
31aeb5e813413505150a14be5c08979920c0d261
src/mlpack/methods/rann/ra_search_impl.hpp | 35 +++++----
src/mlpack/methods/rann/ra_search_rules.hpp | 52 +++++++++++---
src/mlpack/methods/rann/ra_search_rules_impl.hpp | 90 ++++++++++++++----------
3 files changed, 112 insertions(+), 65 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index aa8daa5..16360b5 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -360,9 +360,8 @@ Search(const MatType& querySet,
if (naive)
{
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
- tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit, false);
+ RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
+ sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
@@ -377,12 +376,13 @@ Search(const MatType& querySet,
for (size_t i = 0; i < querySet.n_cols; ++i)
for (size_t j = 0; j < distinctSamples.n_elem; ++j)
rules.BaseCase(i, (size_t) distinctSamples[j]);
+
+ rules.GetResults(*neighborPtr, *distancePtr);
}
else if (singleMode)
{
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
- tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit, false);
+ RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
+ sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
// If the reference root node is a leaf, then the sampling has already been
// done in the RASearchRules constructor. This happens when naive = true.
@@ -402,6 +402,8 @@ Search(const MatType& querySet,
<< (rules.NumDistComputations() / querySet.n_cols) << "."
<< std::endl;
}
+
+ rules.GetResults(*neighborPtr, *distancePtr);
}
else // Dual-tree recursion.
{
@@ -415,9 +417,8 @@ Search(const MatType& querySet,
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
- RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
- *distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
- firstLeafExact, singleSampleLimit, false);
+ RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
+ naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
Log::Info << "Query statistic pre-search: "
@@ -429,6 +430,8 @@ Search(const MatType& querySet,
Log::Info << "Average number of distance calculations per query point: "
<< (rules.NumDistComputations() / querySet.n_cols) << "." << std::endl;
+ rules.GetResults(*neighborPtr, *distancePtr);
+
delete queryTree;
}
@@ -529,14 +532,15 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
// Create the helper object for the tree traversal.
typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, distances,
- metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit, false);
+ RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
+ naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
// Create the traverser.
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
+ rules.GetResults(*neighborPtr, distances);
+
Timer::Stop("computing_neighbors");
// Do we need to map indices?
@@ -586,9 +590,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
// Create the helper object for the tree traversal.
typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
- metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit, true /* sets are the same */);
+ RuleType rules(*referenceSet, *referenceSet, k, metric, tau, alpha, naive,
+ sampleAtLeaves, firstLeafExact, singleSampleLimit, true /* same sets */);
if (naive)
{
@@ -622,6 +625,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
traverser.Traverse(*referenceTree, *referenceTree);
}
+ rules.GetResults(*neighborPtr, *distancePtr);
+
Timer::Stop("computing_neighbors");
// Do we need to map the reference indices?
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 1037af4..b04f9bc 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -10,6 +10,8 @@
#define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
#include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
+#include <queue>
namespace mlpack {
namespace neighbor {
@@ -20,8 +22,7 @@ class RASearchRules
public:
RASearchRules(const arma::mat& referenceSet,
const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ const size_t k,
MetricType& metric,
const double tau = 5,
const double alpha = 0.95,
@@ -31,6 +32,15 @@ class RASearchRules
const size_t singleSampleLimit = 20,
const bool sameSet = false);
+ /**
+ * Store the list of candidates for each query point in the given matrices.
+ *
+ * @param neighbors Matrix storing lists of neighbors for each query point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ */
+ void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
+
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
/**
@@ -197,11 +207,34 @@ class RASearchRules
//! The query set.
const arma::mat& querySet;
- //! The matrix the resultant neighbor indices should be stored in.
- arma::Mat<size_t>& neighbors;
-
- //! The matrix the resultant neighbor distances should be stored in.
- arma::mat& distances;
+ //! Candidate represents a possible candidate neighbor (from the reference
+ // set).
+ struct Candidate
+ {
+ //! Distance between the reference point and the query point.
+ double dist;
+ //! Index of the reference point.
+ size_t index;
+ //! Trivial constructor.
+ Candidate(double d, size_t i) :
+ dist(d),
+ index(i)
+ {};
+ //! Compare the distance of two candidates.
+ friend bool operator<(const Candidate& l, const Candidate& r)
+ {
+ return !SortPolicy::IsBetter(r.dist, l.dist);
+ };
+ };
+
+ //! Use a priority queue to represent the list of candidate neighbors.
+ typedef std::priority_queue<Candidate> CandidateList;
+
+ //! Set of candidate neighbors for each point.
+ std::vector<CandidateList> candidates;
+
+ //! Number of neighbors to search for.
+ const size_t k;
//! The instantiated metric.
MetricType& metric;
@@ -233,16 +266,13 @@ class RASearchRules
TraversalInfoType traversalInfo;
/**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
+ * Helper function to insert a point into the list of candidate points.
*
* @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
const size_t neighbor,
const double distance);
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index 2071de1..bad3e24 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -17,8 +17,7 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
RASearchRules<SortPolicy, MetricType, TreeType>::
RASearchRules(const arma::mat& referenceSet,
const arma::mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ const size_t k,
MetricType& metric,
const double tau,
const double alpha,
@@ -29,8 +28,7 @@ RASearchRules(const arma::mat& referenceSet,
const bool sameSet) :
referenceSet(referenceSet),
querySet(querySet),
- neighbors(neighbors),
- distances(distances),
+ k(k),
metric(metric),
sampleAtLeaves(sampleAtLeaves),
firstLeafExact(firstLeafExact),
@@ -42,7 +40,6 @@ RASearchRules(const arma::mat& referenceSet,
// The rank approximation.
const size_t n = referenceSet.n_cols;
- const size_t k = neighbors.n_rows;
const size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
if (t < k)
{
@@ -68,7 +65,20 @@ RASearchRules(const arma::mat& referenceSet,
Log::Info << "Minimum samples required per query: " << numSamplesReqd <<
", sampling ratio: " << samplingRatio << std::endl;
- if (naive) // No tree traversal; just do naive sampling here.
+ // Let's build the list of candidate neighbors for each query point.
+ // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
+ // The list of candidates will be updated when visiting new points with the
+ // BaseCase() method.
+ const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+
+ std::vector<Candidate> vect(k, def);
+ CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
+ candidates.reserve(querySet.n_cols);
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ candidates.push_back(pqueue);
+
+ if (naive)// No tree traversal; just do naive sampling here.
{
// Sample enough points.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -82,6 +92,26 @@ RASearchRules(const arma::mat& referenceSet,
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ neighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ CandidateList& pqueue = candidates[i];
+ for (size_t j = 1; j <= k; j++)
+ {
+ neighbors(k - j, i) = pqueue.top().index;
+ distances(k - j, i) = pqueue.top().dist;
+ pqueue.pop();
+ }
+ }
+};
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline force_inline
double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
const size_t queryIndex,
@@ -95,16 +125,7 @@ double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
- size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
- distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+ InsertNeighbor(queryIndex, referenceIndex, distance);
numSamplesMade[queryIndex]++;
@@ -122,7 +143,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
&referenceNode);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ const double bestDistance = candidates[queryIndex].top().dist;
return Score(queryIndex, referenceNode, distance, bestDistance);
}
@@ -136,7 +157,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
&referenceNode, baseCaseResult);
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ const double bestDistance = candidates[queryIndex].top().dist;
return Score(queryIndex, referenceNode, distance, bestDistance);
}
@@ -250,7 +271,7 @@ Rescore(const size_t queryIndex,
return oldScore;
// Just check the score again against the distances.
- const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ const double bestDistance = candidates[queryIndex].top().dist;
// If this is better than the best distance we've seen so far,
// maybe there will be something down this node.
@@ -350,7 +371,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ const double bound = candidates[queryNode.Point(i)].top().dist
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -389,7 +410,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ const double bound = candidates[queryNode.Point(i)].top().dist
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -603,7 +624,7 @@ Rescore(TreeType& queryNode,
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+ const double bound = candidates[queryNode.Point(i)].top().dist
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -775,35 +796,26 @@ Rescore(TreeType& queryNode,
} // Rescore(node, node, oldScore)
/**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
*
* @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+inline void RASearchRules<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(
const size_t queryIndex,
- const size_t pos,
const size_t neighbor,
const double distance)
{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
+ Candidate c(distance, neighbor);
+ CandidateList& pqueue = candidates[queryIndex];
+ if (c < pqueue.top())
{
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
+ pqueue.pop();
+ pqueue.push(c);
}
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
}
} // namespace neighbor
More information about the mlpack-git
mailing list