[mlpack-git] master: Use a priority queue (heap) to store the list of candidates while searching fastmks. (bccf3e0)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Jul 26 21:22:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/ef51b032f275266f781d42b9bd0aa50aa26a3077...8522b04c3d9a82fb7e964bafd72e70f0cd30bf4b
>---------------------------------------------------------------
commit bccf3e0d442ba554a2f0276822f102fcf2a2218a
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Fri Jul 22 01:39:22 2016 -0300
Use a priority queue (heap) to store the list of candidates while searching fastmks.
>---------------------------------------------------------------
bccf3e0d442ba554a2f0276822f102fcf2a2218a
src/mlpack/methods/fastmks/fastmks.hpp | 30 ++++--
src/mlpack/methods/fastmks/fastmks_impl.hpp | 107 +++++++++-------------
src/mlpack/methods/fastmks/fastmks_rules.hpp | 59 ++++++++++--
src/mlpack/methods/fastmks/fastmks_rules_impl.hpp | 104 ++++++++++++---------
4 files changed, 174 insertions(+), 126 deletions(-)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index e849635..796f0db 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -12,6 +12,7 @@
#include <mlpack/core/metrics/ip_metric.hpp>
#include "fastmks_stat.hpp"
#include <mlpack/core/tree/cover_tree.hpp>
+#include <queue>
namespace mlpack {
namespace fastmks /** Fast max-kernel search. */ {
@@ -250,13 +251,28 @@ class FastMKS
//! The instantiated inner-product metric induced by the given kernel.
metric::IPMetric<KernelType> metric;
- //! Utility function. Copied too many times from too many places.
- void InsertNeighbor(arma::Mat<size_t>& indices,
- arma::mat& products,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
+ //! Candidate point from the reference set.
+ struct Candidate
+ {
+ //! Kernel value calculated between a reference point and the query point.
+ double product;
+ //! Index of the reference point.
+ size_t index;
+ //! Trivial constructor.
+ Candidate(double p, size_t i) :
+ product(p),
+ index(i)
+ {};
+ //! Compare two candidates.
+ friend bool operator>(const Candidate& l, const Candidate& r)
+ {
+ return l.product > r.product;
+ };
+ };
+
+ //! Use a priority queue to represent the list of candidate points.
+ typedef std::priority_queue<Candidate, std::vector<Candidate>,
+ std::greater<Candidate>> CandidateList;
};
} // namespace fastmks
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index de85cdc..f0f321f 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -13,7 +13,6 @@
#include "fastmks_rules.hpp"
#include <mlpack/core/kernels/gaussian_kernel.hpp>
-#include <queue>
namespace mlpack {
namespace fastmks {
@@ -221,25 +220,31 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Naive implementation.
if (naive)
{
- // Fill kernels.
- kernels.fill(-DBL_MAX);
-
// Simple double loop. Stupid, slow, but a good benchmark.
for (size_t q = 0; q < querySet.n_cols; ++q)
{
+ const Candidate def(-DBL_MAX, size_t() - 1);
+ std::vector<Candidate> cList(k, def);
+ CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+
for (size_t r = 0; r < referenceSet->n_cols; ++r)
{
const double eval = metric.Kernel().Evaluate(querySet.col(q),
referenceSet->col(r));
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows;
- ++insertPosition)
- if (eval > kernels(insertPosition, q))
- break;
+ Candidate c(eval, r);
+ if (c > pqueue.top())
+ {
+ pqueue.pop();
+ pqueue.push(c);
+ }
+ }
- if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+ for (size_t j = 1; j <= k; j++)
+ {
+ indices(k - j, q) = pqueue.top().index;
+ kernels(k - j, q) = pqueue.top().product;
+ pqueue.pop();
}
}
@@ -251,13 +256,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Single-tree implementation.
if (singleMode)
{
- // Fill kernels.
- kernels.fill(-DBL_MAX);
-
// Create rules object (this will store the results). This constructor
// precalculates each self-kernel value.
typedef FastMKSRules<KernelType, Tree> RuleType;
- RuleType rules(*referenceSet, querySet, indices, kernels, metric.Kernel());
+ RuleType rules(*referenceSet, querySet, k, metric.Kernel());
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -267,6 +269,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
Log::Info << rules.BaseCases() << " base cases." << std::endl;
Log::Info << rules.Scores() << " scores." << std::endl;
+ rules.GetResults(indices, kernels);
+
Timer::Stop("computing_products");
return;
}
@@ -310,12 +314,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// No remapping will be necessary because we are using the cover tree.
indices.set_size(k, queryTree->Dataset().n_cols);
kernels.set_size(k, queryTree->Dataset().n_cols);
- kernels.fill(-DBL_MAX);
Timer::Start("computing_products");
typedef FastMKSRules<KernelType, Tree> RuleType;
- RuleType rules(*referenceSet, queryTree->Dataset(), indices, kernels,
- metric.Kernel());
+ RuleType rules(*referenceSet, queryTree->Dataset(), k, metric.Kernel());
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
@@ -324,6 +326,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
Log::Info << rules.BaseCases() << " base cases." << std::endl;
Log::Info << rules.Scores() << " scores." << std::endl;
+ rules.GetResults(indices, kernels);
+
Timer::Stop("computing_products");
}
@@ -341,7 +345,6 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
Timer::Start("computing_products");
indices.set_size(k, referenceSet->n_cols);
kernels.set_size(k, referenceSet->n_cols);
- kernels.fill(-DBL_MAX);
// Naive implementation.
if (naive)
@@ -349,6 +352,10 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Simple double loop. Stupid, slow, but a good benchmark.
for (size_t q = 0; q < referenceSet->n_cols; ++q)
{
+ const Candidate def(-DBL_MAX, size_t() - 1);
+ std::vector<Candidate> cList(k, def);
+ CandidateList pqueue(std::greater<Candidate>(), std::move(cList));
+
for (size_t r = 0; r < referenceSet->n_cols; ++r)
{
if (q == r)
@@ -357,14 +364,19 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
const double eval = metric.Kernel().Evaluate(referenceSet->col(q),
referenceSet->col(r));
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows;
- ++insertPosition)
- if (eval > kernels(insertPosition, q))
- break;
+ Candidate c(eval, r);
+ if (c > pqueue.top())
+ {
+ pqueue.pop();
+ pqueue.push(c);
+ }
+ }
- if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+ for (size_t j = 1; j <= k; j++)
+ {
+ indices(k - j, q) = pqueue.top().index;
+ kernels(k - j, q) = pqueue.top().product;
+ pqueue.pop();
}
}
@@ -379,8 +391,7 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
// Create rules object (this will store the results). This constructor
// precalculates each self-kernel value.
typedef FastMKSRules<KernelType, Tree> RuleType;
- RuleType rules(*referenceSet, *referenceSet, indices, kernels,
- metric.Kernel());
+ RuleType rules(*referenceSet, *referenceSet, k, metric.Kernel());
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -395,6 +406,8 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
Log::Info << rules.BaseCases() << " base cases." << std::endl;
Log::Info << rules.Scores() << " scores." << std::endl;
+ rules.GetResults(indices, kernels);
+
Timer::Stop("computing_products");
return;
}
@@ -405,44 +418,6 @@ void FastMKS<KernelType, MatType, TreeType>::Search(
Search(referenceTree, k, indices, kernels);
}
-/**
- * Helper function to insert a point into the neighbors and distances matrices.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename KernelType,
- typename MatType,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType>
-void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
- arma::Mat<size_t>& indices,
- arma::mat& products,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
-{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (products.n_rows - 1))
- {
- int len = (products.n_rows - 1) - pos;
- memmove(products.colptr(queryIndex) + (pos + 1),
- products.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(indices.colptr(queryIndex) + (pos + 1),
- indices.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
- }
-
- // Now put the new information in the right index.
- products(pos, queryIndex) = distance;
- indices(pos, queryIndex) = neighbor;
-}
-
//! Serialize the model.
template<typename KernelType,
typename MatType,
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 0f4ad34..13be5f5 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -10,6 +10,7 @@
#include <mlpack/core.hpp>
#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
#include <mlpack/core/tree/traversal_info.hpp>
+#include <vector>
namespace mlpack {
namespace fastmks {
@@ -23,10 +24,17 @@ class FastMKSRules
public:
FastMKSRules(const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products,
+ const size_t k,
KernelType& kernel);
+ /**
+ * Store the list of candidates for each query point in the given matrices.
+ *
+ * @param indices Matrix storing lists of candidate points for each query point.
+ * @param products Matrix storing kernel value for each candidate.
+ */
+ void GetResults(arma::Mat<size_t>& indices, arma::mat& products);
+
//! Compute the base case (kernel value) between two points.
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
@@ -101,10 +109,36 @@ class FastMKSRules
//! The query dataset.
const typename TreeType::Mat& querySet;
- //! The indices of the maximum kernel results.
- arma::Mat<size_t>& indices;
- //! The maximum kernels.
- arma::mat& products;
+ //! Candidate point from the reference set.
+ struct Candidate
+ {
+ //! Kernel value calculated between a reference point and the query point.
+ double product;
+ //! Index of the reference point.
+ size_t index;
+ //! Trivial constructor.
+ Candidate(double p, size_t i) :
+ product(p),
+ index(i)
+ {};
+ //! Compare two candidates.
+ friend bool operator>(const Candidate& l, const Candidate& r)
+ {
+ return l.product > r.product;
+ };
+ };
+
+ //! Use a min heap to represent the list of candidate points.
+ //! We will use a vector and the std functions: push_heap() pop_heap().
+ //! We can not use a priority queue because we need to iterate over all the
+ //! candidates and std::priority_queue doesn't provide that interface.
+ typedef std::vector<Candidate> CandidateList;
+
+ //! Set of candidates for each point.
+ std::vector<CandidateList> candidates;
+
+ //! Number of points to search for.
+ const size_t k;
//! Cached query set self-kernels (|| q || for each q).
arma::vec queryKernels;
@@ -124,11 +158,16 @@ class FastMKSRules
//! Calculate the bound for a given query node.
double CalculateBound(TreeType& queryNode) const;
- //! Utility function to insert neighbor into list of results.
+ /**
+ * Helper function to insert a point into the list of candidate points.
+ *
+ * @param queryIndex Index of point whose neighbors we are inserting into.
+ * @param index Index of reference point which is being inserted.
+ * @param product Kernel value for given candidate.
+ */
void InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance);
+ const size_t index,
+ const double product);
//! For benchmarking.
size_t baseCases;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index 27abacf..a5cf681 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -9,6 +9,7 @@
// In case it hasn't already been included.
#include "fastmks_rules.hpp"
+#include <algorithm>
namespace mlpack {
namespace fastmks {
@@ -17,13 +18,11 @@ template<typename KernelType, typename TreeType>
FastMKSRules<KernelType, TreeType>::FastMKSRules(
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products,
+ const size_t k,
KernelType& kernel) :
referenceSet(referenceSet),
querySet(querySet),
- indices(indices),
- products(products),
+ k(k),
kernel(kernel),
lastQueryIndex(-1),
lastReferenceIndex(-1),
@@ -46,6 +45,41 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(
// dereference null pointers.
traversalInfo.LastQueryNode() = (TreeType*) this;
traversalInfo.LastReferenceNode() = (TreeType*) this;
+
+ // Let's build the list of candidate points for each query point.
+ // It will be initialized with k candidates: (-DBL_MAX, size_t() - 1)
+ // The list of candidates will be updated when visiting new points with the
+ // BaseCase() method.
+ const Candidate def(-DBL_MAX, size_t() - 1);
+
+ CandidateList cList(k, def);
+ std::vector<CandidateList> tmp(querySet.n_cols, cList);
+ candidates.swap(tmp);
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKSRules<KernelType, TreeType>::GetResults(
+ arma::Mat<size_t>& indices,
+ arma::mat& products)
+{
+ indices.set_size(k, querySet.n_cols);
+ products.set_size(k, querySet.n_cols);
+
+ for (size_t i = 0; i < querySet.n_cols; i++)
+ {
+ CandidateList& pqueue = candidates[i];
+ std::greater<Candidate> greater;
+ typedef typename CandidateList::iterator Iterator;
+
+ for (Iterator end = pqueue.end(); end != pqueue.begin(); --end)
+ std::pop_heap(pqueue.begin(), end, greater);
+
+ for (size_t j = 0; j < k; j++)
+ {
+ indices(j, i) = pqueue[j].index;
+ products(j, i) = pqueue[j].product;
+ }
+ }
}
template<typename KernelType, typename TreeType>
@@ -83,16 +117,7 @@ double FastMKSRules<KernelType, TreeType>::BaseCase(
if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
return kernelEval;
- // If this is a better candidate, insert it into the list.
- if (kernelEval < products(products.n_rows - 1, queryIndex))
- return kernelEval;
-
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows; ++insertPosition)
- if (kernelEval >= products(insertPosition, queryIndex))
- break;
-
- InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
+ InsertNeighbor(queryIndex, referenceIndex, kernelEval);
return kernelEval;
}
@@ -102,7 +127,7 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
TreeType& referenceNode)
{
// Compare with the current best.
- const double bestKernel = products(products.n_rows - 1, queryIndex);
+ const double bestKernel = candidates[queryIndex].front().product;
// See if we can perform a parent-child prune.
const double furthestDist = referenceNode.FurthestDescendantDistance();
@@ -385,7 +410,7 @@ double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
TreeType& /*referenceNode*/,
const double oldScore) const
{
- const double bestKernel = products(products.n_rows - 1, queryIndex);
+ const double bestKernel = candidates[queryIndex].front().product;
return ((1.0 / oldScore) >= bestKernel) ? oldScore : DBL_MAX;
}
@@ -432,10 +457,11 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
const size_t point = queryNode.Point(i);
- if (products(products.n_rows - 1, point) < worstPointKernel)
- worstPointKernel = products(products.n_rows - 1, point);
+ const CandidateList& candidatesPoints = candidates[point];
+ if (candidatesPoints.front().product < worstPointKernel)
+ worstPointKernel = candidatesPoints.front().product;
- if (products(products.n_rows - 1, point) == -DBL_MAX)
+ if (candidatesPoints.front().product == -DBL_MAX)
continue; // Avoid underflow.
// This should be (queryDescendantDistance + centroidDistance) for any tree
@@ -450,10 +476,10 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
// where p_j^*(p_q) is the j'th kernel candidate for query point p_q and
// k_j^*(p_q) is K(p_q, p_j^*(p_q)).
double worstPointCandidateKernel = DBL_MAX;
- for (size_t j = 0; j < products.n_rows; ++j)
+ for (size_t j = 0; j < candidatesPoints.size(); ++j)
{
- const double candidateKernel = products(j, point) -
- queryDescendantDistance * referenceKernels[indices(j, point)];
+ const double candidateKernel = candidatesPoints[j].product -
+ queryDescendantDistance * referenceKernels[candidatesPoints[j].index];
if (candidateKernel < worstPointCandidateKernel)
worstPointCandidateKernel = candidateKernel;
}
@@ -488,34 +514,26 @@ double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
}
/**
- * Helper function to insert a point into the neighbors and distances matrices.
+ * Helper function to insert a point into the list of candidate points.
*
* @param queryIndex Index of point whose neighbors we are inserting into.
- * @param pos Position in list to insert into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
+ * @param index Index of reference point which is being inserted.
+ * @param product Kernel value for given candidate.
*/
template<typename KernelType, typename TreeType>
-void FastMKSRules<KernelType, TreeType>::InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
+inline void FastMKSRules<KernelType, TreeType>::InsertNeighbor(
+ const size_t queryIndex,
+ const size_t index,
+ const double product)
{
- // We only memmove() if there is actually a need to shift something.
- if (pos < (products.n_rows - 1))
+ Candidate c(product, index);
+ CandidateList& pqueue = candidates[queryIndex];
+ if (c > pqueue.front())
{
- int len = (products.n_rows - 1) - pos;
- memmove(products.colptr(queryIndex) + (pos + 1),
- products.colptr(queryIndex) + pos,
- sizeof(double) * len);
- memmove(indices.colptr(queryIndex) + (pos + 1),
- indices.colptr(queryIndex) + pos,
- sizeof(size_t) * len);
+ std::pop_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
+ pqueue.back() = c;
+ std::push_heap(pqueue.begin(), pqueue.end(), std::greater<Candidate>());
}
-
- // Now put the new information in the right index.
- products(pos, queryIndex) = distance;
- indices(pos, queryIndex) = neighbor;
}
} // namespace fastmks
More information about the mlpack-git
mailing list