[mlpack-svn] r14559 - mlpack/trunk/src/mlpack/methods/fastmks

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 14 14:50:06 EDT 2013


Author: rcurtin
Date: 2013-03-14 14:50:06 -0400 (Thu, 14 Mar 2013)
New Revision: 14559

Added:
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
Log:
Make FastMKS work in the dual-tree setting.  Hooray!


Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp	2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp	2013-03-14 18:50:06 UTC (rev 14559)
@@ -9,7 +9,8 @@
 
 #include <mlpack/core.hpp>
 #include "ip_metric.hpp"
-#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include "fastmks_stat.hpp"
+#include <mlpack/core/tree/cover_tree.hpp>
 
 namespace mlpack {
 namespace fastmks {
@@ -42,9 +43,11 @@
 
   const arma::mat& querySet;
 
-  tree::CoverTree<IPMetric<KernelType> >* referenceTree;
+  tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
+      referenceTree;
 
-  tree::CoverTree<IPMetric<KernelType> >* queryTree;
+  tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
+      queryTree;
 
   bool single;
 

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	2013-03-14 18:50:06 UTC (rev 14559)
@@ -18,191 +18,6 @@
 namespace mlpack {
 namespace fastmks {
 
-template<typename TreeType>
-void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
-{
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-  {
-    if (node.Child(i).NumChildren() == 0)
-      counts[node.Child(i).Point()]++;
-    else
-      RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
-  }
-}
-
-template<typename TreeType>
-void CheckSelfChild(const TreeType& node)
-{
-  if (node.NumChildren() == 0)
-    return; // No self-child applicable here.
-
-  bool found = false;
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-  {
-    if (node.Child(i).Point() == node.Point())
-      found = true;
-
-    // Recursively check the children.
-    CheckSelfChild(node.Child(i));
-  }
-
-  // Ensure this has its own self-child.
-  Log::Assert(found == true);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckCovering(const TreeType& node)
-{
-  // Return if a leaf.  No checking necessary.
-  if (node.NumChildren() == 0)
-    return;
-
-  const arma::mat& dataset = node.Dataset();
-  const size_t nodePoint = node.Point();
-
-  // To ensure that this node satisfies the covering principle, we must ensure
-  // that the distance to each child is less than pow(expansionConstant, scale).
-  double maxDistance = pow(node.Base(), node.Scale());
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-  {
-    const size_t childPoint = node.Child(i).Point();
-
-    double distance = MetricType::Evaluate(dataset.col(nodePoint),
-        dataset.col(childPoint));
-
-    Log::Assert(distance <= maxDistance);
-
-    // Check the child.
-    CheckCovering<TreeType, MetricType>(node.Child(i));
-  }
-}
-
-template<typename TreeType, typename MetricType>
-void CheckIndividualSeparation(const TreeType& constantNode,
-                               const TreeType& node)
-{
-  // Don't check points at a lower scale.
-  if (node.Scale() < constantNode.Scale())
-    return;
-
-  // If at a higher scale, recurse.
-  if (node.Scale() > constantNode.Scale())
-  {
-    for (size_t i = 0; i < node.NumChildren(); ++i)
-    {
-      // Don't recurse into leaves.
-      if (node.Child(i).NumChildren() > 0)
-        CheckIndividualSeparation<TreeType, MetricType>(constantNode,
-            node.Child(i));
-    }
-
-    return;
-  }
-
-  // Don't compare the same point against itself.
-  if (node.Point() == constantNode.Point())
-    return;
-
-  // Now we know we are at the same scale, so make the comparison.
-  const arma::mat& dataset = constantNode.Dataset();
-  const size_t constantPoint = constantNode.Point();
-  const size_t nodePoint = node.Point();
-
-  // Make sure the distance is at least the following value (in accordance with
-  // the separation principle of cover trees).
-  double minDistance = pow(constantNode.ExpansionConstant(),
-      constantNode.Scale());
-
-  double distance = MetricType::Evaluate(dataset.col(constantPoint),
-      dataset.col(nodePoint));
-
-}
-
-template<typename TreeType, typename MetricType>
-void CheckSeparation(const TreeType& node, const TreeType& root)
-{
-  // Check the separation between this point and all other points on this scale.
-  CheckIndividualSeparation<TreeType, MetricType>(node, root);
-
-  // Check the children, but only if they are not leaves.  Leaves don't need to
-  // be checked.
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-    if (node.Child(i).NumChildren() > 0)
-      CheckSeparation<TreeType, MetricType>(node.Child(i), root);
-}
-
-template<typename TreeType, typename MetricType>
-void GetMaxDistance(TreeType& node,
-                    TreeType& constantNode,
-                    double& best,
-                    size_t& index)
-{
-  const arma::mat& dataset = node.Dataset();
-  const double eval = MetricType::Evaluate(dataset.unsafe_col(node.Point()),
-      dataset.unsafe_col(constantNode.Point()));
-  if (eval > best)
-  {
-    best = eval;
-    index = node.Point();
-  }
-
-  // Recurse into children.
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-    GetMaxDistance<TreeType, MetricType>(node.Child(i), constantNode, best,
-        index);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckMaxDistances(TreeType& node)
-{
-  // Check child distances.
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-  {
-    const arma::mat& dataset = node.Dataset();
-    double eval = MetricType::Evaluate(dataset.unsafe_col(node.Point()),
-        dataset.unsafe_col(node.Child(i).Point()));
-
-    Log::Assert(std::abs(eval - node.Child(i).ParentDistance()) < 1e-10);
-  }
-
-  // Check all descendants.
-  double maxDescendantDistance = 0;
-  size_t maxIndex = 0;
-  GetMaxDistance<TreeType, MetricType>(node, node, maxDescendantDistance,
-      maxIndex);
-
-  Log::Assert(std::abs(maxDescendantDistance -
-      node.FurthestDescendantDistance()) < 1e-10);
-
-  for (size_t i = 0; i < node.NumChildren(); ++i)
-    CheckMaxDistances<TreeType, MetricType>(node.Child(i));
-}
-
-template<typename TreeType>
-struct SearchFrame
-{
-  TreeType* node;
-  double eval;
-};
-
-template<typename TreeType>
-class SearchFrameCompare
-{
- public:
-  bool operator()(const SearchFrame<TreeType>& lhs,
-                  const SearchFrame<TreeType>& rhs)
-  {
-    // Compare scale.
-    if (lhs.node->Scale() != rhs.node->Scale())
-      return (lhs.node->Scale() < rhs.node->Scale());
-    else
-    {
-      // Now we have to compare by evaluation.
-      return (lhs.eval < rhs.eval);
-    }
-  }
-};
-
 template<typename KernelType>
 FastMKS<KernelType>::FastMKS(const arma::mat& referenceSet,
                              KernelType& kernel,
@@ -223,8 +38,9 @@
   if (naive)
     referenceTree = NULL;
   else
-    referenceTree = new tree::CoverTree<IPMetric<KernelType> >(referenceSet,
-        expansionConstant, &metric);
+    referenceTree = new tree::CoverTree<IPMetric<KernelType>,
+        tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
+        &metric);
 
   Timer::Stop("tree_building");
 }
@@ -248,42 +64,17 @@
   if (naive)
     referenceTree = NULL;
   else
-    referenceTree = new tree::CoverTree<IPMetric<KernelType> >(referenceSet,
-        expansionConstant, &metric);
+    referenceTree = new tree::CoverTree<IPMetric<KernelType>,
+        tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
+        &metric);
 
   if (single || naive)
     queryTree = NULL;
   else
-    queryTree = new tree::CoverTree<IPMetric<KernelType> >(querySet,
-        expansionConstant, &metric);
+    queryTree = new tree::CoverTree<IPMetric<KernelType>,
+        tree::FirstPointIsRoot, FastMKSStat>(querySet, expansionConstant,
+        &metric);
 
-/*  if (referenceTree != NULL)
-  {
-    Log::Debug << "Check counts" << std::endl;
-    // Now loop through the tree and ensure that each leaf is only created once.
-    arma::vec counts;
-    counts.zeros(referenceSet.n_elem);
-    RecurseTreeCountLeaves(*referenceTree, counts);
-
-    // Each point should only have one leaf node representing it.
-    for (size_t i = 0; i < 20; ++i)
-      Log::Assert(counts[i] == 1);
-
-    Log::Debug << "Check self child\n";
-    // Each non-leaf should have a self-child.
-    CheckSelfChild<tree::CoverTree<IPMetric<KernelType> > >(*referenceTree);
-
-    Log::Debug << "Check covering\n";
-    // Each node must satisfy the covering principle (its children must be less
-    // than or equal to a certain distance apart).
-    CheckCovering<tree::CoverTree<IPMetric<KernelType> >, IPMetric<KernelType> >(*referenceTree);
-
-    Log::Debug << "Check max distances\n";
-    // Check maximum distance of children and grandchildren.
-    CheckMaxDistances<tree::CoverTree<IPMetric<KernelType> >, IPMetric<KernelType> >(*referenceTree);
-    Log::Debug << "Done\n";
-  }*/
-
   Timer::Stop("tree_building");
 }
 
@@ -301,10 +92,10 @@
                                arma::Mat<size_t>& indices,
                                arma::mat& products)
 {
-  // No remapping will be necessary.
+  // No remapping will be necessary because we are using the cover tree.
   indices.set_size(k, querySet.n_cols);
   products.set_size(k, querySet.n_cols);
-  products.fill(-1.0);
+  products.fill(-DBL_MAX);
 
   Timer::Start("computing_products");
 
@@ -318,6 +109,9 @@
     {
       for (size_t r = 0; r < referenceSet.n_cols; ++r)
       {
+        if ((&querySet == &referenceSet) && (q == r))
+          continue;
+
         const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
             referenceSet.unsafe_col(r));
         ++kernelEvaluations;
@@ -343,144 +137,21 @@
   // Single-tree implementation.
   if (single)
   {
-    // Calculate number of pruned nodes.
-    size_t numPrunes = 0;
+    // Create rules object (this will store the results).  This constructor
+    // precalculates each self-kernel value.
+    typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
+        FastMKSStat> TreeType;
+    typedef FastMKSRules<KernelType, TreeType> RuleType;
+    RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
 
-    // Precalculate query products ( || q || for all q).
-    arma::vec queryProducts(querySet.n_cols);
-    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
-      queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
-          querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
-    kernelEvaluations += querySet.n_cols;
+    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
 
-    // Screw the CoverTreeTraverser, we'll implement it by hand.
-    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
-    {
-      // Use an array of priority queues?
-      std::priority_queue<
-          SearchFrame<tree::CoverTree<IPMetric<KernelType> > >,
-          std::vector<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >,
-          SearchFrameCompare<tree::CoverTree<IPMetric<KernelType> > > >
-          frameQueue;
+    for (size_t i = 0; i < querySet.n_cols; ++i)
+      traverser.Traverse(i, *referenceTree);
 
-      // Add initial frame.
-      SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
-      nextFrame.node = referenceTree;
-      nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
-          referenceSet.unsafe_col(referenceTree->Point()));
-      ++kernelEvaluations;
+    // Save the number of pruned nodes.
+    const size_t numPrunes = traverser.NumPrunes();
 
-      // The initial evaluation will be the best so far.
-      indices(0, queryIndex) = referenceTree->Point();
-      products(0, queryIndex) = nextFrame.eval;
-
-      frameQueue.push(nextFrame);
-
-      tree::CoverTree<IPMetric<KernelType> >* referenceNode;
-      double eval;
-      double maxProduct;
-
-      while (!frameQueue.empty())
-      {
-        // Get the information for this node.
-        const SearchFrame<tree::CoverTree<IPMetric<KernelType> > >& frame =
-            frameQueue.top();
-
-        referenceNode = frame.node;
-        eval = frame.eval;
-
-        // Loop through the children, seeing if we can prune them; if not, add
-        // them to the queue.  The self-child is different -- it has the same
-        // parent (and therefore the same kernel evaluation).
-        if (referenceNode->NumChildren() > 0)
-        {
-          SearchFrame<tree::CoverTree<IPMetric<KernelType> > > childFrame;
-
-          // We must handle the self-child differently, to avoid adding it to
-          // the results twice.
-          childFrame.node = &(referenceNode->Child(0));
-          childFrame.eval = eval;
-
-//          maxProduct = eval + std::pow(childFrame.node->ExpansionConstant(),
-//              childFrame.node->Scale() + 1) * queryProducts[queryIndex];
-          // Alternate pruning rule.
-          maxProduct = eval + childFrame.node->FurthestDescendantDistance() *
-              queryProducts[queryIndex];
-
-          // Add self-child if we can't prune it.
-          if (maxProduct > products(products.n_rows - 1, queryIndex))
-          {
-            // But only if it has children of its own.
-            if (childFrame.node->NumChildren() > 0)
-              frameQueue.push(childFrame);
-          }
-          else
-            ++numPrunes;
-
-          for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
-          {
-            // Before we evaluate the child, let's see if it can possibly have
-            // a better evaluation.
-            double maxChildEval = eval + queryProducts[queryIndex] *
-                (referenceNode->Child(i).ParentDistance() +
-                referenceNode->Child(i).FurthestDescendantDistance());
-
-            if (maxChildEval <= products(products.n_rows - 1, queryIndex))
-            {
-              ++numPrunes;
-              continue; // Skip this child; it can't be any better.
-            }
-
-            // Evaluate child.
-            childFrame.node = &(referenceNode->Child(i));
-            childFrame.eval = metric.Kernel().Evaluate(
-                querySet.unsafe_col(queryIndex),
-                referenceSet.unsafe_col(referenceNode->Child(i).Point()));
-            ++kernelEvaluations;
-
-            // Can we prune it?  If we can, we can avoid putting it in the queue
-            // (saves time).
-//            double maxProduct = childFrame.eval +
-//                std::pow(childFrame.node->ExpansionConstant(),
-//                childFrame.node->Scale() + 1) * queryProducts[queryIndex];
-            // Alternate pruning rule.
-            maxProduct = childFrame.eval + queryProducts[queryIndex] *
-                childFrame.node->FurthestDescendantDistance();
-
-            if (maxProduct > products(products.n_rows - 1, queryIndex))
-            {
-              // Good enough to recurse into.  While we're at it, check the
-              // actual evaluation and see if it's an improvement.
-              if (childFrame.eval > products(products.n_rows - 1, queryIndex))
-              {
-                // This is a better result.  Find out where to insert.
-                size_t insertPosition = 0;
-                for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
-                  if (childFrame.eval > products(insertPosition, queryIndex))
-                    break;
-
-                // Insert into the correct position; we are guaranteed that
-                // insertPosition is valid.
-                InsertNeighbor(indices, products, queryIndex, insertPosition,
-                    childFrame.node->Point(), childFrame.eval);
-              }
-
-              // Now add this to the queue (if it has any children which may
-              // need to be recursed into).
-              if (childFrame.node->NumChildren() > 0)
-                frameQueue.push(childFrame);
-            }
-            else
-            {
-              ++numPrunes;
-            }
-          }
-        }
-
-        frameQueue.pop();
-      }
-    }
-
     Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
     Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
         << std::endl;
@@ -491,8 +162,28 @@
     return;
   }
 
-  // Double-tree implementation.
-  Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
+  // Dual-tree implementation.
+  typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
+      FastMKSStat> TreeType;
+  typedef FastMKSRules<KernelType, TreeType> RuleType;
+  RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+  if (queryTree)
+    traverser.Traverse(*queryTree, *referenceTree);
+  else
+    traverser.Traverse(*referenceTree, *referenceTree);
+
+  const size_t numPrunes = traverser.NumPrunes();
+
+  Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+  Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
+  Log::Info << "Distance evaluations: " << distanceEvaluations << "."
+      << std::endl;
+
+  Timer::Stop("computing_products");
+  return;
 }
 
 /**
@@ -529,6 +220,7 @@
 }
 
 // Specialized implementation for tighter bounds for Gaussian.
+/*
 template<>
 void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
                                              arma::Mat<size_t>& indices,
@@ -741,6 +433,7 @@
   Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
 
 }
+*/
 
 }; // namespace fastmks
 }; // namespace mlpack

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp	2013-03-14 18:50:06 UTC (rev 14559)
@@ -13,21 +13,97 @@
 namespace mlpack {
 namespace fastmks {
 
-template<typename MetricType>
+template<typename KernelType, typename TreeType>
 class FastMKSRules
 {
  public:
   FastMKSRules(const arma::mat& referenceSet,
                const arma::mat& querySet,
                arma::Mat<size_t>& indices,
-               arma::mat& products);
+               arma::mat& products,
+               KernelType& kernel);
 
-  void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+  //! Compute the base case (kernel value) between two points.
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
-  bool CanPrune(const size_t queryIndex,
-                tree::CoverTree<MetricType>& referenceNode,
-                const size_t parentIndex);
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate to be recursed into.
+   */
+  double Score(const size_t queryIndex, TreeType& referenceNode) const;
 
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+   */
+  double Score(const size_t queryIndex,
+               TreeType& referenceNode,
+               const double baseCaseResult) const;
+
+  /**
+   * Get the score for recursion order.  A low score indicates priority for
+   * recursion, while DBL_MAX indicates that the node should not be recursed
+   * into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to be recursed into.
+   * @param referenceNode Candidate reference node to be recursed into.
+   */
+  double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+  /**
+   * Get the score for recursion order, passing the base case result (in the
+   * situation where it may be needed to calculate the recursion order).  A low
+   * score indicates priority for recursion, while DBL_MAX indicates that the
+   * node should not be recursed into at all (it should be pruned).
+   *
+   * @param queryNode Candidate query node to be recursed into.
+   * @param referenceNode Candidate reference node to be recursed into.
+   * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
+   */
+  double Score(TreeType& queryNode,
+               TreeType& referenceNode,
+               const double baseCaseResult) const;
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that a node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryIndex Index of query point.
+   * @param referenceNode Candidate node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
+  /**
+   * Re-evaluate the score for recursion order.  A low score indicates priority
+   * for recursion, while DBL_MAX indicates that a node should not be recursed
+   * into at all (it should be pruned).  This is used when the score has already
+   * been calculated, but another recursion may have modified the bounds for
+   * pruning.  So the old score is checked against the new pruning bound.
+   *
+   * @param queryNode Candidate query node to be recursed into.
+   * @param referenceNode Candidate reference node to be recursed into.
+   * @param oldScore Old score produced by Score() (or Rescore()).
+   */
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
  private:
   const arma::mat& referenceSet;
 
@@ -38,7 +114,14 @@
   arma::mat& products;
 
   arma::vec queryKernels; // || q || for each q.
+  arma::vec referenceKernels;
 
+  //! The instantiated kernel.
+  KernelType& kernel;
+
+  //! Calculate the bound for a given query node.
+  double CalculateBound(TreeType& queryNode) const;
+
   void InsertNeighbor(const size_t queryIndex,
                       const size_t pos,
                       const size_t neighbor,

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp	2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp	2013-03-14 18:50:06 UTC (rev 14559)
@@ -13,57 +13,237 @@
 namespace mlpack {
 namespace fastmks {
 
-template<typename MetricType>
-FastMKSRules<MetricType>::FastMKSRules(const arma::mat& referenceSet,
-                                       const arma::mat& querySet,
-                                       arma::Mat<size_t>& indices,
-                                       arma::mat& products) :
+template<typename KernelType, typename TreeType>
+FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
+                                                 const arma::mat& querySet,
+                                                 arma::Mat<size_t>& indices,
+                                                 arma::mat& products,
+                                                 KernelType& kernel) :
     referenceSet(referenceSet),
     querySet(querySet),
     indices(indices),
-    products(products)
+    products(products),
+    kernel(kernel)
 {
   // Precompute each self-kernel.
-//  queryKernels.set_size(querySet.n_cols);
-//  for (size_t i = 0; i < querySet.n_cols; ++i)
-//   queryKernels[i] = sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(i),
-//        querySet.unsafe_col(i)));
+  queryKernels.set_size(querySet.n_cols);
+  for (size_t i = 0; i < querySet.n_cols; ++i)
+    queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
+                                           querySet.unsafe_col(i)));
+
+  referenceKernels.set_size(referenceSet.n_cols);
+  for (size_t i = 0; i < referenceSet.n_cols; ++i)
+    referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
+                                               referenceSet.unsafe_col(i)));
 }
 
-template<typename MetricType>
-bool FastMKSRules<MetricType>::CanPrune(const size_t queryIndex,
-    tree::CoverTree<MetricType>& referenceNode,
-    const size_t parentIndex)
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
 {
-  // The maximum possible inner product is given by
-  //   <q, p_0> + R_p || q ||
-  // and since we are using cover trees, p_0 is the point referred to by the
-  // node, and R_p will be the expansion constant to the power of the scale plus
-  // one.
-  const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
-      referenceSet.col(referenceNode.Point()));
 
-  // See if base case can be added.
-  if (eval > products(products.n_rows - 1, queryIndex))
+  double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
+                                      referenceSet.unsafe_col(referenceIndex));
+
+  // If the reference and query sets are identical, we still need to compute the
+  // base case (so that things can be bounded properly), but we won't add it to
+  // the results.
+  if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+    return kernelEval;
+
+  // If this is a better candidate, insert it into the list.
+  if (kernelEval < products(products.n_rows - 1, queryIndex))
+    return kernelEval;
+
+  size_t insertPosition = 0;
+  for ( ; insertPosition < products.n_rows; ++insertPosition)
+    if (kernelEval >= products(insertPosition, queryIndex))
+      break;
+
+  InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
+
+  return kernelEval;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
+                                                 TreeType& referenceNode) const
+{
+  // Calculate the maximum possible kernel value.
+  const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+  const arma::vec refCentroid;
+  referenceNode.Bound().Centroid(refCentroid);
+
+  const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
+      referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+
+  // Compare with the current best.
+  const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+  // We return the inverse of the maximum kernel so that larger kernels are
+  // recursed into first.
+  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode,
+    const double baseCaseResult) const
+{
+  // We already have the base case result.  Add the bound.
+  const double maxKernel = baseCaseResult +
+      referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+  const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+  // We return the inverse of the maximum kernel so that larger kernels are
+  // recursed into first.
+  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
+                                                 TreeType& referenceNode) const
+{
+  // Calculate the maximum possible kernel value.
+  const arma::vec queryCentroid;
+  const arma::vec refCentroid;
+  queryNode.Bound().Centroid(queryCentroid);
+  referenceNode.Bound().Centroid(refCentroid);
+
+  const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+      referenceNode.Stat().SelfKernel();
+  const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+      queryNode.Stat().SelfKernel();
+
+  const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
+      refKernelTerm + queryKernelTerm +
+      (queryNode.FurthestDescendantDistance() *
+       referenceNode.FurthestDescendantDistance());
+
+  // The existing bound.
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
+  const double bestKernel = queryNode.Stat().Bound();
+
+  // We return the inverse of the maximum kernel so that larger kernels are
+  // recursed into first.
+  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+    TreeType& queryNode,
+    TreeType& referenceNode,
+    const double baseCaseResult) const
+{
+  // We already have the base case, so we need to add the bounds.
+  const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+      referenceNode.Stat().SelfKernel();
+  const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+      queryNode.Stat().SelfKernel();
+
+  const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
+      (queryNode.FurthestDescendantDistance() *
+       referenceNode.FurthestDescendantDistance());
+
+  // The existing bound.
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
+  const double bestKernel = queryNode.Stat().Bound();
+
+  // We return the inverse of the maximum kernel so that larger kernels are
+  // recursed into first.
+  return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+                                                   TreeType& /*referenceNode*/,
+                                                   const double oldScore) const
+{
+  const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+  return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+                                                   TreeType& /*referenceNode*/,
+                                                   const double oldScore) const
+{
+  queryNode.Stat().Bound() = CalculateBound(queryNode);
+  const double bestKernel = queryNode.Stat().Bound();
+
+  return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+/**
+ * Calculate the bound for the given query node.  This bound represents the
+ * minimum value which a node combination must achieve to guarantee an
+ * improvement in the results.
+ *
+ * @param queryNode Query node to calculate bound for.
+ */
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
+    const
+{
+  // We have four possible bounds -- just like NeighborSearchRules, but they are
+  // slightly different in this context.
+  //
+  // (1) min ( min_{all points p in queryNode} P_p[k],
+  //           min_{all children c in queryNode} B(c) );
+  // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
+  //           descendant distance) sqrt(K(I_p[k], I_p[k]));
+  // (3) max_{all children c in queryNode} B(c) + <-- not done yet.  ignored.
+  // (4) B(parent of queryNode);
+  double worstPointKernel = DBL_MAX;
+  double bestAdjustedPointKernel = -DBL_MAX;
+  double bestPointSelfKernel = -DBL_MAX;
+  const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
+
+  // Loop over all points in this node to find the best and worst.
+  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
   {
-    size_t insertPosition;
-    for (insertPosition = 0; insertPosition < indices.n_rows; ++insertPosition)
-      if (eval > products(insertPosition, queryIndex))
-        break;
+    const size_t point = queryNode.Point(i);
+    if (products(products.n_rows - 1, point) < worstPointKernel)
+      worstPointKernel = products(products.n_rows - 1, point);
 
-    // We are guaranteed insertPosition is in the valid range.
-    InsertNeighbor(queryIndex, insertPosition, referenceNode.Point(), eval);
+    if (products(products.n_rows - 1, point) == -DBL_MAX)
+      continue; // Avoid underflow.
+
+    const double candidateKernel = products(products.n_rows - 1, point) -
+        (2 * queryDescendantDistance) *
+        referenceKernels[indices(indices.n_rows - 1, point)];
+
+    if (candidateKernel > bestAdjustedPointKernel)
+      bestAdjustedPointKernel = candidateKernel;
   }
 
-  double maxProduct = eval + std::pow(referenceNode.ExpansionConstant(),
-      referenceNode.Scale() + 1) *
-      sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
-      querySet.col(queryIndex)));
+  // Loop over all the children in the node.
+  double worstChildKernel = DBL_MAX;
 
-  if (maxProduct > products(products.n_rows - 1, queryIndex))
-    return false;
-  else
-    return true;
+  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+  {
+    if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
+      worstChildKernel = queryNode.Child(i).Stat().Bound();
+  }
+
+  // Now assemble bound (1).
+  const double firstBound = (worstPointKernel < worstChildKernel) ?
+      worstPointKernel : worstChildKernel;
+
+  // Bound (2) is bestAdjustedPointKernel.
+  const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
+      queryNode.Parent()->Stat().Bound();
+
+  // Pick the best of these bounds.
+  const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
+      bestAdjustedPointKernel;
+//  const double interA = 0.0;
+  const double interB = fourthBound;
+
+  return (interA > interB) ? interA : interB;
 }
 
 /**
@@ -74,11 +254,11 @@
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename MetricType>
-void FastMKSRules<MetricType>::InsertNeighbor(const size_t queryIndex,
-                                              const size_t pos,
-                                              const size_t neighbor,
-                                              const double distance)
+template<typename MetricType, typename TreeType>
+void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
+                                                        const size_t pos,
+                                                        const size_t neighbor,
+                                                        const double distance)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (products.n_rows - 1))

Added: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp	2013-03-14 18:50:06 UTC (rev 14559)
@@ -0,0 +1,89 @@
+/**
+ * @file fastmks_stat.hpp
+ * @author Ryan Curtin
+ *
+ * The statistic used in trees with FastMKS.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+/**
+ * The statistic used in trees with FastMKS.  This stores both the bound and the
+ * self-kernels for each node in the tree.
+ */
+class FastMKSStat
+{
+ public:
+  /**
+   * Default initialization.
+   */
+  FastMKSStat() : bound(-DBL_MAX), selfKernel(0.0) { }
+
+  /**
+   * Initialize this statistic for the given tree node.  The TreeType's metric
+   * better be IPMetric with some kernel type (that is, Metric().Kernel() must
+   * exist).
+   *
+   * @param node Node that this statistic is built for.
+   */
+  template<typename TreeType>
+  FastMKSStat(const TreeType& node) :
+      bound(-DBL_MAX)
+  {
+    // Do we have to calculate the centroid?
+    if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+    {
+      // If this type of tree has self-children, then maybe the evaluation is
+      // already done.  These statistics are built bottom-up, so the child stat
+      // should already be done.
+      if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
+          (node.NumChildren() > 0) &&
+          (node.Point(0) == node.Child(0).Point(0)))
+      {
+        selfKernel = node.Child(0).Stat().SelfKernel();
+      }
+      else
+      {
+        selfKernel = sqrt(node.Metric().Kernel().Evaluate(
+            node.Dataset().unsafe_col(node.Point(0)),
+            node.Dataset().unsafe_col(node.Point(0))));
+      }
+    }
+    else
+    {
+      // Calculate the centroid.
+      arma::vec centroid;
+      node.Centroid(centroid);
+
+      selfKernel = sqrt(node.Metric().Kernel().Evaluate(centroid, centroid));
+    }
+  }
+
+  //! Get the self-kernel.
+  double SelfKernel() const { return selfKernel; }
+  //! Modify the self-kernel.
+  double& SelfKernel() { return selfKernel; }
+
+  //! Get the bound.
+  double Bound() const { return bound; }
+  //! Modify the bound.
+  double& Bound() { return bound; }
+
+ private:
+  //! The bound for pruning.
+  double bound;
+
+  //! The self-kernel evaluation: sqrt(K(centroid, centroid)).
+  double selfKernel;
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif




More information about the mlpack-svn mailing list