[mlpack-svn] r15776 - mlpack/trunk/src/mlpack/methods/fastmks
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 15:53:43 EDT 2013
Author: rcurtin
Date: Fri Sep 13 15:53:43 2013
New Revision: 15776
Log:
This was the version of code used for the FastMKS benchmarks in the recently
submitted paper, "Dual-tree Fast Exact Max-Kernel Search".
Modified:
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp Fri Sep 13 15:53:43 2013
@@ -37,21 +37,7 @@
* @param queryIndex Index of query point.
* @param referenceNode Candidate to be recursed into.
*/
- double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
- */
- double Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const;
+ double Score(const size_t queryIndex, TreeType& referenceNode);
/**
* Get the score for recursion order. A low score indicates priority for
@@ -61,21 +47,7 @@
* @param queryNode Candidate query node to be recursed into.
* @param referenceNode Candidate reference node to be recursed into.
*/
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
- /**
- * Get the score for recursion order, passing the base case result (in the
- * situation where it may be needed to calculate the recursion order). A low
- * score indicates priority for recursion, while DBL_MAX indicates that the
- * node should not be recursed into at all (it should be pruned).
- *
- * @param queryNode Candidate query node to be recursed into.
- * @param referenceNode Candidate reference node to be recursed into.
- * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
- */
- double Score(TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const;
+ double Score(TreeType& queryNode, TreeType& referenceNode);
/**
* Re-evaluate the score for recursion order. A low score indicates priority
@@ -107,6 +79,16 @@
TreeType& referenceNode,
const double oldScore) const;
+ //! Get the number of times BaseCase() was called.
+ size_t BaseCases() const { return baseCases; }
+ //! Modify the number of times BaseCase() was called.
+ size_t& BaseCases() { return baseCases; }
+
+ //! Get the number of times Score() was called.
+ size_t Scores() const { return scores; }
+ //! Modify the number of times Score() was called.
+ size_t& Scores() { return scores; }
+
private:
//! The reference dataset.
const arma::mat& referenceSet;
@@ -126,6 +108,13 @@
//! The instantiated kernel.
KernelType& kernel;
+ //! The last query index BaseCase() was called with.
+ size_t lastQueryIndex;
+ //! The last reference index BaseCase() was called with.
+ size_t lastReferenceIndex;
+ //! The last kernel evaluation resulting from BaseCase().
+ double lastKernel;
+
//! Calculate the bound for a given query node.
double CalculateBound(TreeType& queryNode) const;
@@ -134,6 +123,11 @@
const size_t pos,
const size_t neighbor,
const double distance);
+
+ //! For benchmarking.
+ size_t baseCases;
+ //! For benchmarking.
+ size_t scores;
};
}; // namespace fastmks
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp Fri Sep 13 15:53:43 2013
@@ -23,7 +23,12 @@
querySet(querySet),
indices(indices),
products(products),
- kernel(kernel)
+ kernel(kernel),
+ lastQueryIndex(-1),
+ lastReferenceIndex(-1),
+ lastKernel(0.0),
+ baseCases(0),
+ scores(0)
{
// Precompute each self-kernel.
queryKernels.set_size(querySet.n_cols);
@@ -43,10 +48,29 @@
const size_t queryIndex,
const size_t referenceIndex)
{
+ // Score() always happens before BaseCase() for a given node combination. For
+ // cover trees, the kernel evaluation between the two centroid points already
+ // happened. So we don't need to do it. Note that this optimizes out if the
+ // first conditional is false (its result is known at compile time).
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ if ((queryIndex == lastQueryIndex) &&
+ (referenceIndex == lastReferenceIndex))
+ return lastKernel;
+
+ // Store new values.
+ lastQueryIndex = queryIndex;
+ lastReferenceIndex = referenceIndex;
+ }
+ ++baseCases;
double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
referenceSet.unsafe_col(referenceIndex));
+ // Update the last kernel value, if we need to.
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ lastKernel = kernelEval;
+
// If the reference and query sets are identical, we still need to compute the
// base case (so that things can be bounded properly), but we won't add it to
// the results.
@@ -67,98 +91,304 @@
return kernelEval;
}
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
- TreeType& referenceNode) const
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode)
{
- // Calculate the maximum possible kernel value.
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
- const arma::vec refCentroid;
- referenceNode.Bound().Centroid(refCentroid);
-
- const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
- referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
-
// Compare with the current best.
const double bestKernel = products(products.n_rows - 1, queryIndex);
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
+ // See if we can perform a parent-child prune.
+ const double furthestDist = referenceNode.FurthestDescendantDistance();
+ if (referenceNode.Parent() != NULL)
+ {
+ double maxKernelBound;
+ const double parentDist = referenceNode.ParentDistance();
+ const double combinedDistBound = parentDist + furthestDist;
+ const double lastKernel = referenceNode.Parent()->Stat().LastKernel();
+ if (kernel::KernelTraits<KernelType>::IsNormalized)
+ {
+ const double squaredDist = std::pow(combinedDistBound, 2.0);
+ const double delta = (1 - 0.5 * squaredDist);
+ if (lastKernel <= delta)
+ {
+ const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist);
+ maxKernelBound = lastKernel * delta +
+ gamma * sqrt(1 - std::pow(lastKernel, 2.0));
+ }
+ else
+ {
+ maxKernelBound = 1.0;
+ }
+ }
+ else
+ {
+ maxKernelBound = lastKernel +
+ combinedDistBound * queryKernels[queryIndex];
+ }
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
- const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- // We already have the base case result. Add the bound.
- const double maxKernel = baseCaseResult +
- referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
- const double bestKernel = products(products.n_rows - 1, queryIndex);
+ if (maxKernelBound < bestKernel)
+ return DBL_MAX;
+ }
- // We return the inverse of the maximum kernel so that larger kernels are
- // recursed into first.
- return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
+ // Calculate the maximum possible kernel value, either by calculating the
+ // centroid or, if the centroid is a point, use that.
+ ++scores;
+ double kernelEval;
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ // Could it be that this kernel evaluation has already been calculated?
+ if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+ referenceNode.Parent() != NULL &&
+ referenceNode.Point(0) == referenceNode.Parent()->Point(0))
+ {
+ kernelEval = referenceNode.Parent()->Stat().LastKernel();
+ }
+ else
+ {
+ kernelEval = BaseCase(queryIndex, referenceNode.Point(0));
+ }
+ }
+ else
+ {
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ arma::vec refCentroid;
+ referenceNode.Centroid(refCentroid);
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode) const
-{
- // Calculate the maximum possible kernel value.
- const arma::vec queryCentroid;
- const arma::vec refCentroid;
- queryNode.Bound().Centroid(queryCentroid);
- referenceNode.Bound().Centroid(refCentroid);
-
- const double refKernelTerm = queryNode.FurthestDescendantDistance() *
- referenceNode.Stat().SelfKernel();
- const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
- queryNode.Stat().SelfKernel();
-
- const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
- refKernelTerm + queryKernelTerm +
- (queryNode.FurthestDescendantDistance() *
- referenceNode.FurthestDescendantDistance());
+ kernelEval = kernel.Evaluate(queryPoint, refCentroid);
+ }
- // The existing bound.
- queryNode.Stat().Bound() = CalculateBound(queryNode);
- const double bestKernel = queryNode.Stat().Bound();
+ referenceNode.Stat().LastKernel() = kernelEval;
+
+ double maxKernel;
+ if (kernel::KernelTraits<KernelType>::IsNormalized)
+ {
+ const double squaredDist = std::pow(furthestDist, 2.0);
+ const double delta = (1 - 0.5 * squaredDist);
+ if (kernelEval <= delta)
+ {
+ const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist);
+ maxKernel = kernelEval * delta +
+ gamma * sqrt(1 - std::pow(kernelEval, 2.0));
+ }
+ else
+ {
+ maxKernel = 1.0;
+ }
+ }
+ else
+ {
+ maxKernel = kernelEval + furthestDist * queryKernels[queryIndex];
+ }
// We return the inverse of the maximum kernel so that larger kernels are
// recursed into first.
return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
}
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
- TreeType& queryNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode)
{
- // We already have the base case, so we need to add the bounds.
- const double refKernelTerm = queryNode.FurthestDescendantDistance() *
- referenceNode.Stat().SelfKernel();
- const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
- queryNode.Stat().SelfKernel();
-
- const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
- (queryNode.FurthestDescendantDistance() *
- referenceNode.FurthestDescendantDistance());
-
- // The existing bound.
+ // Update and get the query node's bound.
queryNode.Stat().Bound() = CalculateBound(queryNode);
const double bestKernel = queryNode.Stat().Bound();
+ // First, see if we can make a parent-child or parent-parent prune. These
+ // four bounds on the maximum kernel value are looser than the bound normally
+ // used, but they can prevent a base case from needing to be calculated.
+ const TreeType* queryParent = queryNode.Parent();
+ const TreeType* refParent = referenceNode.Parent();
+
+ // Convenience caching so lines are shorter.
+ const double queryParentDist = queryNode.ParentDistance();
+ const double queryDescDist = queryNode.FurthestDescendantDistance();
+ const double refParentDist = referenceNode.ParentDistance();
+ const double refDescDist = referenceNode.FurthestDescendantDistance();
+
+ const double queryDistBound = (queryParentDist + queryDescDist);
+ const double refDistBound = (refParentDist + refDescDist);
+
+ if ((queryParent != NULL) &&
+ (queryParent->Stat().LastKernelNode() == (void*) &referenceNode))
+ {
+ // Query parent was last evaluated with reference node.
+ const double maxKernelBound = queryParent->Stat().LastKernel() +
+ queryDistBound * referenceNode.Stat().SelfKernel();
+
+ if (maxKernelBound < bestKernel)
+ return DBL_MAX;
+ }
+ else if ((refParent != NULL) &&
+ (refParent->Stat().LastKernelNode() == (void*) &queryNode))
+ {
+ // Reference parent was last evaluated with query node.
+ const double maxKernelBound = refParent->Stat().LastKernel() +
+ (refParentDist + refDescDist) * queryNode.Stat().SelfKernel();
+
+ if (maxKernelBound < bestKernel)
+ return DBL_MAX;
+ }
+ else if ((refParent != NULL) && (queryParent != NULL) &&
+ (queryParent->Stat().LastKernelNode() == (void*) refParent))
+ {
+ // Query parent was last calculated with reference parent.
+ const double queryKernelTerm = (refParentDist + refDescDist) *
+ queryParent->Stat().SelfKernel();
+ const double refKernelTerm = (queryParentDist + queryDescDist) *
+ refParent->Stat().SelfKernel();
+ const double dualTerm = (queryParentDist + queryDescDist) * (refParentDist +
+ refDescDist);
+
+ const double maxKernelBound = queryParent->Stat().LastKernel() +
+ queryKernelTerm + refKernelTerm + dualTerm;
+
+ if (maxKernelBound < bestKernel)
+ return DBL_MAX;
+ }
+ else if ((refParent != NULL) && (queryParent != NULL) &&
+ (refParent->Stat().LastKernelNode() == (void*) queryParent))
+ {
+ // Reference parent was last calculated with query parent.
+ const double queryKernelTerm = (refParentDist + refDescDist) *
+ queryParent->Stat().SelfKernel();
+ const double refKernelTerm = (queryParentDist + queryDescDist) *
+ refParent->Stat().SelfKernel();
+ const double dualTerm = (queryParentDist + queryDescDist) *
+ (refParentDist + refDescDist);
+
+ const double maxKernelBound = refParent->Stat().LastKernel() +
+ queryKernelTerm + refKernelTerm + dualTerm;
+
+ if (maxKernelBound < bestKernel)
+ return DBL_MAX;
+ }
+
+ // Calculate kernel evaluation, if necessary.
+ double kernelEval = 0.0;
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ bool alreadyDone = false;
+ if ((queryNode.Parent() != NULL) &&
+ (queryNode.Parent()->Point(0) == queryNode.Point(0)))
+ {
+ TreeType* lastRef = (TreeType*)
+ queryNode.Parent()->Stat().LastKernelNode();
+ if (lastRef->Point(0) == referenceNode.Point(0))
+ {
+ // The query node parent was evaluated with the reference node.
+ kernelEval = queryNode.Parent()->Stat().LastKernel();
+ alreadyDone = true;
+ }
+ }
+
+ if ((referenceNode.Parent() != NULL) &&
+ (referenceNode.Parent()->Point(0) == referenceNode.Point(0)))
+ {
+ TreeType* lastQuery = (TreeType*)
+ referenceNode.Parent()->Stat().LastKernelNode();
+ if (lastQuery->Point(0) == queryNode.Point(0))
+ {
+ // The reference node parent was evaluated with the query node.
+ kernelEval = referenceNode.Parent()->Stat().LastKernel();
+ alreadyDone = true;
+ }
+ }
+
+ TreeType* lastRefNode = (TreeType*) referenceNode.Stat().LastKernelNode();
+ if ((lastRefNode != NULL) && (queryNode.Point(0) == lastRefNode->Point(0)))
+ {
+ // The kernel evaluation was already performed and is saved by the
+ // reference node.
+ kernelEval = referenceNode.Stat().LastKernel();
+ alreadyDone = true;
+ }
+
+ TreeType* lastQueryNode = (TreeType*) queryNode.Stat().LastKernelNode();
+ if ((lastQueryNode != NULL) &&
+ (referenceNode.Point(0) == lastQueryNode->Point(0)))
+ {
+ // The kernel evaluation was already performed and is saved by the query
+ // node.
+ kernelEval = queryNode.Stat().LastKernel();
+ alreadyDone = true;
+ }
+
+ if (!alreadyDone)
+ {
+ // The kernel must be evaluated, but it is between points in the dataset,
+ // so we can call BaseCase(). BaseCase() will set lastQueryIndex and
+ // lastReferenceIndex correctly.
+ kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+ }
+ else
+ {
+ // When BaseCase() is called after Score(), these must be correct so that
+ // another kernel evaluation is not performed.
+ lastQueryIndex = queryNode.Point(0);
+ lastReferenceIndex = referenceNode.Point(0);
+ }
+ }
+ else
+ {
+ // Calculate the maximum possible kernel value.
+ arma::vec queryCentroid;
+ arma::vec refCentroid;
+ queryNode.Centroid(queryCentroid);
+ referenceNode.Centroid(refCentroid);
+
+ kernelEval = kernel.Evaluate(queryCentroid, refCentroid);
+ }
+ ++scores;
+
+ double maxKernel;
+ if (kernel::KernelTraits<KernelType>::IsNormalized)
+ {
+ // We have a tighter bound for normalized kernels.
+ const double querySqDist = std::pow(queryDescDist, 2.0);
+ const double refSqDist = std::pow(refDescDist, 2.0);
+ const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0);
+
+ if (kernelEval <= (1 - 0.5 * bothSqDist))
+ {
+ const double queryDelta = (1 - 0.5 * querySqDist);
+ const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist);
+ const double refDelta = (1 - 0.5 * refSqDist);
+ const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist);
+
+ maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) +
+ sqrt(1 - std::pow(kernelEval, 2.0)) *
+ (queryGamma * refDelta + queryDelta * refGamma);
+ }
+ else
+ {
+ maxKernel = 1.0;
+ }
+ }
+ else
+ {
+ // Use standard bound; kernel is not normalized.
+ const double refKernelTerm = queryDescDist *
+ referenceNode.Stat().SelfKernel();
+ const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel();
+
+ maxKernel = kernelEval + refKernelTerm + queryKernelTerm +
+ (queryDescDist * refDescDist);
+ }
+
+ // Store relevant information for parent-child pruning.
+ queryNode.Stat().LastKernel() = kernelEval;
+ queryNode.Stat().LastKernelNode() = (void*) &referenceNode;
+ referenceNode.Stat().LastKernel() = kernelEval;
+ referenceNode.Stat().LastKernelNode() = (void*) &queryNode;
+
// We return the inverse of the maximum kernel so that larger kernels are
// recursed into first.
return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
}
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
TreeType& /*referenceNode*/,
const double oldScore) const
{
@@ -167,8 +397,8 @@
return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
}
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Rescore(TreeType& queryNode,
TreeType& /*referenceNode*/,
const double oldScore) const
{
@@ -185,8 +415,8 @@
*
* @param queryNode Query node to calculate bound for.
*/
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
const
{
// We have four possible bounds -- just like NeighborSearchRules, but they are
@@ -200,7 +430,7 @@
// (4) B(parent of queryNode);
double worstPointKernel = DBL_MAX;
double bestAdjustedPointKernel = -DBL_MAX;
-// double bestPointSelfKernel = -DBL_MAX;
+
const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
// Loop over all points in this node to find the best and worst.
@@ -213,8 +443,10 @@
if (products(products.n_rows - 1, point) == -DBL_MAX)
continue; // Avoid underflow.
+ // This should be (queryDescendantDistance + centroidDistance) for any tree
+ // but it works for cover trees since centroidDistance = 0 for cover trees.
const double candidateKernel = products(products.n_rows - 1, point) -
- (2 * queryDescendantDistance) *
+ queryDescendantDistance *
referenceKernels[indices(indices.n_rows - 1, point)];
if (candidateKernel > bestAdjustedPointKernel)
@@ -255,8 +487,8 @@
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename MetricType, typename TreeType>
-void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
+template<typename KernelType, typename TreeType>
+void FastMKSRules<KernelType, TreeType>::InsertNeighbor(const size_t queryIndex,
const size_t pos,
const size_t neighbor,
const double distance)
More information about the mlpack-svn
mailing list