[mlpack-git] master: Use std::pair instead of Candidate struct. (d1eadad)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b
>---------------------------------------------------------------
commit d1eadad6908d4d95a147455d7a6e60e2cca238f8
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Fri Jul 22 16:48:31 2016 -0300
Use std::pair instead of Candidate struct.
>---------------------------------------------------------------
d1eadad6908d4d95a147455d7a6e60e2cca238f8
src/mlpack/methods/cf/cf.cpp | 18 ++++++-------
src/mlpack/methods/cf/cf.hpp | 22 +++++----------
src/mlpack/methods/fastmks/fastmks.hpp | 24 ++++++-----------
src/mlpack/methods/fastmks/fastmks_impl.hpp | 24 ++++++++---------
src/mlpack/methods/fastmks/fastmks_rules.hpp | 22 +++++----------
src/mlpack/methods/fastmks/fastmks_rules_impl.hpp | 31 +++++++++++-----------
src/mlpack/methods/lsh/lsh_search.hpp | 27 +++++++------------
src/mlpack/methods/lsh/lsh_search_impl.hpp | 26 +++++++++---------
.../neighbor_search/neighbor_search_rules.hpp | 26 +++++++-----------
.../neighbor_search/neighbor_search_rules_impl.hpp | 20 +++++++-------
src/mlpack/methods/rann/ra_search_rules.hpp | 26 +++++++-----------
src/mlpack/methods/rann/ra_search_rules_impl.hpp | 26 +++++++++---------
12 files changed, 125 insertions(+), 167 deletions(-)
diff --git a/src/mlpack/methods/cf/cf.cpp b/src/mlpack/methods/cf/cf.cpp
index 121cf13..01a4c84 100644
--- a/src/mlpack/methods/cf/cf.cpp
+++ b/src/mlpack/methods/cf/cf.cpp
@@ -95,11 +95,11 @@ void CF::GetRecommendations(const size_t numRecs,
// Let's build the list of candidate recomendations for the given user.
// Default candidate: the smallest possible value and invalid item number.
- const Candidate def(-DBL_MAX, cleanedData.n_rows);
+ const Candidate def = std::make_pair(-DBL_MAX, cleanedData.n_rows);
std::vector<Candidate> vect(numRecs, def);
- typedef std::priority_queue<Candidate, std::vector<Candidate>,
- std::greater<Candidate>> CandidateList;
- CandidateList pqueue(std::greater<Candidate>(), std::move(vect));
+ typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+ CandidateList;
+ CandidateList pqueue(CandidateCmp(), std::move(vect));
// Look through the averages column corresponding to the current user.
for (size_t j = 0; j < averages.n_rows; ++j)
@@ -108,11 +108,11 @@ void CF::GetRecommendations(const size_t numRecs,
if (cleanedData(j, users(i)) != 0.0)
continue; // The user already rated the item.
- Candidate c(averages[j], j);
// Is the estimated value better than the worst candidate?
- if (c > pqueue.top())
+ if (averages[i] > pqueue.top().first)
{
+ Candidate c = std::make_pair(averages[j], j);
pqueue.pop();
pqueue.push(c);
}
@@ -120,14 +120,14 @@ void CF::GetRecommendations(const size_t numRecs,
for (size_t p = 1; p <= numRecs; p++)
{
- recommendations(numRecs - p, i) = pqueue.top().item;
- values(numRecs - p, i) = pqueue.top().value;
+ recommendations(numRecs - p, i) = pqueue.top().second;
+ values(numRecs - p, i) = pqueue.top().first;
pqueue.pop();
}
// If we were not able to come up with enough recommendations, issue a
// warning.
- if (recommendations(numRecs - 1, i) == def.item)
+ if (recommendations(numRecs - 1, i) == def.second)
Log::Warn << "Could not provide " << numRecs << " recommendations "
<< "for user " << users(i) << " (not enough un-rated items)!"
<< std::endl;
diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp
index 42b1a4b..82d624c 100644
--- a/src/mlpack/methods/cf/cf.hpp
+++ b/src/mlpack/methods/cf/cf.hpp
@@ -258,22 +258,14 @@ class CF
//! Cleaned data matrix.
arma::sp_mat cleanedData;
- //! Candidate represents a possible recommendation.
- struct Candidate
- {
- //! Value of this recommendation.
- double value;
- //! Item of this recommendation.
- size_t item;
- //! Trivial constructor.
- Candidate(double value, size_t item) :
- value(value),
- item(item)
- {};
- //! Compare the value of two candidates.
- friend bool operator>(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible recommendation (value, item).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the value.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return l.value > r.value;
+ return c1.first > c2.first;
};
};
}; // class CF
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 796f0db..031fe43 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -251,28 +251,20 @@ class FastMKS
//! The instantiated inner-product metric induced by the given kernel.
metric::IPMetric<KernelType> metric;
- //! Candidate point from the reference set.
- struct Candidate
- {
- //! Kernel value calculated between a reference point and the query point.
- double product;
- //! Index of the reference point.
- size_t index;
- //! Trivial constructor.
- Candidate(double p, size_t i) :
- product(p),
- index(i)
- {};
- //! Compare two candidates.
- friend bool operator>(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible candidate point (value, index).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the value.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return l.product > r.product;
+ return c1.first > c2.first;
};
};
//! Use a priority queue to represent the list of candidate points.
typedef std::priority_queue<Candidate, std::vector<Candidate>,
- std::greater<Candidate>> CandidateList;
+ CandidateCmp> CandidateList;
};
} // namespace fastmks
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index f0f321f..993ba54 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -223,18 +223,18 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Simple double loop. Stupid, slow, but a good benchmark.
for (size_t q = 0; q < querySet.n_cols; ++q)
{
- const Candidate def(-DBL_MAX, size_t() - 1);
+ const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
std::vector<Candidate> cList(k, def);
- CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+ CandidateList pqueue(CandidateCmp(), std::move(cList));
for (size_t r = 0; r < referenceSet->n_cols; ++r)
{
const double eval = metric.Kernel().Evaluate(querySet.col(q),
referenceSet->col(r));
- Candidate c(eval, r);
- if (c > pqueue.top())
+ if (eval > pqueue.top().first)
{
+ Candidate c = std::make_pair(eval, r);
pqueue.pop();
pqueue.push(c);
}
@@ -242,8 +242,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
for (size_t j = 1; j <= k; j++)
{
- indices(k - j, q) = pqueue.top().index;
- kernels(k - j, q) = pqueue.top().product;
+ indices(k - j, q) = pqueue.top().second;
+ kernels(k - j, q) = pqueue.top().first;
pqueue.pop();
}
}
@@ -352,9 +352,9 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Simple double loop. Stupid, slow, but a good benchmark.
for (size_t q = 0; q < referenceSet->n_cols; ++q)
{
- const Candidate def(-DBL_MAX, size_t() - 1);
+ const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
std::vector<Candidate> cList(k, def);
- CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+ CandidateList pqueue(CandidateCmp(), std::move(cList));
for (size_t r = 0; r < referenceSet->n_cols; ++r)
{
@@ -364,9 +364,9 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
referenceSet->col(r));
- Candidate c(eval, r);
- if (c > pqueue.top())
+ if (eval > pqueue.top().first)
{
+ Candidate c = std::make_pair(eval, r);
pqueue.pop();
pqueue.push(c);
}
@@ -374,8 +374,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
for (size_t j = 1; j <= k; j++)
{
- indices(k - j, q) = pqueue.top().index;
- kernels(k - j, q) = pqueue.top().product;
+ indices(k - j, q) = pqueue.top().second;
+ kernels(k - j, q) = pqueue.top().first;
pqueue.pop();
}
}
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 11b9604..9aca42d 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -124,22 +124,14 @@ class FastMKSRules
//! The query dataset.
const typename TreeType::Mat& querySet;
- //! Candidate point from the reference set.
- struct Candidate
- {
- //! Kernel value calculated between a reference point and the query point.
- double product;
- //! Index of the reference point.
- size_t index;
- //! Trivial constructor.
- Candidate(double p, size_t i) :
- product(p),
- index(i)
- {};
- //! Compare two candidates.
- friend bool operator>(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible candidate point (value, index).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the value.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return l.product > r.product;
+ return c1.first > c2.first;
};
};
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index a5cf681..6efc9a7 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -50,7 +50,7 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(
// It will be initialized with k candidates: (-DBL_MAX, size_t() - 1)
// The list of candidates will be updated when visiting new points with the
// BaseCase() method.
- const Candidate def(-DBL_MAX, size_t() - 1);
+ const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
CandidateList cList(k, def);
std::vector<CandidateList> tmp(querySet.n_cols, cList);
@@ -68,16 +68,15 @@ void FastMKSRules<KernelType, TreeType>::GetResults(
for (size_t i = 0; i < querySet.n_cols; i++)
{
CandidateList& pqueue = candidates[i];
- std::greater<Candidate> greater;
typedef typename CandidateList::iterator Iterator;
for (Iterator end = pqueue.end(); end != pqueue.begin(); --end)
- std::pop_heap(pqueue.begin(), end, greater);
+ std::pop_heap(pqueue.begin(), end, CandidateCmp());
for (size_t j = 0; j < k; j++)
{
- indices(j, i) = pqueue[j].index;
- products(j, i) = pqueue[j].product;
+ indices(j, i) = pqueue[j].second;
+ products(j, i) = pqueue[j].first;
}
}
}
@@ -127,7 +126,7 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode)
{
// Compare with the current best.
- const double bestKernel = candidates[queryIndex].front().product;
+ const double bestKernel = candidates[queryIndex].front().first;
// See if we can perform a parent-child prune.
const double furthestDist = referenceNode.FurthestDescendantDistance();
@@ -410,7 +409,7 @@ double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
TreeType& /*referenceNode*/,
const double oldScore) const
{
- const double bestKernel = candidates[queryIndex].front().product;
+ const double bestKernel = candidates[queryIndex].front().first;
return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
}
@@ -458,10 +457,10 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
{
const size_t point = queryNode.Point(i);
const CandidateList& candidatesPoints = candidates[point];
- if (candidatesPoints.front().product < worstPointKernel)
- worstPointKernel = candidatesPoints.front().product;
+ if (candidatesPoints.front().first < worstPointKernel)
+ worstPointKernel = candidatesPoints.front().first;
- if (candidatesPoints.front().product == -DBL_MAX)
+ if (candidatesPoints.front().first == -DBL_MAX)
continue; // Avoid underflow.
// This should be (queryDescendantDistance + centroidDistance) for any tree
@@ -478,8 +477,8 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
double worstPointCandidateKernel = DBL_MAX;
for (size_t j = 0; j < candidatesPoints.size(); ++j)
{
- const double candidateKernel = candidatesPoints[j].product -
- queryDescendantDistance * referenceKernels[candidatesPoints[j].index];
+ const double candidateKernel = candidatesPoints[j].first -
+ queryDescendantDistance * referenceKernels[candidatesPoints[j].second];
if (candidateKernel < worstPointCandidateKernel)
worstPointCandidateKernel = candidateKernel;
}
@@ -526,13 +525,13 @@ inline void FastMKSRules<KernelType, TreeType>::InsertNeighbor(
const size_t index,
const double product)
{
- Candidate c(product, index);
CandidateList& pqueue = candidates[queryIndex];
- if (c > pqueue.front())
+ if (product > pqueue.front().first)
{
- std::pop_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+ Candidate c = std::make_pair(product, index);
+ std::pop_heap(pqueue.begin(), pqueue.end(), CandidateCmp());
pqueue.back() = c;
- std::push_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+ std::push_heap(pqueue.begin(), pqueue.end(), CandidateCmp());
}
}
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 45284ba..c622fbb 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -428,28 +428,21 @@ class LSHSearch
//! The number of distance evaluations.
size_t distanceEvaluations;
- //! Candidate represents a possible candidate neighbor (from the reference
- // set).
- struct Candidate
- {
- //! Distance between the reference point and the query point.
- double dist;
- //! Index of the reference point.
- size_t index;
- //! Trivial constructor.
- Candidate(double d, size_t i) :
- dist(d),
- index(i)
- {};
- //! Compare the distance of two candidates.
- friend bool operator<(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible candidate neighbor (distance, index).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the distance.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return !SortPolicy::IsBetter(r.dist, l.dist);
+ return !SortPolicy::IsBetter(c2.first, c1.first);
};
};
//! Use a priority queue to represent the list of candidate neighbors.
- typedef std::priority_queue<Candidate> CandidateList;
+ typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+ CandidateList;
+
}; // class LSHSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index d0b53ae..bdef1be 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -275,9 +275,10 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
// Let's build the list of candidate neighbors for the given query point.
// It will be initialized with k candidates:
// (WorstDistance, referenceSet->n_cols)
- const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+ const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+ referenceSet->n_cols);
std::vector<Candidate> vect(k, def);
- CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+ CandidateList pqueue(CandidateCmp(), std::move(vect));
for (size_t j = 0; j < referenceIndices.n_elem; ++j)
{
@@ -290,9 +291,9 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
referenceSet->unsafe_col(queryIndex),
referenceSet->unsafe_col(referenceIndex));
- Candidate c(distance, referenceIndex);
+ Candidate c = std::make_pair(distance, referenceIndex);
// If this distance is better than the worst candidate, let's insert it.
- if (c < pqueue.top())
+ if (CandidateCmp()(c, pqueue.top()))
{
pqueue.pop();
pqueue.push(c);
@@ -301,8 +302,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
for (size_t j = 1; j <= k; j++)
{
- neighbors(k - j, queryIndex) = pqueue.top().index;
- distances(k - j, queryIndex) = pqueue.top().dist;
+ neighbors(k - j, queryIndex) = pqueue.top().second;
+ distances(k - j, queryIndex) = pqueue.top().first;
pqueue.pop();
}
}
@@ -320,9 +321,10 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
// Let's build the list of candidate neighbors for the given query point.
// It will be initialized with k candidates:
// (WorstDistance, referenceSet->n_cols)
- const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+ const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+ referenceSet->n_cols);
std::vector<Candidate> vect(k, def);
- CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+ CandidateList pqueue(CandidateCmp(), std::move(vect));
for (size_t j = 0; j < referenceIndices.n_elem; ++j)
{
@@ -331,9 +333,9 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
querySet.unsafe_col(queryIndex),
referenceSet->unsafe_col(referenceIndex));
- Candidate c(distance, referenceIndex);
+ Candidate c = std::make_pair(distance, referenceIndex);
// If this distance is better than the worst candidate, let's insert it.
- if (c < pqueue.top())
+ if (CandidateCmp()(c, pqueue.top()))
{
pqueue.pop();
pqueue.push(c);
@@ -342,8 +344,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
for (size_t j = 1; j <= k; j++)
{
- neighbors(k - j, queryIndex) = pqueue.top().index;
- distances(k - j, queryIndex) = pqueue.top().dist;
+ neighbors(k - j, queryIndex) = pqueue.top().second;
+ distances(k - j, queryIndex) = pqueue.top().first;
pqueue.pop();
}
}
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 25e7175..a44d06a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -143,28 +143,20 @@ class NeighborSearchRules
//! The query set.
const typename TreeType::Mat& querySet;
- //! Candidate represents a possible candidate neighbor (from the reference
- // set).
- struct Candidate
- {
- //! Distance between the reference point and the query point.
- double dist;
- //! Index of the reference point.
- size_t index;
- //! Trivial constructor.
- Candidate(double d, size_t i) :
- dist(d),
- index(i)
- {};
- //! Compare the distance of two candidates.
- friend bool operator<(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible candidate neighbor (distance, index).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the distance.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return !SortPolicy::IsBetter(r.dist, l.dist);
+ return !SortPolicy::IsBetter(c2.first, c1.first);
};
};
//! Use a priority queue to represent the list of candidate neighbors.
- typedef std::priority_queue<Candidate> CandidateList;
+ typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+ CandidateList;
//! Set of candidate neighbors for each point.
std::vector<CandidateList> candidates;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 65d258e..e40d09e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -42,10 +42,11 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
// It will be initialized with k candidates: (WorstDistance, size_t() - 1)
// The list of candidates will be updated when visiting new points with the
// BaseCase() method.
- const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+ const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+ size_t() - 1);
std::vector<Candidate> vect(k, def);
- CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+ CandidateList pqueue(CandidateCmp(), std::move(vect));
candidates.reserve(querySet.n_cols);
for (size_t i = 0; i < querySet.n_cols; i++)
@@ -65,8 +66,8 @@ void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
CandidateList& pqueue = candidates[i];
for (size_t j = 1; j <= k; j++)
{
- neighbors(k - j, i) = pqueue.top().index;
- distances(k - j, i) = pqueue.top().dist;
+ neighbors(k - j, i) = pqueue.top().second;
+ distances(k - j, i) = pqueue.top().first;
pqueue.pop();
}
}
@@ -136,7 +137,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
// Compare against the best k'th distance for this query point so far.
- double bestDistance = candidates[queryIndex].top().dist;
+ double bestDistance = candidates[queryIndex].top().first;
bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -153,7 +154,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
return oldScore;
// Just check the score again against the distances.
- double bestDistance = candidates[queryIndex].top().dist;
+ double bestDistance = candidates[queryIndex].top().first;
bestDistance = SortPolicy::Relax(bestDistance, epsilon);
return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
@@ -376,7 +377,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
// Loop over points held in the node.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- const double distance = candidates[queryNode.Point(i)].top().dist;
+ const double distance = candidates[queryNode.Point(i)].top().first;
if (SortPolicy::IsBetter(worstDistance, distance))
worstDistance = distance;
if (SortPolicy::IsBetter(distance, bestPointDistance))
@@ -467,9 +468,10 @@ InsertNeighbor(
const size_t neighbor,
const double distance)
{
- Candidate c(distance, neighbor);
CandidateList& pqueue = candidates[queryIndex];
- if (c < pqueue.top())
+ Candidate c = std::make_pair(distance, neighbor);
+
+ if (CandidateCmp()(c, pqueue.top()))
{
pqueue.pop();
pqueue.push(c);
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index c7edeba..93b7a8c 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -243,28 +243,20 @@ class RASearchRules
//! The query set.
const arma::mat& querySet;
- //! Candidate represents a possible candidate neighbor (from the reference
- // set).
- struct Candidate
- {
- //! Distance between the reference point and the query point.
- double dist;
- //! Index of the reference point.
- size_t index;
- //! Trivial constructor.
- Candidate(double d, size_t i) :
- dist(d),
- index(i)
- {};
- //! Compare the distance of two candidates.
- friend bool operator<(const Candidate& l, const Candidate& r)
+ //! Candidate represents a possible candidate neighbor (distance, index).
+ typedef std::pair<double, size_t> Candidate;
+
+ //! Compare two candidates based on the distance.
+ struct CandidateCmp {
+ bool operator()(const Candidate& c1, const Candidate& c2)
{
- return !SortPolicy::IsBetter(r.dist, l.dist);
+ return !SortPolicy::IsBetter(c2.first, c1.first);
};
};
//! Use a priority queue to represent the list of candidate neighbors.
- typedef std::priority_queue<Candidate> CandidateList;
+ typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+ CandidateList;
//! Set of candidate neighbors for each point.
std::vector<CandidateList> candidates;
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index bad3e24..9a886db 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -69,10 +69,11 @@ RASearchRules(const arma::mat& referenceSet,
// It will be initialized with k candidates: (WorstDistance, size_t() - 1)
// The list of candidates will be updated when visiting new points with the
// BaseCase() method.
- const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+ const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+ size_t() - 1);
std::vector<Candidate> vect(k, def);
- CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+ CandidateList pqueue(CandidateCmp(), std::move(vect));
candidates.reserve(querySet.n_cols);
for (size_t i = 0; i < querySet.n_cols; i++)
@@ -104,8 +105,8 @@ void RASearchRules<SortPolicy, MetricType, TreeType>::GetResults(
CandidateList& pqueue = candidates[i];
for (size_t j = 1; j <= k; j++)
{
- neighbors(k - j, i) = pqueue.top().index;
- distances(k - j, i) = pqueue.top().dist;
+ neighbors(k - j, i) = pqueue.top().second;
+ distances(k - j, i) = pqueue.top().first;
pqueue.pop();
}
}
@@ -143,7 +144,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
&referenceNode);
- const double bestDistance = candidates[queryIndex].top().dist;
+ const double bestDistance = candidates[queryIndex].top().first;
return Score(queryIndex, referenceNode, distance, bestDistance);
}
@@ -157,7 +158,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
&referenceNode, baseCaseResult);
- const double bestDistance = candidates[queryIndex].top().dist;
+ const double bestDistance = candidates[queryIndex].top().first;
return Score(queryIndex, referenceNode, distance, bestDistance);
}
@@ -271,7 +272,7 @@ Rescore(const size_t queryIndex,
return oldScore;
// Just check the score again against the distances.
- const double bestDistance = candidates[queryIndex].top().dist;
+ const double bestDistance = candidates[queryIndex].top().first;
// If this is better than the best distance we've seen so far,
// maybe there will be something down this node.
@@ -371,7 +372,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = candidates[queryNode.Point(i)].top().dist
+ const double bound = candidates[queryNode.Point(i)].top().first
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -410,7 +411,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = candidates[queryNode.Point(i)].top().dist
+ const double bound = candidates[queryNode.Point(i)].top().first
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -624,7 +625,7 @@ Rescore(TreeType& queryNode,
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
- const double bound = candidates[queryNode.Point(i)].top().dist
+ const double bound = candidates[queryNode.Point(i)].top().first
+ maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
@@ -809,9 +810,10 @@ InsertNeighbor(
const size_t neighbor,
const double distance)
{
- Candidate c(distance, neighbor);
CandidateList& pqueue = candidates[queryIndex];
- if (c < pqueue.top())
+ Candidate c = std::make_pair(distance, neighbor);
+
+ if (CandidateCmp()(c, pqueue.top()))
{
pqueue.pop();
pqueue.push(c);
More information about the mlpack-git
mailing list