[mlpack-git] master: Use std::pair instead of Candidate struct. (d1eadad)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b

>---------------------------------------------------------------

commit d1eadad6908d4d95a147455d7a6e60e2cca238f8
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Fri Jul 22 16:48:31 2016 -0300

    Use std::pair instead of Candidate struct.


>---------------------------------------------------------------

d1eadad6908d4d95a147455d7a6e60e2cca238f8
 src/mlpack/methods/cf/cf.cpp                       | 18 ++++++-------
 src/mlpack/methods/cf/cf.hpp                       | 22 +++++----------
 src/mlpack/methods/fastmks/fastmks.hpp             | 24 ++++++-----------
 src/mlpack/methods/fastmks/fastmks_impl.hpp        | 24 ++++++++---------
 src/mlpack/methods/fastmks/fastmks_rules.hpp       | 22 +++++----------
 src/mlpack/methods/fastmks/fastmks_rules_impl.hpp  | 31 +++++++++++-----------
 src/mlpack/methods/lsh/lsh_search.hpp              | 27 +++++++------------
 src/mlpack/methods/lsh/lsh_search_impl.hpp         | 26 +++++++++---------
 .../neighbor_search/neighbor_search_rules.hpp      | 26 +++++++-----------
 .../neighbor_search/neighbor_search_rules_impl.hpp | 20 +++++++-------
 src/mlpack/methods/rann/ra_search_rules.hpp        | 26 +++++++-----------
 src/mlpack/methods/rann/ra_search_rules_impl.hpp   | 26 +++++++++---------
 12 files changed, 125 insertions(+), 167 deletions(-)

diff --git a/src/mlpack/methods/cf/cf.cpp b/src/mlpack/methods/cf/cf.cpp
index 121cf13..01a4c84 100644
--- a/src/mlpack/methods/cf/cf.cpp
+++ b/src/mlpack/methods/cf/cf.cpp
@@ -95,11 +95,11 @@ void CF::GetRecommendations(const size_t numRecs,
 
     // Let's build the list of candidate recomendations for the given user.
     // Default candidate: the smallest possible value and invalid item number.
-    const Candidate def(-DBL_MAX, cleanedData.n_rows);
+    const Candidate def = std::make_pair(-DBL_MAX, cleanedData.n_rows);
     std::vector<Candidate> vect(numRecs, def);
-    typedef std::priority_queue<Candidate, std::vector<Candidate>,
-        std::greater<Candidate>> CandidateList;
-    CandidateList pqueue(std::greater<Candidate>(), std::move(vect));
+    typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+        CandidateList;
+    CandidateList pqueue(CandidateCmp(), std::move(vect));
 
     // Look through the averages column corresponding to the current user.
     for (size_t j = 0; j < averages.n_rows; ++j)
@@ -108,11 +108,11 @@ void CF::GetRecommendations(const size_t numRecs,
       if (cleanedData(j, users(i)) != 0.0)
         continue; // The user already rated the item.
 
-      Candidate c(averages[j], j);
 
       // Is the estimated value better than the worst candidate?
-      if (c > pqueue.top())
+      if (averages[i] > pqueue.top().first)
       {
+        Candidate c = std::make_pair(averages[j], j);
         pqueue.pop();
         pqueue.push(c);
       }
@@ -120,14 +120,14 @@ void CF::GetRecommendations(const size_t numRecs,
 
     for (size_t p = 1; p <= numRecs; p++)
     {
-      recommendations(numRecs - p, i) = pqueue.top().item;
-      values(numRecs - p, i) = pqueue.top().value;
+      recommendations(numRecs - p, i) = pqueue.top().second;
+      values(numRecs - p, i) = pqueue.top().first;
       pqueue.pop();
     }
 
     // If we were not able to come up with enough recommendations, issue a
     // warning.
-    if (recommendations(numRecs - 1, i) == def.item)
+    if (recommendations(numRecs - 1, i) == def.second)
       Log::Warn << "Could not provide " << numRecs << " recommendations "
           << "for user " << users(i) << " (not enough un-rated items)!"
           << std::endl;
diff --git a/src/mlpack/methods/cf/cf.hpp b/src/mlpack/methods/cf/cf.hpp
index 42b1a4b..82d624c 100644
--- a/src/mlpack/methods/cf/cf.hpp
+++ b/src/mlpack/methods/cf/cf.hpp
@@ -258,22 +258,14 @@ class CF
   //! Cleaned data matrix.
   arma::sp_mat cleanedData;
 
-  //! Candidate represents a possible recommendation.
-  struct Candidate
-  {
-    //! Value of this recommendation.
-    double value;
-    //! Item of this recommendation.
-    size_t item;
-    //! Trivial constructor.
-    Candidate(double value, size_t item) :
-        value(value),
-        item(item)
-    {};
-    //! Compare the value of two candidates.
-    friend bool operator>(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible recommendation (value, item).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the value.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return l.value > r.value;
+      return c1.first > c2.first;
     };
   };
 }; // class CF
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 796f0db..031fe43 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -251,28 +251,20 @@ class FastMKS
   //! The instantiated inner-product metric induced by the given kernel.
   metric::IPMetric<KernelType> metric;
 
-  //! Candidate point from the reference set.
-  struct Candidate
-  {
-    //! Kernel value calculated between a reference point and the query point.
-    double product;
-    //! Index of the reference point.
-    size_t index;
-    //! Trivial constructor.
-    Candidate(double p, size_t i) :
-        product(p),
-        index(i)
-    {};
-    //! Compare two candidates.
-    friend bool operator>(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible candidate point (value, index).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the value.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return l.product > r.product;
+      return c1.first > c2.first;
     };
   };
 
   //! Use a priority queue to represent the list of candidate points.
   typedef std::priority_queue<Candidate, std::vector<Candidate>,
-      std::greater<Candidate>> CandidateList;
+      CandidateCmp> CandidateList;
 };
 
 } // namespace fastmks
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index f0f321f..993ba54 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -223,18 +223,18 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     // Simple double loop.  Stupid, slow, but a good benchmark.
     for (size_t q = 0; q < querySet.n_cols; ++q)
     {
-      const Candidate def(-DBL_MAX, size_t() - 1);
+      const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
       std::vector<Candidate> cList(k, def);
-      CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+      CandidateList pqueue(CandidateCmp(), std::move(cList));
 
       for (size_t r = 0; r < referenceSet->n_cols; ++r)
       {
         const double eval = metric.Kernel().Evaluate(querySet.col(q),
                                                      referenceSet->col(r));
 
-        Candidate c(eval, r);
-        if (c > pqueue.top())
+        if (eval > pqueue.top().first)
         {
+          Candidate c = std::make_pair(eval, r);
           pqueue.pop();
           pqueue.push(c);
         }
@@ -242,8 +242,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
 
       for (size_t j = 1; j <= k; j++)
       {
-        indices(k - j, q) = pqueue.top().index;
-        kernels(k - j, q) = pqueue.top().product;
+        indices(k - j, q) = pqueue.top().second;
+        kernels(k - j, q) = pqueue.top().first;
         pqueue.pop();
       }
     }
@@ -352,9 +352,9 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
     // Simple double loop.  Stupid, slow, but a good benchmark.
     for (size_t q = 0; q < referenceSet->n_cols; ++q)
     {
-      const Candidate def(-DBL_MAX, size_t() - 1);
+      const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
       std::vector<Candidate> cList(k, def);
-      CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+      CandidateList pqueue(CandidateCmp(), std::move(cList));
 
       for (size_t r = 0; r < referenceSet->n_cols; ++r)
       {
@@ -364,9 +364,9 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
         const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
                                                      referenceSet->col(r));
 
-        Candidate c(eval, r);
-        if (c > pqueue.top())
+        if (eval > pqueue.top().first)
         {
+          Candidate c = std::make_pair(eval, r);
           pqueue.pop();
           pqueue.push(c);
         }
@@ -374,8 +374,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
 
       for (size_t j = 1; j <= k; j++)
       {
-        indices(k - j, q) = pqueue.top().index;
-        kernels(k - j, q) = pqueue.top().product;
+        indices(k - j, q) = pqueue.top().second;
+        kernels(k - j, q) = pqueue.top().first;
         pqueue.pop();
       }
     }
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 11b9604..9aca42d 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -124,22 +124,14 @@ class FastMKSRules
   //! The query dataset.
   const typename TreeType::Mat& querySet;
 
-  //! Candidate point from the reference set.
-  struct Candidate
-  {
-    //! Kernel value calculated between a reference point and the query point.
-    double product;
-    //! Index of the reference point.
-    size_t index;
-    //! Trivial constructor.
-    Candidate(double p, size_t i) :
-        product(p),
-        index(i)
-    {};
-    //! Compare two candidates.
-    friend bool operator>(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible candidate point (value, index).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the value.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return l.product > r.product;
+      return c1.first > c2.first;
     };
   };
 
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index a5cf681..6efc9a7 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -50,7 +50,7 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(
   // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1)
   // The list of candidates will be updated when visiting new points with the
   // BaseCase() method.
-  const Candidate def(-DBL_MAX, size_t() - 1);
+  const Candidate def = std::make_pair(-DBL_MAX, size_t() - 1);
 
   CandidateList cList(k, def);
   std::vector<CandidateList> tmp(querySet.n_cols, cList);
@@ -68,16 +68,15 @@ void FastMKSRules<KernelType, TreeType>::GetResults(
   for (size_t i = 0; i < querySet.n_cols; i++)
   {
     CandidateList& pqueue = candidates[i];
-    std::greater<Candidate> greater;
     typedef typename CandidateList::iterator Iterator;
 
     for (Iterator end = pqueue.end(); end != pqueue.begin(); --end)
-      std::pop_heap(pqueue.begin(), end, greater);
+      std::pop_heap(pqueue.begin(), end, CandidateCmp());
 
     for (size_t j = 0; j < k; j++)
     {
-      indices(j, i) = pqueue[j].index;
-      products(j, i) = pqueue[j].product;
+      indices(j, i) = pqueue[j].second;
+      products(j, i) = pqueue[j].first;
     }
   }
 }
@@ -127,7 +126,7 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
                                                  TreeType& referenceNode)
 {
   // Compare with the current best.
-  const double bestKernel = candidates[queryIndex].front().product;
+  const double bestKernel = candidates[queryIndex].front().first;
 
   // See if we can perform a parent-child prune.
   const double furthestDist = referenceNode.FurthestDescendantDistance();
@@ -410,7 +409,7 @@ double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
                                                    TreeType& /*referenceNode*/,
                                                    const double oldScore) const
 {
-  const double bestKernel = candidates[queryIndex].front().product;
+  const double bestKernel = candidates[queryIndex].front().first;
 
   return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
 }
@@ -458,10 +457,10 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
   {
     const size_t point = queryNode.Point(i);
     const CandidateList& candidatesPoints = candidates[point];
-    if (candidatesPoints.front().product < worstPointKernel)
-      worstPointKernel = candidatesPoints.front().product;
+    if (candidatesPoints.front().first < worstPointKernel)
+      worstPointKernel = candidatesPoints.front().first;
 
-    if (candidatesPoints.front().product == -DBL_MAX)
+    if (candidatesPoints.front().first == -DBL_MAX)
       continue; // Avoid underflow.
 
     // This should be (queryDescendantDistance + centroidDistance) for any tree
@@ -478,8 +477,8 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
     double worstPointCandidateKernel = DBL_MAX;
     for (size_t j = 0; j < candidatesPoints.size(); ++j)
     {
-      const double candidateKernel = candidatesPoints[j].product -
-          queryDescendantDistance * referenceKernels[candidatesPoints[j].index];
+      const double candidateKernel = candidatesPoints[j].first -
+          queryDescendantDistance * referenceKernels[candidatesPoints[j].second];
       if (candidateKernel < worstPointCandidateKernel)
         worstPointCandidateKernel = candidateKernel;
     }
@@ -526,13 +525,13 @@ inline void FastMKSRules<KernelType, TreeType>::InsertNeighbor(
     const size_t index,
     const double product)
 {
-  Candidate c(product, index);
   CandidateList& pqueue = candidates[queryIndex];
-  if (c > pqueue.front())
+  if (product > pqueue.front().first)
   {
-    std::pop_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+    Candidate c = std::make_pair(product, index);
+    std::pop_heap(pqueue.begin(), pqueue.end(), CandidateCmp());
     pqueue.back() = c;
-    std::push_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+    std::push_heap(pqueue.begin(), pqueue.end(), CandidateCmp());
   }
 }
 
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index 45284ba..c622fbb 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -428,28 +428,21 @@ class LSHSearch
   //! The number of distance evaluations.
   size_t distanceEvaluations;
 
-  //! Candidate represents a possible candidate neighbor (from the reference
-  // set).
-  struct Candidate
-  {
-    //! Distance between the reference point and the query point.
-    double dist;
-    //! Index of the reference point.
-    size_t index;
-    //! Trivial constructor.
-    Candidate(double d, size_t i) :
-        dist(d),
-        index(i)
-    {};
-    //! Compare the distance of two candidates.
-    friend bool operator<(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible candidate neighbor (distance, index).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the distance.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return !SortPolicy::IsBetter(r.dist, l.dist);
+      return !SortPolicy::IsBetter(c2.first, c1.first);
     };
   };
 
   //! Use a priority queue to represent the list of candidate neighbors.
-  typedef std::priority_queue<Candidate> CandidateList;
+  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+      CandidateList;
+
 }; // class LSHSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index d0b53ae..bdef1be 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -275,9 +275,10 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
   // Let's build the list of candidate neighbors for the given query point.
   // It will be initialized with k candidates:
   // (WorstDistance, referenceSet->n_cols)
-  const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+      referenceSet->n_cols);
   std::vector<Candidate> vect(k, def);
-  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+  CandidateList pqueue(CandidateCmp(), std::move(vect));
 
   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
   {
@@ -290,9 +291,9 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         referenceSet->unsafe_col(queryIndex),
         referenceSet->unsafe_col(referenceIndex));
 
-    Candidate c(distance, referenceIndex);
+    Candidate c = std::make_pair(distance, referenceIndex);
     // If this distance is better than the worst candidate, let's insert it.
-    if (c < pqueue.top())
+    if (CandidateCmp()(c, pqueue.top()))
     {
       pqueue.pop();
       pqueue.push(c);
@@ -301,8 +302,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
 
   for (size_t j = 1; j <= k; j++)
   {
-    neighbors(k - j, queryIndex) = pqueue.top().index;
-    distances(k - j, queryIndex) = pqueue.top().dist;
+    neighbors(k - j, queryIndex) = pqueue.top().second;
+    distances(k - j, queryIndex) = pqueue.top().first;
     pqueue.pop();
   }
 }
@@ -320,9 +321,10 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
   // Let's build the list of candidate neighbors for the given query point.
   // It will be initialized with k candidates:
   // (WorstDistance, referenceSet->n_cols)
-  const Candidate def(SortPolicy::WorstDistance(), referenceSet->n_cols);
+  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+      referenceSet->n_cols);
   std::vector<Candidate> vect(k, def);
-  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+  CandidateList pqueue(CandidateCmp(), std::move(vect));
 
   for (size_t j = 0; j < referenceIndices.n_elem; ++j)
   {
@@ -331,9 +333,9 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
         querySet.unsafe_col(queryIndex),
         referenceSet->unsafe_col(referenceIndex));
 
-    Candidate c(distance, referenceIndex);
+    Candidate c = std::make_pair(distance, referenceIndex);
     // If this distance is better than the worst candidate, let's insert it.
-    if (c < pqueue.top())
+    if (CandidateCmp()(c, pqueue.top()))
     {
       pqueue.pop();
       pqueue.push(c);
@@ -342,8 +344,8 @@ void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
 
   for (size_t j = 1; j <= k; j++)
   {
-    neighbors(k - j, queryIndex) = pqueue.top().index;
-    distances(k - j, queryIndex) = pqueue.top().dist;
+    neighbors(k - j, queryIndex) = pqueue.top().second;
+    distances(k - j, queryIndex) = pqueue.top().first;
     pqueue.pop();
   }
 }
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 25e7175..a44d06a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -143,28 +143,20 @@ class NeighborSearchRules
   //! The query set.
   const typename TreeType::Mat& querySet;
 
-  //! Candidate represents a possible candidate neighbor (from the reference
-  // set).
-  struct Candidate
-  {
-    //! Distance between the reference point and the query point.
-    double dist;
-    //! Index of the reference point.
-    size_t index;
-    //! Trivial constructor.
-    Candidate(double d, size_t i) :
-        dist(d),
-        index(i)
-    {};
-    //! Compare the distance of two candidates.
-    friend bool operator<(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible candidate neighbor (distance, index).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the distance.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return !SortPolicy::IsBetter(r.dist, l.dist);
+      return !SortPolicy::IsBetter(c2.first, c1.first);
     };
   };
 
   //! Use a priority queue to represent the list of candidate neighbors.
-  typedef std::priority_queue<Candidate> CandidateList;
+  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+      CandidateList;
 
   //! Set of candidate neighbors for each point.
   std::vector<CandidateList> candidates;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 65d258e..e40d09e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -42,10 +42,11 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
   // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
   // The list of candidates will be updated when visiting new points with the
   // BaseCase() method.
-  const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+      size_t() - 1);
 
   std::vector<Candidate> vect(k, def);
-  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+  CandidateList pqueue(CandidateCmp(), std::move(vect));
 
   candidates.reserve(querySet.n_cols);
   for (size_t i = 0; i < querySet.n_cols; i++)
@@ -65,8 +66,8 @@ void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
     CandidateList& pqueue = candidates[i];
     for (size_t j = 1; j <= k; j++)
     {
-      neighbors(k - j, i) = pqueue.top().index;
-      distances(k - j, i) = pqueue.top().dist;
+      neighbors(k - j, i) = pqueue.top().second;
+      distances(k - j, i) = pqueue.top().first;
       pqueue.pop();
     }
   }
@@ -136,7 +137,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   }
 
   // Compare against the best k'th distance for this query point so far.
-  double bestDistance = candidates[queryIndex].top().dist;
+  double bestDistance = candidates[queryIndex].top().first;
   bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -153,7 +154,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     return oldScore;
 
   // Just check the score again against the distances.
-  double bestDistance = candidates[queryIndex].top().dist;
+  double bestDistance = candidates[queryIndex].top().first;
   bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
@@ -376,7 +377,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   // Loop over points held in the node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    const double distance = candidates[queryNode.Point(i)].top().dist;
+    const double distance = candidates[queryNode.Point(i)].top().first;
     if (SortPolicy::IsBetter(worstDistance, distance))
       worstDistance = distance;
     if (SortPolicy::IsBetter(distance, bestPointDistance))
@@ -467,9 +468,10 @@ InsertNeighbor(
     const size_t neighbor,
     const double distance)
 {
-  Candidate c(distance, neighbor);
   CandidateList& pqueue = candidates[queryIndex];
-  if (c < pqueue.top())
+  Candidate c = std::make_pair(distance, neighbor);
+
+  if (CandidateCmp()(c, pqueue.top()))
   {
     pqueue.pop();
     pqueue.push(c);
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index c7edeba..93b7a8c 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -243,28 +243,20 @@ class RASearchRules
   //! The query set.
   const arma::mat& querySet;
 
-  //! Candidate represents a possible candidate neighbor (from the reference
-  // set).
-  struct Candidate
-  {
-    //! Distance between the reference point and the query point.
-    double dist;
-    //! Index of the reference point.
-    size_t index;
-    //! Trivial constructor.
-    Candidate(double d, size_t i) :
-        dist(d),
-        index(i)
-    {};
-    //! Compare the distance of two candidates.
-    friend bool operator<(const Candidate& l, const Candidate& r)
+  //! Candidate represents a possible candidate neighbor (distance, index).
+  typedef std::pair<double, size_t> Candidate;
+
+  //! Compare two candidates based on the distance.
+  struct CandidateCmp {
+    bool operator()(const Candidate& c1, const Candidate& c2)
     {
-      return !SortPolicy::IsBetter(r.dist, l.dist);
+      return !SortPolicy::IsBetter(c2.first, c1.first);
     };
   };
 
   //! Use a priority queue to represent the list of candidate neighbors.
-  typedef std::priority_queue<Candidate> CandidateList;
+  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+      CandidateList;
 
   //! Set of candidate neighbors for each point.
   std::vector<CandidateList> candidates;
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index bad3e24..9a886db 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -69,10 +69,11 @@ RASearchRules(const arma::mat& referenceSet,
   // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
   // The list of candidates will be updated when visiting new points with the
   // BaseCase() method.
-  const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+  const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
+      size_t() - 1);
 
   std::vector<Candidate> vect(k, def);
-  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+  CandidateList pqueue(CandidateCmp(), std::move(vect));
 
   candidates.reserve(querySet.n_cols);
   for (size_t i = 0; i < querySet.n_cols; i++)
@@ -104,8 +105,8 @@ void RASearchRules<SortPolicy, MetricType, TreeType>::GetResults(
     CandidateList& pqueue = candidates[i];
     for (size_t j = 1; j <= k; j++)
     {
-      neighbors(k - j, i) = pqueue.top().index;
-      distances(k - j, i) = pqueue.top().dist;
+      neighbors(k - j, i) = pqueue.top().second;
+      distances(k - j, i) = pqueue.top().first;
       pqueue.pop();
     }
   }
@@ -143,7 +144,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
       &referenceNode);
-  const double bestDistance = candidates[queryIndex].top().dist;
+  const double bestDistance = candidates[queryIndex].top().first;
 
   return Score(queryIndex, referenceNode, distance, bestDistance);
 }
@@ -157,7 +158,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
       &referenceNode, baseCaseResult);
-  const double bestDistance = candidates[queryIndex].top().dist;
+  const double bestDistance = candidates[queryIndex].top().first;
 
   return Score(queryIndex, referenceNode, distance, bestDistance);
 }
@@ -271,7 +272,7 @@ Rescore(const size_t queryIndex,
     return oldScore;
 
   // Just check the score again against the distances.
-  const double bestDistance = candidates[queryIndex].top().dist;
+  const double bestDistance = candidates[queryIndex].top().first;
 
   // If this is better than the best distance we've seen so far,
   // maybe there will be something down this node.
@@ -371,7 +372,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = candidates[queryNode.Point(i)].top().dist
+    const double bound = candidates[queryNode.Point(i)].top().first
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -410,7 +411,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = candidates[queryNode.Point(i)].top().dist
+    const double bound = candidates[queryNode.Point(i)].top().first
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -624,7 +625,7 @@ Rescore(TreeType& queryNode,
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = candidates[queryNode.Point(i)].top().dist
+    const double bound = candidates[queryNode.Point(i)].top().first
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -809,9 +810,10 @@ InsertNeighbor(
     const size_t neighbor,
     const double distance)
 {
-  Candidate c(distance, neighbor);
   CandidateList& pqueue = candidates[queryIndex];
-  if (c < pqueue.top())
+  Candidate c = std::make_pair(distance, neighbor);
+
+  if (CandidateCmp()(c, pqueue.top()))
   {
     pqueue.pop();
     pqueue.push(c);




More information about the mlpack-git mailing list