[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching. (31aeb5e)

gitdub at mlpack.org gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b

>---------------------------------------------------------------

commit 31aeb5e813413505150a14be5c08979920c0d261
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 21 12:52:40 2016 -0300

    Use a priority queue (heap) to store the list of candidates while searching.


>---------------------------------------------------------------

31aeb5e813413505150a14be5c08979920c0d261
 src/mlpack/methods/rann/ra_search_impl.hpp       | 35 +++++----
 src/mlpack/methods/rann/ra_search_rules.hpp      | 52 +++++++++++---
 src/mlpack/methods/rann/ra_search_rules_impl.hpp | 90 ++++++++++++++----------
 3 files changed, 112 insertions(+), 65 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index aa8daa5..16360b5 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -360,9 +360,8 @@ Search(const MatType& querySet,
 
   if (naive)
   {
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
-                   tau, alpha, naive, sampleAtLeaves, firstLeafExact,
-                   singleSampleLimit, false);
+    RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
+        sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
 
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
@@ -377,12 +376,13 @@ Search(const MatType& querySet,
     for (size_t i = 0; i < querySet.n_cols; ++i)
       for (size_t j = 0; j < distinctSamples.n_elem; ++j)
         rules.BaseCase(i, (size_t) distinctSamples[j]);
+
+    rules.GetResults(*neighborPtr, *distancePtr);
   }
   else if (singleMode)
   {
-    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
-                   tau, alpha, naive, sampleAtLeaves, firstLeafExact,
-                   singleSampleLimit, false);
+    RuleType rules(*referenceSet, querySet, k, metric, tau, alpha, naive,
+        sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
 
     // If the reference root node is a leaf, then the sampling has already been
     // done in the RASearchRules constructor.  This happens when naive = true.
@@ -402,6 +402,8 @@ Search(const MatType& querySet,
           << (rules.NumDistComputations() / querySet.n_cols) << "."
           << std::endl;
     }
+
+    rules.GetResults(*neighborPtr, *distancePtr);
   }
   else // Dual-tree recursion.
   {
@@ -415,9 +417,8 @@ Search(const MatType& querySet,
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
-    RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
-                   *distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
-                   firstLeafExact, singleSampleLimit, false);
+    RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
+        naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     Log::Info << "Query statistic pre-search: "
@@ -429,6 +430,8 @@ Search(const MatType& querySet,
     Log::Info << "Average number of distance calculations per query point: "
         << (rules.NumDistComputations() / querySet.n_cols) << "." << std::endl;
 
+    rules.GetResults(*neighborPtr, *distancePtr);
+
     delete queryTree;
   }
 
@@ -529,14 +532,15 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
 
   // Create the helper object for the tree traversal.
   typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, distances,
-                 metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
-                 singleSampleLimit, false);
+  RuleType rules(*referenceSet, queryTree->Dataset(), k, metric, tau, alpha,
+      naive, sampleAtLeaves, firstLeafExact, singleSampleLimit, false);
 
   // Create the traverser.
   typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
   traverser.Traverse(*queryTree, *referenceTree);
 
+  rules.GetResults(*neighborPtr, distances);
+
   Timer::Stop("computing_neighbors");
 
   // Do we need to map indices?
@@ -586,9 +590,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
 
   // Create the helper object for the tree traversal.
   typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
-                 metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
-                 singleSampleLimit, true /* sets are the same */);
+  RuleType rules(*referenceSet, *referenceSet, k, metric, tau, alpha, naive,
+      sampleAtLeaves, firstLeafExact, singleSampleLimit, true /* same sets */);
 
   if (naive)
   {
@@ -622,6 +625,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
     traverser.Traverse(*referenceTree, *referenceTree);
   }
 
+  rules.GetResults(*neighborPtr, *distancePtr);
+
   Timer::Stop("computing_neighbors");
 
   // Do we need to map the reference indices?
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 1037af4..b04f9bc 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -10,6 +10,8 @@
 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
 #include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
+#include <queue>
 
 namespace mlpack {
 namespace neighbor {
@@ -20,8 +22,7 @@ class RASearchRules
  public:
   RASearchRules(const arma::mat& referenceSet,
                 const arma::mat& querySet,
-                arma::Mat<size_t>& neighbors,
-                arma::mat& distances,
+                const size_t k,
                 MetricType& metric,
                 const double tau = 5,
                 const double alpha = 0.95,
@@ -31,6 +32,15 @@ class RASearchRules
                 const size_t singleSampleLimit = 20,
                 const bool sameSet = false);
 
+  /**
+   * Store the list of candidates for each query point in the given matrices.
+   *
+   * @param neighbors Matrix storing lists of neighbors for each query point.
+   * @param distances Matrix storing distances of neighbors for each query
+   *     point.
+   */
+  void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
+
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
   /**
@@ -197,11 +207,34 @@ class RASearchRules
   //! The query set.
   const arma::mat& querySet;
 
-  //! The matrix the resultant neighbor indices should be stored in.
-  arma::Mat<size_t>& neighbors;
-
-  //! The matrix the resultant neighbor distances should be stored in.
-  arma::mat& distances;
+  //! Candidate represents a possible candidate neighbor (from the reference
+  // set).
+  struct Candidate
+  {
+    //! Distance between the reference point and the query point.
+    double dist;
+    //! Index of the reference point.
+    size_t index;
+    //! Trivial constructor.
+    Candidate(double d, size_t i) :
+        dist(d),
+        index(i)
+    {};
+    //! Compare the distance of two candidates.
+    friend bool operator<(const Candidate& l, const Candidate& r)
+    {
+      return !SortPolicy::IsBetter(r.dist, l.dist);
+    };
+  };
+
+  //! Use a priority queue to represent the list of candidate neighbors.
+  typedef std::priority_queue<Candidate> CandidateList;
+
+  //! Set of candidate neighbors for each point.
+  std::vector<CandidateList> candidates;
+
+  //! Number of neighbors to search for.
+  const size_t k;
 
   //! The instantiated metric.
   MetricType& metric;
@@ -233,16 +266,13 @@ class RASearchRules
   TraversalInfoType traversalInfo;
 
   /**
-   * Insert a point into the neighbors and distances matrices; this is a helper
-   * function.
+   * Helper function to insert a point into the list of candidate points.
    *
    * @param queryIndex Index of point whose neighbors we are inserting into.
-   * @param pos Position in list to insert into.
    * @param neighbor Index of reference point which is being inserted.
    * @param distance Distance from query point to reference point.
    */
   void InsertNeighbor(const size_t queryIndex,
-                      const size_t pos,
                       const size_t neighbor,
                       const double distance);
 
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index 2071de1..bad3e24 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -17,8 +17,7 @@ template<typename SortPolicy, typename MetricType, typename TreeType>
 RASearchRules<SortPolicy, MetricType, TreeType>::
 RASearchRules(const arma::mat& referenceSet,
               const arma::mat& querySet,
-              arma::Mat<size_t>& neighbors,
-              arma::mat& distances,
+              const size_t k,
               MetricType& metric,
               const double tau,
               const double alpha,
@@ -29,8 +28,7 @@ RASearchRules(const arma::mat& referenceSet,
               const bool sameSet) :
     referenceSet(referenceSet),
     querySet(querySet),
-    neighbors(neighbors),
-    distances(distances),
+    k(k),
     metric(metric),
     sampleAtLeaves(sampleAtLeaves),
     firstLeafExact(firstLeafExact),
@@ -42,7 +40,6 @@ RASearchRules(const arma::mat& referenceSet,
 
   // The rank approximation.
   const size_t n = referenceSet.n_cols;
-  const size_t k = neighbors.n_rows;
   const size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
   if (t < k)
   {
@@ -68,7 +65,20 @@ RASearchRules(const arma::mat& referenceSet,
   Log::Info << "Minimum samples required per query: " << numSamplesReqd <<
     ", sampling ratio: " << samplingRatio << std::endl;
 
-  if (naive) // No tree traversal; just do naive sampling here.
+  // Let's build the list of candidate neighbors for each query point.
+  // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
+  // The list of candidates will be updated when visiting new points with the
+  // BaseCase() method.
+  const Candidate def(SortPolicy::WorstDistance(), size_t() - 1);
+
+  std::vector<Candidate> vect(k, def);
+  CandidateList pqueue(std::less<Candidate>(), std::move(vect));
+
+  candidates.reserve(querySet.n_cols);
+  for (size_t i = 0; i < querySet.n_cols; i++)
+    candidates.push_back(pqueue);
+
+  if (naive)// No tree traversal; just do naive sampling here.
   {
     // Sample enough points.
     for (size_t i = 0; i < querySet.n_cols; ++i)
@@ -82,6 +92,26 @@ RASearchRules(const arma::mat& referenceSet,
 }
 
 template<typename SortPolicy, typename MetricType, typename TreeType>
+void RASearchRules<SortPolicy, MetricType, TreeType>::GetResults(
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances)
+{
+  neighbors.set_size(k, querySet.n_cols);
+  distances.set_size(k, querySet.n_cols);
+
+  for (size_t i = 0; i < querySet.n_cols; i++)
+  {
+    CandidateList& pqueue = candidates[i];
+    for (size_t j = 1; j <= k; j++)
+    {
+      neighbors(k - j, i) = pqueue.top().index;
+      distances(k - j, i) = pqueue.top().dist;
+      pqueue.pop();
+    }
+  }
+};
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
 inline force_inline
 double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
     const size_t queryIndex,
@@ -95,16 +125,7 @@ double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
   double distance = metric.Evaluate(querySet.unsafe_col(queryIndex),
                                     referenceSet.unsafe_col(referenceIndex));
 
-  // If this distance is better than any of the current candidates, the
-  // SortDistance() function will give us the position to insert it into.
-  arma::vec queryDist = distances.unsafe_col(queryIndex);
-  arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
-  size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
-      distance);
-
-  // SortDistance() returns (size_t() - 1) if we shouldn't add it.
-  if (insertPosition != (size_t() - 1))
-    InsertNeighbor(queryIndex, insertPosition, referenceIndex, distance);
+  InsertNeighbor(queryIndex, referenceIndex, distance);
 
   numSamplesMade[queryIndex]++;
 
@@ -122,7 +143,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
       &referenceNode);
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  const double bestDistance = candidates[queryIndex].top().dist;
 
   return Score(queryIndex, referenceNode, distance, bestDistance);
 }
@@ -136,7 +157,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
   const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
   const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
       &referenceNode, baseCaseResult);
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  const double bestDistance = candidates[queryIndex].top().dist;
 
   return Score(queryIndex, referenceNode, distance, bestDistance);
 }
@@ -250,7 +271,7 @@ Rescore(const size_t queryIndex,
     return oldScore;
 
   // Just check the score again against the distances.
-  const double bestDistance = distances(distances.n_rows - 1, queryIndex);
+  const double bestDistance = candidates[queryIndex].top().dist;
 
   // If this is better than the best distance we've seen so far,
   // maybe there will be something down this node.
@@ -350,7 +371,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+    const double bound = candidates[queryNode.Point(i)].top().dist
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -389,7 +410,7 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+    const double bound = candidates[queryNode.Point(i)].top().dist
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -603,7 +624,7 @@ Rescore(TreeType& queryNode,
 
   for (size_t i = 0; i < queryNode.NumPoints(); i++)
   {
-    const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
+    const double bound = candidates[queryNode.Point(i)].top().dist
         + maxDescendantDistance;
     if (bound < pointBound)
       pointBound = bound;
@@ -775,35 +796,26 @@ Rescore(TreeType& queryNode,
 } // Rescore(node, node, oldScore)
 
 /**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
  *
  * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
 template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearchRules<SortPolicy, MetricType, TreeType>::InsertNeighbor(
+inline void RASearchRules<SortPolicy, MetricType, TreeType>::
+InsertNeighbor(
     const size_t queryIndex,
-    const size_t pos,
     const size_t neighbor,
     const double distance)
 {
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (distances.n_rows - 1))
+  Candidate c(distance, neighbor);
+  CandidateList& pqueue = candidates[queryIndex];
+  if (c < pqueue.top())
   {
-    int len = (distances.n_rows - 1) - pos;
-    memmove(distances.colptr(queryIndex) + (pos + 1),
-        distances.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(neighbors.colptr(queryIndex) + (pos + 1),
-        neighbors.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
+    pqueue.pop();
+    pqueue.push(c);
   }
-
-  // Now put the new information in the right index.
-  distances(pos, queryIndex) = distance;
-  neighbors(pos, queryIndex) = neighbor;
 }
 
 } // namespace neighbor




More information about the mlpack-git mailing list