[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. This makes the code more efficient, especially when k is greater. For example, for knn, given a list of k candidates neighbors, we need to do 2 fast operations: - know the furthest of them. - insert a new candidate. This is the appropiate situation for using a heap. (198cec8)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b
>---------------------------------------------------------------
commit 198cec80a434b3d88a993e3d67a18b778ebc07f1
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 21 10:52:06 2016 -0300
Use a priority queue (heap) to store the list of candidates while searching.
This makes the code more efficient, especially when k is greater.
For example, for knn, given a list of k candidates neighbors, we need
to do 2 fast operations:
- know the furthest of them.
- insert a new candidate.
This is the appropiate situation for using a heap.
>---------------------------------------------------------------
198cec80a434b3d88a993e3d67a18b778ebc07f1
.../neighbor_search/neighbor_search_impl.hpp | 26 ++++---
.../neighbor_search/neighbor_search_rules.hpp | 53 +++++++++++---
.../neighbor_search/neighbor_search_rules_impl.hpp | 81 +++++++++++++---------
3 files changed, 105 insertions(+), 55 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 73560e2..8d0c694 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -376,8 +376,7 @@ Search(const MatType& querySet,
if (naive)
{
// Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
- epsilon);
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon);
// The naive brute-force traversal.
for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -385,12 +384,13 @@ Search(const MatType& querySet,
rules.BaseCase(i, j);
baseCases += querySet.n_cols * referenceSet->n_cols;
+
+ rules.GetResults(*neighborPtr, *distancePtr);
}
else if (singleMode)
{
// Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
- epsilon);
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon);
// Create the traverser.
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -404,6 +404,8 @@ Search(const MatType& querySet,
Log::Info << rules.Scores() << " node combinations were scored.\n";
Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+
+ rules.GetResults(*neighborPtr, *distancePtr);
}
else // Dual-tree recursion.
{
@@ -415,8 +417,7 @@ Search(const MatType& querySet,
Timer::Start("computing_neighbors");
// Create the helper object for the tree traversal.
- RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
- *distancePtr, metric, epsilon);
+ RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
@@ -429,6 +430,8 @@ Search(const MatType& querySet,
Log::Info << rules.Scores() << " node combinations were scored.\n";
Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+ rules.GetResults(*neighborPtr, *distancePtr);
+
delete queryTree;
}
@@ -541,8 +544,7 @@ Search(Tree* queryTree,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric,
- epsilon);
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
@@ -551,6 +553,8 @@ Search(Tree* queryTree,
scores += rules.Scores();
baseCases += rules.BaseCases();
+ rules.GetResults(*neighborPtr, distances);
+
Timer::Stop("computing_neighbors");
// Do we need to map indices?
@@ -612,8 +616,8 @@ Search(const size_t k,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
- metric, epsilon, true /* don't return the same point as nearest neighbor */);
+ RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
+ true /* don't return the same point as nearest neighbor */);
if (naive)
{
@@ -676,6 +680,8 @@ Search(const size_t k,
treeNeedsReset = true;
}
+ rules.GetResults(*neighborPtr, *distancePtr);
+
Timer::Stop("computing_neighbors");
// Do we need to map the reference indices?
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 47a7933..0bcdc49 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -9,6 +9,8 @@
#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
#include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
+#include <queue>
namespace mlpack {
namespace neighbor {
@@ -19,11 +21,20 @@ class NeighborSearchRules
public:
NeighborSearchRules(const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ const size_t k,
MetricType& metric,
const double epsilon = 0,
const bool sameSet = false);
+
+ /**
+ * Store the list of candidates for each query point in the given matrices.
+ *
+ * @param neighbors Matrix storing lists of neighbors for each query point.
+ * @param distances Matrix storing distances of neighbors for each query
+ * point.
+ */
+ void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
+
/**
* Get the distance from the query point to the reference point.
* This will update the "neighbor" matrix with the new point if appropriate
@@ -109,11 +120,34 @@ class NeighborSearchRules
//! The query set.
const typename TreeType::Mat& querySet;
- //! The matrix the resultant neighbor indices should be stored in.
- arma::Mat<size_t>& neighbors;
-
- //! The matrix the resultant neighbor distances should be stored in.
- arma::mat& distances;
+ //! Candidate represents a possible candidate neighbor (from the reference
+ // set).
+ struct Candidate
+ {
+ //! Distance between the reference point and the query point.
+ double dist;
+ //! Index of the reference point.
+ size_t index;
+ //! Trivial constructor.
+ Candidate(double d, size_t i) :
+ dist(d),
+ index(i)
+ {};
+ //! Compare the distance of two candidates.
+ friend bool operator<(const Candidate& l, const Candidate& r)
+ {
+ return !SortPolicy::IsBetter(r.dist, l.dist);
+ };
+ };
+
+ //! Use a priority queue to represent the list of candidate neighbors.
+ typedef std::priority_queue<Candidate> CandidateList;
+
+ //! Set of candidate neighbors for each point.
+ std::vector<CandidateList> candidates;
+
+ //! Number of neighbors to search for.
+ const size_t k;
//! The instantiated metric.
MetricType& metric;
@@ -146,16 +180,13 @@ class NeighborSearchRules
double CalculateBound(TreeType& queryNode) const;
/**
- * Insert a point into the neighbors and distances matrices; this is a helper
- * function.
+ * Helper function to insert a point into the list of candidate points.
*
* @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
const size_t neighbor,
const double distance);
};
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 24f9485..65d258e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -17,15 +17,13 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
+ const size_t k,
MetricType& metric,
const double epsilon,
const bool sameSet) :
referenceSet(referenceSet),
querySet(querySet),
- neighbors(neighbors),
- distances(distances),
+ k(k),
metric(metric),
sameSet(sameSet),
epsilon(epsilon),
@@ -39,9 +37,42 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
// use the this pointer.
traversalInfo.LastQueryNode() = (TreeType*) this;
traversalInfo.LastReferenceNode() = (TreeType*) this;
+
+ // Let's build the list of candidate neighbors for each query point.
+ // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
+ // The list of candidates will be updated when visiting new points with the
+ // BaseCase() method.
+ const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+
+ std::vector<Candidate> vect(k, def);
+ CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
+ candidates.reserve(querySet.n_cols);
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ candidates.push_back(pqueue);
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ neighbors.set_size(k, querySet.n_cols);
+ distances.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ CandidateList& pqueue = candidates[i];
+ for (size_t j = 1; j <= k; j++)
+ {
+ neighbors(k - j, i) = pqueue.top().index;
+ distances(k - j, i) = pqueue.top().dist;
+ pqueue.pop();
+ }
+ }
+};
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline force_inline // Absolutely MUST be inline so optimizations can happen.
double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
BaseCase(const size_t queryIndex, const size_t referenceIndex)
@@ -59,16 +90,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
referenceSet.col(referenceIndex));
++baseCases;
- // If this distance is better than any of the current candidates, the
- // SortDistance() function will give us the position to insert it into.
- arma::vec queryDist = distances.unsafe_col(queryIndex);
- arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
- const size_t insertPosition = SortPolicy::SortDistance(queryDist,
- queryIndices, distance);
-
- // SortDistance() returns (size_t() - 1) if we shouldn't add it.
- if (insertPosition != (size_t() - 1))
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+ InsertNeighbor(queryIndex, referenceIndex, distance);
// Cache this information for the next time BaseCase() is called.
lastQueryIndex = queryIndex;
@@ -114,7 +136,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
// Compare against the best k'th distance for this query point so far.
- double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ double bestDistance = candidates[queryIndex].top().dist;
bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -131,7 +153,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
return oldScore;
// Just check the score again against the distances.
- double bestDistance = distances(distances.n_rows - 1, queryIndex);
+ double bestDistance = candidates[queryIndex].top().dist;
bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
@@ -354,7 +376,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
// Loop over points held in the node.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+ const double distance = candidates[queryNode.Point(i)].top().dist;
if (SortPolicy::IsBetter(worstDistance, distance))
worstDistance = distance;
if (SortPolicy::IsBetter(distance, bestPointDistance))
@@ -432,35 +454,26 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
}
/**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
*
* @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(
const size_t queryIndex,
- const size_t pos,
const size_t neighbor,
const double distance)
{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (distances.n_rows - 1))
+ Candidate c(distance, neighbor);
+ CandidateList& pqueue = candidates[queryIndex];
+ if (c < pqueue.top())
{
- int len = (distances.n_rows - 1) - pos;
- memmove(distances.colptr(queryIndex) + (pos + 1),
- distances.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(neighbors.colptr(queryIndex) + (pos + 1),
- neighbors.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
+ pqueue.pop();
+ pqueue.push(c);
}
-
- // Now put the new information in the right index.
- distances(pos, queryIndex) = distance;
- neighbors(pos, queryIndex) = neighbor;
}
} // namespace neighbor
More information about the mlpack-git
mailing list