[mlpack-svn] r15776 - mlpack/trunk/src/mlpack/methods/fastmks

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 15:53:43 EDT 2013


Author: rcurtin
Date: Fri Sep 13 15:53:43 2013
New Revision: 15776

Log:
This was the version of code used for the FastMKS benchmarks in the recently
submitted paper, "Dual-tree Fast Exact Max-Kernel Search".


Modified:
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	Fri Sep 13 15:53:43 2013
@@ -37,21 +37,7 @@
    * @param queryIndex Index of query point.
    * @param referenceNode Candidate to be recursed into.
    */
-  double Score(const size_t queryIndex, TreeType& referenceNode) const;
-
-  /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryIndex Index of query point.
-   * @param referenceNode Candidate node to be recursed into.
-   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
-   */
-  double Score(const size_t queryIndex,
-               TreeType& referenceNode,
-               const double baseCaseResult) const;
+  double Score(const size_t queryIndex, TreeType& referenceNode);
 
   /**
    * Get the score for recursion order.  A low score indicates priority for
@@ -61,21 +47,7 @@
    * @param queryNode Candidate query node to be recursed into.
    * @param referenceNode Candidate reference node to be recursed into.
    */
-  double Score(TreeType& queryNode, TreeType& referenceNode) const;
-
-  /**
-   * Get the score for recursion order, passing the base case result (in the
-   * situation where it may be needed to calculate the recursion order).  A low
-   * score indicates priority for recursion, while DBL_MAX indicates that the
-   * node should not be recursed into at all (it should be pruned).
-   *
-   * @param queryNode Candidate query node to be recursed into.
-   * @param referenceNode Candidate reference node to be recursed into.
-   * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
-   */
-  double Score(TreeType& queryNode,
-               TreeType& referenceNode,
-               const double baseCaseResult) const;
+  double Score(TreeType& queryNode, TreeType& referenceNode);
 
   /**
    * Re-evaluate the score for recursion order.  A low score indicates priority
@@ -107,6 +79,16 @@
                  TreeType& referenceNode,
                  const double oldScore) const;
 
+  //! Get the number of times BaseCase() was called.
+  size_t BaseCases() const { return baseCases; }
+  //! Modify the number of times BaseCase() was called.
+  size_t& BaseCases() { return baseCases; }
+
+  //! Get the number of times Score() was called.
+  size_t Scores() const { return scores; }
+  //! Modify the number of times Score() was called.
+  size_t& Scores() { return scores; }
+
  private:
   //! The reference dataset.
   const arma::mat& referenceSet;
@@ -126,6 +108,13 @@
   //! The instantiated kernel.
   KernelType& kernel;
 
+  //! The last query index BaseCase() was called with.
+  size_t lastQueryIndex;
+  //! The last reference index BaseCase() was called with.
+  size_t lastReferenceIndex;
+  //! The last kernel evaluation resulting from BaseCase().
+  double lastKernel;
+
   //! Calculate the bound for a given query node.
   double CalculateBound(TreeType& queryNode) const;
 
@@ -134,6 +123,11 @@
                       const size_t pos,
                       const size_t neighbor,
                       const double distance);
+
+  //! For benchmarking.
+  size_t baseCases;
+  //! For benchmarking.
+  size_t scores;
 };
 
 }; // namespace fastmks

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp	Fri Sep 13 15:53:43 2013
@@ -23,7 +23,12 @@
     querySet(querySet),
     indices(indices),
     products(products),
-    kernel(kernel)
+    kernel(kernel),
+    lastQueryIndex(-1),
+    lastReferenceIndex(-1),
+    lastKernel(0.0),
+    baseCases(0),
+    scores(0)
 {
   // Precompute each self-kernel.
   queryKernels.set_size(querySet.n_cols);
@@ -43,10 +48,29 @@
     const size_t queryIndex,
     const size_t referenceIndex)
 {
+  // Score() always happens before BaseCase() for a given node combination.  For
+  // cover trees, the kernel evaluation between the two centroid points already
+  // happened.  So we don't need to do it.  Note that this optimizes out if the
+  // first conditional is false (its result is known at compile time).
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    if ((queryIndex == lastQueryIndex) &&
+        (referenceIndex == lastReferenceIndex))
+      return lastKernel;
+
+    // Store new values.
+    lastQueryIndex = queryIndex;
+    lastReferenceIndex = referenceIndex;
+  }
 
+  ++baseCases;
   double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
                                       referenceSet.unsafe_col(referenceIndex));
 
+  // Update the last kernel value, if we need to.
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+    lastKernel = kernelEval;
+
   // If the reference and query sets are identical, we still need to compute the
   // base case (so that things can be bounded properly), but we won't add it to
   // the results.
@@ -67,98 +91,304 @@
   return kernelEval;
 }
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
-                                                 TreeType& referenceNode) const
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
+                                                 TreeType& referenceNode)
 {
-  // Calculate the maximum possible kernel value.
-  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
-  const arma::vec refCentroid;
-  referenceNode.Bound().Centroid(refCentroid);
-
-  const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
-      referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
-
   // Compare with the current best.
   const double bestKernel = products(products.n_rows - 1, queryIndex);
 
-  // We return the inverse of the maximum kernel so that larger kernels are
-  // recursed into first.
-  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
+  // See if we can perform a parent-child prune.
+  const double furthestDist = referenceNode.FurthestDescendantDistance();
+  if (referenceNode.Parent() != NULL)
+  {
+    double maxKernelBound;
+    const double parentDist = referenceNode.ParentDistance();
+    const double combinedDistBound = parentDist + furthestDist;
+    const double lastKernel = referenceNode.Parent()->Stat().LastKernel();
+    if (kernel::KernelTraits<KernelType>::IsNormalized)
+    {
+      const double squaredDist = std::pow(combinedDistBound, 2.0);
+      const double delta = (1 - 0.5 * squaredDist);
+      if (lastKernel <= delta)
+      {
+        const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist);
+        maxKernelBound = lastKernel * delta +
+             gamma * sqrt(1 - std::pow(lastKernel, 2.0));
+      }
+      else
+      {
+        maxKernelBound = 1.0;
+      }
+    }
+    else
+    {
+      maxKernelBound = lastKernel +
+          combinedDistBound * queryKernels[queryIndex];
+    }
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
-    const size_t queryIndex,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
-{
-  // We already have the base case result.  Add the bound.
-  const double maxKernel = baseCaseResult +
-      referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
-  const double bestKernel = products(products.n_rows - 1, queryIndex);
+    if (maxKernelBound < bestKernel)
+      return DBL_MAX;
+  }
 
-  // We return the inverse of the maximum kernel so that larger kernels are
-  // recursed into first.
-  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
-}
+  // Calculate the maximum possible kernel value, either by calculating the
+  // centroid or, if the centroid is a point, use that.
+  ++scores;
+  double kernelEval;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    // Could it be that this kernel evaluation has already been calculated?
+    if (tree::TreeTraits<TreeType>::HasSelfChildren &&
+        referenceNode.Parent() != NULL &&
+        referenceNode.Point(0) == referenceNode.Parent()->Point(0))
+    {
+      kernelEval = referenceNode.Parent()->Stat().LastKernel();
+    }
+    else
+    {
+      kernelEval = BaseCase(queryIndex, referenceNode.Point(0));
+    }
+  }
+  else
+  {
+    const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+    arma::vec refCentroid;
+    referenceNode.Centroid(refCentroid);
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
-                                                 TreeType& referenceNode) const
-{
-  // Calculate the maximum possible kernel value.
-  const arma::vec queryCentroid;
-  const arma::vec refCentroid;
-  queryNode.Bound().Centroid(queryCentroid);
-  referenceNode.Bound().Centroid(refCentroid);
-
-  const double refKernelTerm = queryNode.FurthestDescendantDistance() *
-      referenceNode.Stat().SelfKernel();
-  const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
-      queryNode.Stat().SelfKernel();
-
-  const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
-      refKernelTerm + queryKernelTerm +
-      (queryNode.FurthestDescendantDistance() *
-       referenceNode.FurthestDescendantDistance());
+    kernelEval = kernel.Evaluate(queryPoint, refCentroid);
+  }
 
-  // The existing bound.
-  queryNode.Stat().Bound() = CalculateBound(queryNode);
-  const double bestKernel = queryNode.Stat().Bound();
+  referenceNode.Stat().LastKernel() = kernelEval;
+
+  double maxKernel;
+  if (kernel::KernelTraits<KernelType>::IsNormalized)
+  {
+    const double squaredDist = std::pow(furthestDist, 2.0);
+    const double delta = (1 - 0.5 * squaredDist);
+    if (kernelEval <= delta)
+    {
+      const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist);
+      maxKernel = kernelEval * delta +
+          gamma * sqrt(1 - std::pow(kernelEval, 2.0));
+    }
+    else
+    {
+      maxKernel = 1.0;
+    }
+  }
+  else
+  {
+    maxKernel = kernelEval + furthestDist * queryKernels[queryIndex];
+  }
 
   // We return the inverse of the maximum kernel so that larger kernels are
   // recursed into first.
   return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
 }
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Score(
-    TreeType& queryNode,
-    TreeType& referenceNode,
-    const double baseCaseResult) const
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode,
+                                                 TreeType& referenceNode)
 {
-  // We already have the base case, so we need to add the bounds.
-  const double refKernelTerm = queryNode.FurthestDescendantDistance() *
-      referenceNode.Stat().SelfKernel();
-  const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
-      queryNode.Stat().SelfKernel();
-
-  const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
-      (queryNode.FurthestDescendantDistance() *
-       referenceNode.FurthestDescendantDistance());
-
-  // The existing bound.
+  // Update and get the query node's bound.
   queryNode.Stat().Bound() = CalculateBound(queryNode);
   const double bestKernel = queryNode.Stat().Bound();
 
+  // First, see if we can make a parent-child or parent-parent prune.  These
+  // four bounds on the maximum kernel value are looser than the bound normally
+  // used, but they can prevent a base case from needing to be calculated.
+  const TreeType* queryParent = queryNode.Parent();
+  const TreeType* refParent = referenceNode.Parent();
+
+  // Convenience caching so lines are shorter.
+  const double queryParentDist = queryNode.ParentDistance();
+  const double queryDescDist = queryNode.FurthestDescendantDistance();
+  const double refParentDist = referenceNode.ParentDistance();
+  const double refDescDist = referenceNode.FurthestDescendantDistance();
+
+  const double queryDistBound = (queryParentDist + queryDescDist);
+  const double refDistBound = (refParentDist + refDescDist);
+
+  if ((queryParent != NULL) &&
+      (queryParent->Stat().LastKernelNode() == (void*) &referenceNode))
+  {
+    // Query parent was last evaluated with reference node.
+    const double maxKernelBound = queryParent->Stat().LastKernel() +
+        queryDistBound * referenceNode.Stat().SelfKernel();
+
+    if (maxKernelBound < bestKernel)
+      return DBL_MAX;
+  }
+  else if ((refParent != NULL) &&
+      (refParent->Stat().LastKernelNode() == (void*) &queryNode))
+  {
+    // Reference parent was last evaluated with query node.
+    const double maxKernelBound = refParent->Stat().LastKernel() +
+        (refParentDist + refDescDist) * queryNode.Stat().SelfKernel();
+
+    if (maxKernelBound < bestKernel)
+      return DBL_MAX;
+  }
+  else if ((refParent != NULL) && (queryParent != NULL) &&
+      (queryParent->Stat().LastKernelNode() == (void*) refParent))
+  {
+    // Query parent was last calculated with reference parent.
+    const double queryKernelTerm = (refParentDist + refDescDist) *
+        queryParent->Stat().SelfKernel();
+    const double refKernelTerm = (queryParentDist + queryDescDist) *
+        refParent->Stat().SelfKernel();
+    const double dualTerm = (queryParentDist + queryDescDist) * (refParentDist +
+        refDescDist);
+
+    const double maxKernelBound = queryParent->Stat().LastKernel() +
+        queryKernelTerm + refKernelTerm + dualTerm;
+
+    if (maxKernelBound < bestKernel)
+      return DBL_MAX;
+  }
+  else if ((refParent != NULL) && (queryParent != NULL) &&
+      (refParent->Stat().LastKernelNode() == (void*) queryParent))
+  {
+    // Reference parent was last calculated with query parent.
+    const double queryKernelTerm = (refParentDist + refDescDist) *
+        queryParent->Stat().SelfKernel();
+    const double refKernelTerm = (queryParentDist + queryDescDist) *
+        refParent->Stat().SelfKernel();
+    const double dualTerm = (queryParentDist + queryDescDist) *
+        (refParentDist + refDescDist);
+
+    const double maxKernelBound = refParent->Stat().LastKernel() +
+        queryKernelTerm + refKernelTerm + dualTerm;
+
+    if (maxKernelBound < bestKernel)
+      return DBL_MAX;
+  }
+
+  // Calculate kernel evaluation, if necessary.
+  double kernelEval = 0.0;
+  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+  {
+    bool alreadyDone = false;
+    if ((queryNode.Parent() != NULL) &&
+        (queryNode.Parent()->Point(0) == queryNode.Point(0)))
+    {
+      TreeType* lastRef = (TreeType*)
+          queryNode.Parent()->Stat().LastKernelNode();
+      if (lastRef->Point(0) == referenceNode.Point(0))
+      {
+        // The query node parent was evaluated with the reference node.
+        kernelEval = queryNode.Parent()->Stat().LastKernel();
+        alreadyDone = true;
+      }
+    }
+
+    if ((referenceNode.Parent() != NULL) &&
+        (referenceNode.Parent()->Point(0) == referenceNode.Point(0)))
+    {
+      TreeType* lastQuery = (TreeType*)
+          referenceNode.Parent()->Stat().LastKernelNode();
+      if (lastQuery->Point(0) == queryNode.Point(0))
+      {
+        // The reference node parent was evaluated with the query node.
+        kernelEval = referenceNode.Parent()->Stat().LastKernel();
+        alreadyDone = true;
+      }
+    }
+
+    TreeType* lastRefNode = (TreeType*) referenceNode.Stat().LastKernelNode();
+    if ((lastRefNode != NULL) && (queryNode.Point(0) == lastRefNode->Point(0)))
+    {
+      // The kernel evaluation was already performed and is saved by the
+      // reference node.
+      kernelEval = referenceNode.Stat().LastKernel();
+      alreadyDone = true;
+    }
+
+    TreeType* lastQueryNode = (TreeType*) queryNode.Stat().LastKernelNode();
+    if ((lastQueryNode != NULL) &&
+        (referenceNode.Point(0) == lastQueryNode->Point(0)))
+    {
+      // The kernel evaluation was already performed and is saved by the query
+      // node.
+      kernelEval = queryNode.Stat().LastKernel();
+      alreadyDone = true;
+    }
+
+    if (!alreadyDone)
+    {
+      // The kernel must be evaluated, but it is between points in the dataset,
+      // so we can call BaseCase().  BaseCase() will set lastQueryIndex and
+      // lastReferenceIndex correctly.
+      kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0));
+    }
+    else
+    {
+      // When BaseCase() is called after Score(), these must be correct so that
+      // another kernel evaluation is not performed.
+      lastQueryIndex = queryNode.Point(0);
+      lastReferenceIndex = referenceNode.Point(0);
+    }
+  }
+  else
+  {
+    // Calculate the maximum possible kernel value.
+    arma::vec queryCentroid;
+    arma::vec refCentroid;
+    queryNode.Centroid(queryCentroid);
+    referenceNode.Centroid(refCentroid);
+
+    kernelEval = kernel.Evaluate(queryCentroid, refCentroid);
+  }
+  ++scores;
+
+  double maxKernel;
+  if (kernel::KernelTraits<KernelType>::IsNormalized)
+  {
+    // We have a tighter bound for normalized kernels.
+    const double querySqDist = std::pow(queryDescDist, 2.0);
+    const double refSqDist = std::pow(refDescDist, 2.0);
+    const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0);
+
+    if (kernelEval <= (1 - 0.5 * bothSqDist))
+    {
+      const double queryDelta = (1 - 0.5 * querySqDist);
+      const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist);
+      const double refDelta = (1 - 0.5 * refSqDist);
+      const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist);
+
+      maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) +
+          sqrt(1 - std::pow(kernelEval, 2.0)) *
+          (queryGamma * refDelta + queryDelta * refGamma);
+    }
+    else
+    {
+      maxKernel = 1.0;
+    }
+  }
+  else
+  {
+    // Use standard bound; kernel is not normalized.
+    const double refKernelTerm = queryDescDist *
+        referenceNode.Stat().SelfKernel();
+    const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel();
+
+    maxKernel = kernelEval + refKernelTerm + queryKernelTerm +
+        (queryDescDist * refDescDist);
+  }
+
+  // Store relevant information for parent-child pruning.
+  queryNode.Stat().LastKernel() = kernelEval;
+  queryNode.Stat().LastKernelNode() = (void*) &referenceNode;
+  referenceNode.Stat().LastKernel() = kernelEval;
+  referenceNode.Stat().LastKernelNode() = (void*) &queryNode;
+
   // We return the inverse of the maximum kernel so that larger kernels are
   // recursed into first.
   return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
 }
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Rescore(const size_t queryIndex,
                                                    TreeType& /*referenceNode*/,
                                                    const double oldScore) const
 {
@@ -167,8 +397,8 @@
   return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
 }
 
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::Rescore(TreeType& queryNode,
                                                    TreeType& /*referenceNode*/,
                                                    const double oldScore) const
 {
@@ -185,8 +415,8 @@
  *
  * @param queryNode Query node to calculate bound for.
  */
-template<typename MetricType, typename TreeType>
-double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
     const
 {
   // We have four possible bounds -- just like NeighborSearchRules, but they are
@@ -200,7 +430,7 @@
   // (4) B(parent of queryNode);
   double worstPointKernel = DBL_MAX;
   double bestAdjustedPointKernel = -DBL_MAX;
-//  double bestPointSelfKernel = -DBL_MAX;
+
   const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
 
   // Loop over all points in this node to find the best and worst.
@@ -213,8 +443,10 @@
     if (products(products.n_rows - 1, point) == -DBL_MAX)
       continue; // Avoid underflow.
 
+    // This should be (queryDescendantDistance + centroidDistance) for any tree
+    // but it works for cover trees since centroidDistance = 0 for cover trees.
     const double candidateKernel = products(products.n_rows - 1, point) -
-        (2 * queryDescendantDistance) *
+        queryDescendantDistance *
         referenceKernels[indices(indices.n_rows - 1, point)];
 
     if (candidateKernel > bestAdjustedPointKernel)
@@ -255,8 +487,8 @@
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename MetricType, typename TreeType>
-void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
+template<typename KernelType, typename TreeType>
+void FastMKSRules<KernelType, TreeType>::InsertNeighbor(const size_t queryIndex,
                                                         const size_t pos,
                                                         const size_t neighbor,
                                                         const double distance)



More information about the mlpack-svn mailing list