[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. This makes the code more efficient, especially when k is greater. For example, for knn, given a list of k candidates neighbors, we need to do 2 fast operations: - know the furthest of them. - insert a new candidate. This is the appropiate situation for using a heap. (198cec8)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b

>---------------------------------------------------------------

commit 198cec80a434b3d88a993e3d67a18b778ebc07f1
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 21 10:52:06 2016 -0300

    Use a priority queue (heap) to store the list of candidates while searching.
    This makes the code more efficient, especially when k is greater.
    For example, for knn, given a list of k candidates neighbors, we need
    to do 2 fast operations:
     - know the furthest of them.
     - insert a new candidate.
    This is the appropiate situation for using a heap.


>---------------------------------------------------------------

198cec80a434b3d88a993e3d67a18b778ebc07f1
 .../neighbor_search/neighbor_search_impl.hpp       | 26 ++++---
 .../neighbor_search/neighbor_search_rules.hpp      | 53 +++++++++++---
 .../neighbor_search/neighbor_search_rules_impl.hpp | 81 +++++++++++++---------
 3 files changed, 105 insertions(+), 55 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 73560e2..8d0c694 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -376,8 +376,7 @@ Search(const MatType& querySet,
   if (naive)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
-        epsilon);
+    RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
     // The naive brute-force traversal.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -385,12 +384,13 @@ Search(const MatType& querySet,
         rules.BaseCase(i, j);
 
     baseCases += querySet.n_cols * referenceSet->n_cols;
+
+    rules.GetResults(*neighborPtr, *distancePtr);
   }
   else if (singleMode)
   {
     // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
-        epsilon);
+    RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
     // Create the traverser.
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -404,6 +404,8 @@ Search(const MatType& querySet,
 
     Log::Info << rules.Scores() << " node combinations were scored.\n";
     Log::Info << rules.BaseCases() << " base cases were calculated.\n";
+
+    rules.GetResults(*neighborPtr, *distancePtr);
   }
   else // Dual-tree recursion.
   {
@@ -415,8 +417,7 @@ Search(const MatType& querySet,
     Timer::Start("computing_neighbors");
 
     // Create the helper object for the tree traversal.
-    RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
-        *distancePtr, metric, epsilon);
+    RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, epsilon);
 
     // Create the traverser.
     TraversalType<RuleType> traverser(rules);
@@ -429,6 +430,8 @@ Search(const MatType& querySet,
     Log::Info << rules.Scores() << " node combinations were scored.\n";
     Log::Info << rules.BaseCases() << " base cases were calculated.\n";
 
+    rules.GetResults(*neighborPtr, *distancePtr);
+
     delete queryTree;
   }
 
@@ -541,8 +544,7 @@ Search(Tree* queryTree,
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric,
-      epsilon);
+  RuleType rules(*referenceSet, querySet, k, metric, epsilon);
 
   // Create the traverser.
   TraversalType<RuleType> traverser(rules);
@@ -551,6 +553,8 @@ Search(Tree* queryTree,
   scores += rules.Scores();
   baseCases += rules.BaseCases();
 
+  rules.GetResults(*neighborPtr, distances);
+
   Timer::Stop("computing_neighbors");
 
   // Do we need to map indices?
@@ -612,8 +616,8 @@ Search(const size_t k,
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
-      metric, epsilon, true /* don't return the same point as nearest neighbor */);
+  RuleType rules(*referenceSet, *referenceSet, k, metric, epsilon,
+      true /* don't return the same point as nearest neighbor */);
 
   if (naive)
   {
@@ -676,6 +680,8 @@ Search(const size_t k,
     treeNeedsReset = true;
   }
 
+  rules.GetResults(*neighborPtr, *distancePtr);
+
   Timer::Stop("computing_neighbors");
 
   // Do we need to map the reference indices?
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 47a7933..0bcdc49 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -9,6 +9,8 @@
 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
 
 #include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
+#include <queue>
 
 namespace mlpack {
 namespace neighbor {
@@ -19,11 +21,20 @@ class NeighborSearchRules
  public:
   NeighborSearchRules(const typename TreeType::Mat& referenceSet,
                       const typename TreeType::Mat& querySet,
-                      arma::Mat<size_t>& neighbors,
-                      arma::mat& distances,
+                      const size_t k,
                       MetricType& metric,
                       const double epsilon = 0,
                       const bool sameSet = false);
+
+  /**
+   * Store the list of candidates for each query point in the given matrices.
+   *
+   * @param neighbors Matrix storing lists of neighbors for each query point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *     point.
+   */
+  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
+
   /**
    * Get the distance from the query point to the reference point.
    * This will update the "neighbor" matrix with the new point if appropriate
@@ -109,11 +120,34 @@ class NeighborSearchRules
   //! The query set.
   const typename TreeType::Mat& querySet;
 
-  //! The matrix the resultant neighbor indices should be stored in.
-  arma::Mat<size_t>& neighbors;
-
-  //! The matrix the resultant neighbor distances should be stored in.
-  arma::mat& distances;
+  //! Candidate represents a possible candidate neighbor (from the reference
+  // set).
+  struct Candidate
+  {
+    //! Distance between the reference point and the query point.
+    double dist;
+    //! Index of the reference point.
+    size_t index;
+    //! Trivial constructor.
+    Candidate(double d, size_t i) :
+        dist(d),
+        index(i)
+    {};
+    //! Compare the distance of two candidates.
+    friend bool operator<(const Candidate& l, const Candidate& r)
+    {
+      return !SortPolicy::IsBetter(r.dist, l.dist);
+    };
+  };
+
+  //! Use a priority queue to represent the list of candidate neighbors.
+  typedef std::priority_queue<Candidate> CandidateList;
+
+  //! Set of candidate neighbors for each point.
+  std::vector<CandidateList> candidates;
+
+  //! Number of neighbors to search for.
+  const size_t k;
 
   //! The instantiated metric.
   MetricType& metric;
@@ -146,16 +180,13 @@ class NeighborSearchRules
   double CalculateBound(TreeType& queryNode) const;
 
   /**
-   * Insert a point into the neighbors and distances matrices; this is a helper
-   * function.
+   * Helper function to insert a point into the list of candidate points.
    *
    * @param queryIndex Index of point whose neighbors we are inserting into.
-   * @param pos Position in list to insert into.
    * @param neighbor Index of reference point which is being inserted.
    * @param distance Distance from query point to reference point.
    */
   void InsertNeighbor(const size_t queryIndex,
-                      const size_t pos,
                       const size_t neighbor,
                       const double distance);
 };
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 24f9485..65d258e 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -17,15 +17,13 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
     const typename TreeType::Mat& referenceSet,
     const typename TreeType::Mat& querySet,
-    arma::Mat<size_t>& neighbors,
-    arma::mat& distances,
+    const size_t k,
     MetricType& metric,
     const double epsilon,
     const bool sameSet) :
     referenceSet(referenceSet),
     querySet(querySet),
-    neighbors(neighbors),
-    distances(distances),
+    k(k),
     metric(metric),
     sameSet(sameSet),
     epsilon(epsilon),
@@ -39,9 +37,42 @@ NeighborSearchRules<SortPolicy, MetricType, TreeType>::NeighborSearchRules(
   // use the this pointer.
   traversalInfo.LastQueryNode() = (TreeType*) this;
   traversalInfo.LastReferenceNode() = (TreeType*) this;
+
+  // Let's build the list of candidate neighbors for each query point.
+  // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
+  // The list of candidates will be updated when visiting new points with the
+  // BaseCase() method.
+  const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+
+  std::vector<Candidate> vect(k, def);
+  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
+  candidates.reserve(querySet.n_cols);
+  for (size_t i = 0; i < querySet.n_cols; i++)
+    candidates.push_back(pqueue);
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+void NeighborSearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
+{
+  neighbors.set_size(k, querySet.n_cols);
+  distances.set_size(k, querySet.n_cols);
+
+  for (size_t i = 0; i < querySet.n_cols; i++)
+  {
+    CandidateList& pqueue = candidates[i];
+    for (size_t j = 1; j <= k; j++)
+    {
+      neighbors(k - j, i) = pqueue.top().index;
+      distances(k - j, i) = pqueue.top().dist;
+      pqueue.pop();
+    }
+  }
+};
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline force_inline // Absolutely MUST be inline so optimizations can happen.
 double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
 BaseCase(const size_t queryIndex, const size_t referenceIndex)
@@ -59,16 +90,7 @@ BaseCase(const size_t queryIndex, const size_t referenceIndex)
                                     referenceSet.col(referenceIndex));
   ++baseCases;
 
-  // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the position to insert it into.
-  arma::vec queryDist = distances.unsafe_col(queryIndex);
-  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
-  const size_t insertPosition = SortPolicy::SortDistance(queryDist,
-      queryIndices, distance);
-
-  // SortDistance() returns (size_t() - 1) if we shouldn't add it.
-  if (insertPosition != (size_t() - 1))
-    InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+  InsertNeighbor(queryIndex, referenceIndex, distance);
 
   // Cache this information for the next time BaseCase() is called.
   lastQueryIndex = queryIndex;
@@ -114,7 +136,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
   }
 
   // Compare against the best k'th distance for this query point so far.
-  double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  double bestDistance = candidates[queryIndex].top().dist;
   bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
@@ -131,7 +153,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
     return oldScore;
 
   // Just check the score again against the distances.
-  double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  double bestDistance = candidates[queryIndex].top().dist;
   bestDistance = SortPolicy::Relax(bestDistance, epsilon);
 
   return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
@@ -354,7 +376,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
   // Loop over points held in the node.
   for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
+    const double distance = candidates[queryNode.Point(i)].top().dist;
     if (SortPolicy::IsBetter(worstDistance, distance))
       worstDistance = distance;
     if (SortPolicy::IsBetter(distance, bestPointDistance))
@@ -432,35 +454,26 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
 }
 
 /**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
  *
  * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void NeighborSearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+inline void NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(
     const size_t queryIndex,
-    const size_t pos,
     const size_t neighbor,
     const double distance)
 {
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (distances.n_rows - 1))
+  Candidate c(distance, neighbor);
+  CandidateList& pqueue = candidates[queryIndex];
+  if (c < pqueue.top())
   {
-    int len = (distances.n_rows - 1) - pos;
-    memmove(distances.colptr(queryIndex) + (pos + 1),
-        distances.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(neighbors.colptr(queryIndex) + (pos + 1),
-        neighbors.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
+    pqueue.pop();
+    pqueue.push(c);
   }
-
-  // Now put the new information in the right index.
-  distances(pos, queryIndex) = distance;
-  neighbors(pos, queryIndex) = neighbor;
 }
 
 } // namespace neighbor




More information about the mlpack-git mailing list