[mlpack-git] master: Refactor FastMKS to take individual queries in Search(). (3ad3877)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Apr 23 10:40:29 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/57d0567dddff01feea73b348f38cc040dc3cf8e3...3ad38770911c7b9840901f0934bd1a81c2249046

>---------------------------------------------------------------

commit 3ad38770911c7b9840901f0934bd1a81c2249046
Author: ryan <ryan at ratml.org>
Date:   Thu Apr 23 10:40:10 2015 -0400

    Refactor FastMKS to take individual queries in Search().


>---------------------------------------------------------------

3ad38770911c7b9840901f0934bd1a81c2249046
 src/mlpack/methods/fastmks/fastmks.hpp      | 150 ++++-----
 src/mlpack/methods/fastmks/fastmks_impl.hpp | 490 ++++++++--------------------
 src/mlpack/methods/fastmks/fastmks_main.cpp | 111 ++++---
 3 files changed, 270 insertions(+), 481 deletions(-)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 90ca995..72ea34b 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -55,106 +55,102 @@ class FastMKS
 {
  public:
   /**
-   * Create the FastMKS object using the reference set as the query set.
-   * Optionally, specify whether or not single-tree search or naive
-   * (brute-force) search should be used.
+   * Create the FastMKS object with the given reference set (this is the set
+   * that is searched).  Optionally, specify whether or not single-tree search
+   * or naive (brute-force) search should be used.
    *
-   * @param referenceSet Set of data to run FastMKS on.
+   * @param referenceSet Set of reference data.
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
   FastMKS(const typename TreeType::Mat& referenceSet,
-          const bool single = false,
+          const bool singleMode = false,
           const bool naive = false);
 
   /**
-   * Create the FastMKS object using separate reference and query sets.
-   * Optionally, specify whether or not single-tree search or naive
-   * (brute-force) search should be used.
+   * Create the FastMKS object using the reference set (this is the set that is
+   * searched) with an initialized kernel.  This is useful for when the kernel
+   * stores state.  Optionally, specify whether or not single-tree search or
+   * naive (brute-force) search should be used.
    *
    * @param referenceSet Reference set of data for FastMKS.
-   * @param querySet Set of query points for FastMKS.
-   * @param single Whether or not to run single-tree search.
-   * @param naive Whether or not to run brute-force (naive) search.
-   */
-  FastMKS(const typename TreeType::Mat& referenceSet,
-          const typename TreeType::Mat& querySet,
-          const bool single = false,
-          const bool naive = false);
-
-  /**
-   * Create the FastMKS object using the reference set as the query set, and
-   * with an initialized kernel.  This is useful for when the kernel stores
-   * state.  Optionally, specify whether or not single-tree search or naive
-   * (brute-force) search should be used.
-   *
-   * @param referenceSet Reference set of data for FastMKS.
-   * @param kernel Initialized kernel.
-   * @param single Whether or not to run single-tree search.
-   * @param naive Whether or not to run brute-force (naive) search.
-   */
-  FastMKS(const typename TreeType::Mat& referenceSet,
-          KernelType& kernel,
-          const bool single = false,
-          const bool naive = false);
-
-  /**
-   * Create the FastMKS object using separate reference and query sets, and with
-   * an initialized kernel.  This is useful for when the kernel stores state.
-   * Optionally, specify whether or not single-tree search or naive
-   * (brute-force) search should be used.
-   *
-   * @param referenceSet Reference set of data for FastMKS.
-   * @param querySet Set of query points for FastMKS.
    * @param kernel Initialized kernel.
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
   FastMKS(const typename TreeType::Mat& referenceSet,
-          const typename TreeType::Mat& querySet,
           KernelType& kernel,
-          const bool single = false,
+          const bool singleMode = false,
           const bool naive = false);
 
   /**
    * Create the FastMKS object with an already-initialized tree built on the
    * reference points.  Be sure that the tree is built with the metric type
-   * IPMetric<KernelType>.  For this constructor, the reference set and the
-   * query set are the same points.  Optionally, whether or not to run
-   * single-tree search or brute-force (naive) search can be specified.
+   * IPMetric<KernelType>.  Optionally, whether or not to run single-tree search
+   * can be specified.  Brute-force search is not available with this
+   * constructor since a tree is given (use one of the other constructors).
    *
-   * @param referenceSet Reference set of data for FastMKS.
    * @param referenceTree Tree built on reference data.
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const typename TreeType::Mat& referenceSet,
-          TreeType* referenceTree,
-          const bool single = false,
-          const bool naive = false);
+  FastMKS(TreeType* referenceTree,
+          const bool singleMode = false);
+
+  //! Destructor for the FastMKS object.
+  ~FastMKS();
 
   /**
-   * Create the FastMKS object with already-initialized trees built on the
-   * reference and query points.  Be sure that the trees are built with the
-   * metric type IPMetric<KernelType>.  Optionally, whether or not to run
-   * single-tree search or naive (brute-force) search can be specified.
+   * Search for the points in the reference set with maximum kernel evaluation
+   * to each point in the given query set.  The resulting kernel evaluations are
+   * stored in the kernels matrix, and the corresponding point indices are
+   * stored in the indices matrix.  The results for each point in the query set
+   * are stored in the corresponding column of the kernels and products
+   * matrices; for instance, the index of the point with maximum kernel
+   * evaluation to point 4 in the query set will be stored in row 0 and column 4
+   * of the indices matrix.
    *
-   * @param referenceSet Reference set of data for FastMKS.
-   * @param referenceTree Tree built on reference data.
-   * @param querySet Set of query points for FastMKS.
-   * @param queryTree Tree built on query data.
-   * @param single Whether or not to use single-tree search.
-   * @param naive Whether or not to use naive (brute-force) search.
+   * If querySet only contains a few points, the extra overhead of building a
+   * tree to perform dual-tree search may not be warranted, and it may be faster
+   * to use single-tree search, either by setting singleMode to false in the
+   * constructor or with SingleMode().
+   *
+   * @param querySet Set of query points (can be a single point).
+   * @param k The number of maximum kernels to find.
+   * @param indices Matrix to store resulting indices of max-kernel search in.
+   * @param kernels Matrix to store resulting max-kernel values in.
    */
-  FastMKS(const typename TreeType::Mat& referenceSet,
-          TreeType* referenceTree,
-          const typename TreeType::Mat& querySet,
-          TreeType* queryTree,
-          const bool single = false,
-          const bool naive = false);
+  void Search(const typename TreeType::Mat& querySet,
+              const size_t k,
+              arma::Mat<size_t>& indices,
+              arma::mat& kernels);
 
-  //! Destructor for the FastMKS object.
-  ~FastMKS();
+  /**
+   * Search for the points in the reference set with maximum kernel evaluation
+   * to each point in the query set corresponding to the given pre-built query
+   * tree.  The resulting kernel evaluations are stored in the kernels matrix,
+   * and the corresponding point indices are stored in the indices matrix.  The
+   * results for each point in the query set are stored in the corresponding
+   * column of the kernels and products matrices; for instance, the index of the
+   * point with maximum kernel evaluation to point 4 in the query set will be
+   * stored in row 0 and column 4 of the indices matrix.
+   *
+   * This will throw an exception if called while the FastMKS object has
+   * 'single' set to true.
+   *
+   * Be aware that if your tree modifies the original input matrix, the results
+   * here are with respect to the modified input matrix (that is,
+   * queryTree->Dataset()).
+   *
+   * @param queryTree Tree built on query points.
+   * @param k The number of maximum kernels to find.
+   * @param indices Matrix to store resulting indices of max-kernel search in.
+   * @param kernels Matrix to store resulting max-kernel values in.
+   */
+  void Search(TreeType* querySet,
+              const size_t k,
+              arma::Mat<size_t>& indices,
+              arma::mat& kernels);
 
   /**
    * Search for the maximum inner products of the query set (or if no query set
@@ -179,6 +175,11 @@ class FastMKS
   //! Modify the inner-product metric induced by the given kernel.
   metric::IPMetric<KernelType>& Metric() { return metric; }
 
+  //! Get whether or not single-tree search is used.
+  bool SingleMode() const { return singleMode; }
+  //! Modify whether or not single-tree search is used.
+  bool& SingleMode() { return singleMode; }
+
   /**
    * Returns a string representation of this object.
    */
@@ -187,20 +188,13 @@ class FastMKS
  private:
   //! The reference dataset.
   const typename TreeType::Mat& referenceSet;
-  //! The query dataset.
-  const typename TreeType::Mat& querySet;
-
   //! The tree built on the reference dataset.
   TreeType* referenceTree;
-  //! The tree built on the query dataset.  This is NULL if there is no query
-  //! set.
-  TreeType* queryTree;
-
-  //! If true, this object created the trees and is responsible for them.
+  //! If true, this object created the tree and is responsible for it.
   bool treeOwner;
 
   //! If true, single-tree search is used.
-  bool single;
+  bool singleMode;
   //! If true, naive (brute-force) search is used.
   bool naive;
 
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 19f466e..1835d8d 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -18,17 +18,15 @@
 namespace mlpack {
 namespace fastmks {
 
-// Single dataset, no instantiated kernel.
+// No instantiated kernel.
 template<typename KernelType, typename TreeType>
 FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
-                                       const bool single,
+                                       const bool singleMode,
                                        const bool naive) :
     referenceSet(referenceSet),
-    querySet(referenceSet),
     referenceTree(NULL),
-    queryTree(NULL),
     treeOwner(true),
-    single(single),
+    singleMode(singleMode),
     naive(naive)
 {
   Timer::Start("tree_building");
@@ -36,50 +34,19 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
   if (!naive)
     referenceTree = new TreeType(referenceSet);
 
-  if (!naive && !single)
-    queryTree = new TreeType(referenceSet);
-
-  Timer::Stop("tree_building");
-}
-
-// Two datasets, no instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
-                                       const typename TreeType::Mat& querySet,
-                                       const bool single,
-                                       const bool naive) :
-    referenceSet(referenceSet),
-    querySet(querySet),
-    referenceTree(NULL),
-    queryTree(NULL),
-    treeOwner(true),
-    single(single),
-    naive(naive)
-{
-  Timer::Start("tree_building");
-
-  // If necessary, the trees should be built.
-  if (!naive)
-    referenceTree = new TreeType(referenceSet);
-
-  if (!naive && !single)
-    queryTree = new TreeType(querySet);
-
   Timer::Stop("tree_building");
 }
 
-// One dataset, instantiated kernel.
+// Instantiated kernel.
 template<typename KernelType, typename TreeType>
 FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
                                        KernelType& kernel,
-                                       const bool single,
+                                       const bool singleMode,
                                        const bool naive) :
     referenceSet(referenceSet),
-    querySet(referenceSet),
     referenceTree(NULL),
-    queryTree(NULL),
     treeOwner(true),
-    single(single),
+    singleMode(singleMode),
     naive(naive),
     metric(kernel)
 {
@@ -89,135 +56,173 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
   if (!naive)
     referenceTree = new TreeType(referenceSet, metric);
 
-  if (!naive && !single)
-    queryTree = new TreeType(referenceSet, metric);
-
-  Timer::Stop("tree_building");
-}
-
-// Two datasets, instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
-                                       const typename TreeType::Mat& querySet,
-                                       KernelType& kernel,
-                                       const bool single,
-                                       const bool naive) :
-    referenceSet(referenceSet),
-    querySet(querySet),
-    referenceTree(NULL),
-    queryTree(NULL),
-    treeOwner(true),
-    single(single),
-    naive(naive),
-    metric(kernel)
-{
-  Timer::Start("tree_building");
-
-  // If necessary, the trees should be built.
-  if (!naive)
-    referenceTree = new TreeType(referenceSet, metric);
-
-  if (!naive && !single)
-    queryTree = new TreeType(querySet, metric);
-
   Timer::Stop("tree_building");
 }
 
 // One dataset, pre-built tree.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(
-    const typename TreeType::Mat& referenceSet,
-    TreeType* referenceTree,
-    const bool single,
-    const bool naive) :
-    referenceSet(referenceSet),
-    querySet(referenceSet),
+FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
+                                       const bool singleMode) :
+    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
-    queryTree(NULL),
     treeOwner(false),
-    single(single),
-    naive(naive),
+    singleMode(singleMode),
+    naive(false),
     metric(referenceTree->Metric())
 {
-  // The query tree cannot be the same as the reference tree.
-  if (referenceTree)
-    queryTree = new TreeType(*referenceTree);
+  // Nothing to do.
 }
 
-// Two datasets, pre-built trees.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(
-    const typename TreeType::Mat& referenceSet,
-    TreeType* referenceTree,
-    const typename TreeType::Mat& querySet,
-    TreeType* queryTree,
-    const bool single,
-    const bool naive) :
-    referenceSet(referenceSet),
-    querySet(querySet),
-    referenceTree(referenceTree),
-    queryTree(queryTree),
-    treeOwner(false),
-    single(single),
-    naive(naive),
-    metric(referenceTree->Metric())
+FastMKS<KernelType, TreeType>::~FastMKS()
 {
-  // Nothing to do.
+  // If we created the trees, we must delete them.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
 }
 
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::~FastMKS()
+void FastMKS<KernelType, TreeType>::Search(
+    const typename TreeType::Mat& querySet,
+    const size_t k,
+    arma::Mat<size_t>& indices,
+    arma::mat& kernels)
 {
-  // If we created the trees, we must delete them.
-  if (treeOwner)
+  Timer::Start("computing_products");
+
+  // No remapping will be necessary because we are using the cover tree.
+  indices.set_size(k, querySet.n_cols);
+  kernels.set_size(k, querySet.n_cols);
+
+  // Naive implementation.
+  if (naive)
+  {
+    // Fill kernels.
+    kernels.fill(-DBL_MAX);
+
+    // Simple double loop.  Stupid, slow, but a good benchmark.
+    for (size_t q = 0; q < querySet.n_cols; ++q)
+    {
+      for (size_t r = 0; r < referenceSet.n_cols; ++r)
+      {
+        const double eval = metric.Kernel().Evaluate(querySet.col(q),
+                                                     referenceSet.col(r));
+
+        size_t insertPosition;
+        for (insertPosition = 0; insertPosition < indices.n_rows;
+            ++insertPosition)
+          if (eval > kernels(insertPosition, q))
+            break;
+
+        if (insertPosition < indices.n_rows)
+          InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+      }
+    }
+
+    Timer::Stop("computing_products");
+
+    return;
+  }
+
+  // Single-tree implementation.
+  if (singleMode)
   {
-    if (queryTree)
-      delete queryTree;
-    if (referenceTree)
-      delete referenceTree;
+    // Fill kernels.
+    kernels.fill(-DBL_MAX);
+
+    // Create rules object (this will store the results).  This constructor
+    // precalculates each self-kernel value.
+    typedef FastMKSRules<KernelType, TreeType> RuleType;
+    RuleType rules(referenceSet, querySet, indices, kernels, metric.Kernel());
+
+    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+    for (size_t i = 0; i < querySet.n_cols; ++i)
+      traverser.Traverse(i, *referenceTree);
+
+    Log::Info << rules.BaseCases() << " base cases." << std::endl;
+    Log::Info << rules.Scores() << " scores." << std::endl;
+
+    Timer::Stop("computing_products");
+    return;
   }
-  else if (&querySet == &referenceSet)
+
+  // Dual-tree implementation.  First, we need to build the query tree.  We are
+  // assuming it doesn't map anything...
+  Timer::Stop("computing_products");
+  Timer::Start("tree_building");
+  TreeType queryTree(querySet);
+  Timer::Stop("tree_building");
+
+  Search(&queryTree, k, indices, kernels);
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
+                                           const size_t k,
+                                           arma::Mat<size_t>& indices,
+                                           arma::mat& kernels)
+{
+  // If either naive mode or single mode is specified, this must fail.
+  if (naive || singleMode)
   {
-    // The user passed in a reference tree which we needed to copy.
-    if (queryTree)
-      delete queryTree;
+    throw std::invalid_argument("can't call Search() with a query tree when "
+        "single mode or naive search is enabled");
   }
+
+  // No remapping will be necessary because we are using the cover tree.
+  indices.set_size(k, queryTree->Dataset().n_cols);
+  kernels.set_size(k, queryTree->Dataset().n_cols);
+  kernels.fill(-DBL_MAX);
+
+  Timer::Start("computing_products");
+  typedef FastMKSRules<KernelType, TreeType> RuleType;
+  RuleType rules(referenceSet, queryTree->Dataset(), indices, kernels,
+      metric.Kernel());
+
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+  traverser.Traverse(*queryTree, *referenceTree);
+
+  Log::Info << rules.BaseCases() << " base cases." << std::endl;
+  Log::Info << rules.Scores() << " scores." << std::endl;
+
+  Timer::Stop("computing_products");
 }
 
 template<typename KernelType, typename TreeType>
 void FastMKS<KernelType, TreeType>::Search(const size_t k,
                                            arma::Mat<size_t>& indices,
-                                           arma::mat& products)
+                                           arma::mat& kernels)
 {
   // No remapping will be necessary because we are using the cover tree.
-  indices.set_size(k, querySet.n_cols);
-  products.set_size(k, querySet.n_cols);
-  products.fill(-DBL_MAX);
-
   Timer::Start("computing_products");
+  indices.set_size(k, referenceSet.n_cols);
+  kernels.set_size(k, referenceSet.n_cols);
+  kernels.fill(-DBL_MAX);
 
   // Naive implementation.
   if (naive)
   {
     // Simple double loop.  Stupid, slow, but a good benchmark.
-    for (size_t q = 0; q < querySet.n_cols; ++q)
+    for (size_t q = 0; q < referenceSet.n_cols; ++q)
     {
       for (size_t r = 0; r < referenceSet.n_cols; ++r)
       {
-        if ((&querySet == &referenceSet) && (q == r))
-          continue;
+        if (q == r)
+          continue; // Don't return the point as its own candidate.
 
-        const double eval = metric.Kernel().Evaluate(querySet.col(q),
+        const double eval = metric.Kernel().Evaluate(referenceSet.col(q),
                                                      referenceSet.col(r));
 
         size_t insertPosition;
         for (insertPosition = 0; insertPosition < indices.n_rows;
             ++insertPosition)
-          if (eval > products(insertPosition, q))
+          if (eval > kernels(insertPosition, q))
             break;
 
         if (insertPosition < indices.n_rows)
-          InsertNeighbor(indices, products, q, insertPosition, r, eval);
+          InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
       }
     }
 
@@ -227,16 +232,17 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
   }
 
   // Single-tree implementation.
-  if (single)
+  if (singleMode)
   {
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
     typedef FastMKSRules<KernelType, TreeType> RuleType;
-    RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+    RuleType rules(referenceSet, referenceSet, indices, kernels,
+        metric.Kernel());
 
     typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
 
-    for (size_t i = 0; i < querySet.n_cols; ++i)
+    for (size_t i = 0; i < referenceSet.n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
 
     // Save the number of pruned nodes.
@@ -252,21 +258,9 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
   }
 
   // Dual-tree implementation.
-  typedef FastMKSRules<KernelType, TreeType> RuleType;
-  RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
-
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
-  traverser.Traverse(*queryTree, *referenceTree);
-
-  const size_t numPrunes = traverser.NumPrunes();
-
-  Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-  Log::Info << rules.BaseCases() << " base cases." << std::endl;
-  Log::Info << rules.Scores() << " scores." << std::endl;
-
   Timer::Stop("computing_products");
-  return;
+
+  Search(referenceTree, k, indices, kernels);
 }
 
 /**
@@ -309,229 +303,13 @@ std::string FastMKS<KernelType, TreeType>::ToString() const
   std::ostringstream convert;
   convert << "FastMKS [" << this << "]" << std::endl;
   convert << "  Naive: " << naive << std::endl;
-  convert << "  Single: " << single << std::endl;
+  convert << "  Single: " << singleMode << std::endl;
   convert << "  Metric: " << std::endl;
   convert << mlpack::util::Indent(metric.ToString(),2);
   convert << std::endl;
   return convert.str();
 }
 
-// Specialized implementation for tighter bounds for Gaussian.
-/*
-template<>
-void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
-                                             arma::Mat<size_t>& indices,
-                                             arma::mat& products)
-{
-  Log::Warn << "Alternate implementation!" << std::endl;
-
-  // Terrible copypasta.  Bad bad bad.
-  // No remapping will be necessary.
-  indices.set_size(k, querySet.n_cols);
-  products.set_size(k, querySet.n_cols);
-  products.fill(-1.0);
-
-  Timer::Start("computing_products");
-
-  size_t kernelEvaluations = 0;
-
-  // Naive implementation.
-  if (naive)
-  {
-    // Simple double loop.  Stupid, slow, but a good benchmark.
-    for (size_t q = 0; q < querySet.n_cols; ++q)
-    {
-      for (size_t r = 0; r < referenceSet.n_cols; ++r)
-      {
-        const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
-            referenceSet.unsafe_col(r));
-        ++kernelEvaluations;
-
-        size_t insertPosition;
-        for (insertPosition = 0; insertPosition < indices.n_rows;
-            ++insertPosition)
-          if (eval > products(insertPosition, q))
-            break;
-
-        if (insertPosition < indices.n_rows)
-          InsertNeighbor(indices, products, q, insertPosition, r, eval);
-      }
-    }
-
-    Timer::Stop("computing_products");
-
-    Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
-    return;
-  }
-
-  // Single-tree implementation.
-  if (single)
-  {
-    // Calculate number of pruned nodes.
-    size_t numPrunes = 0;
-
-    // Precalculate query products ( || q || for all q).
-    arma::vec queryProducts(querySet.n_cols);
-    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
-      queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
-          querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
-    kernelEvaluations += querySet.n_cols;
-
-    // Screw the CoverTreeTraverser, we'll implement it by hand.
-    for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
-    {
-      // Use an array of priority queues?
-      std::priority_queue<
-          SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >,
-          std::vector<SearchFrame<tree::CoverTree<IPMetric<
-              kernel::GaussianKernel> > > >,
-          SearchFrameCompare<tree::CoverTree<IPMetric<
-              kernel::GaussianKernel> > > >
-          frameQueue;
-
-      // Add initial frame.
-      SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
-          nextFrame;
-      nextFrame.node = referenceTree;
-      nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
-          referenceSet.unsafe_col(referenceTree->Point()));
-      Log::Assert(nextFrame.eval <= 1);
-      ++kernelEvaluations;
-
-      // The initial evaluation will be the best so far.
-      indices(0, queryIndex) = referenceTree->Point();
-      products(0, queryIndex) = nextFrame.eval;
-
-      frameQueue.push(nextFrame);
-
-      tree::CoverTree<IPMetric<kernel::GaussianKernel> >* referenceNode;
-      double eval;
-      double maxProduct;
-
-      while (!frameQueue.empty())
-      {
-        // Get the information for this node.
-        const SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >&
-            frame = frameQueue.top();
-
-        referenceNode = frame.node;
-        eval = frame.eval;
-
-        // Loop through the children, seeing if we can prune them; if not, add
-        // them to the queue.  The self-child is different -- it has the same
-        // parent (and therefore the same kernel evaluation).
-        if (referenceNode->NumChildren() > 0)
-        {
-          SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
-              childFrame;
-
-          // We must handle the self-child differently, to avoid adding it to
-          // the results twice.
-          childFrame.node = &(referenceNode->Child(0));
-          childFrame.eval = eval;
-
-          // Alternate pruning rule.
-          const double mdd = childFrame.node->FurthestDescendantDistance();
-          if (eval >= (1 - std::pow(mdd, 2.0) / 2.0))
-            maxProduct = 1;
-          else
-            maxProduct = eval * (1 - std::pow(mdd, 2.0) / 2.0) + sqrt(1 -
-                std::pow(eval, 2.0)) * mdd * sqrt(1 - std::pow(mdd, 2.0) / 4.0);
-
-          // Add self-child if we can't prune it.
-          if (maxProduct > products(products.n_rows - 1, queryIndex))
-          {
-            // But only if it has children of its own.
-            if (childFrame.node->NumChildren() > 0)
-              frameQueue.push(childFrame);
-          }
-          else
-            ++numPrunes;
-
-          for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
-          {
-            // Before we evaluate the child, let's see if it can possibly have
-            // a better evaluation.
-            const double mpdd = std::min(
-                referenceNode->Child(i).ParentDistance() +
-                referenceNode->Child(i).FurthestDescendantDistance(), 2.0);
-            double maxChildEval = 1;
-            if (eval < (1 - std::pow(mpdd, 2.0) / 2.0))
-              maxChildEval = eval * (1 - std::pow(mpdd, 2.0) / 2.0) + sqrt(1 -
-                  std::pow(eval, 2.0)) * mpdd * sqrt(1 - std::pow(mpdd, 2.0)
-                  / 4.0);
-
-            if (maxChildEval > products(products.n_rows - 1, queryIndex))
-            {
-              // Evaluate child.
-              childFrame.node = &(referenceNode->Child(i));
-              childFrame.eval = metric.Kernel().Evaluate(
-                  querySet.unsafe_col(queryIndex),
-                  referenceSet.unsafe_col(referenceNode->Child(i).Point()));
-              ++kernelEvaluations;
-
-              // Can we prune it?  If we can, we can avoid putting it in the
-              // queue (saves time).
-              const double cmdd = childFrame.node->FurthestDescendantDistance();
-              if (childFrame.eval > (1 - std::pow(cmdd, 2.0) / 2.0))
-                maxProduct = 1;
-              else
-                maxProduct = childFrame.eval * (1 - std::pow(cmdd, 2.0) / 2.0)
-                    + sqrt(1 - std::pow(eval, 2.0)) * cmdd * sqrt(1 -
-                    std::pow(cmdd, 2.0) / 4.0);
-
-              if (maxProduct > products(products.n_rows - 1, queryIndex))
-              {
-                // Good enough to recurse into.  While we're at it, check the
-                // actual evaluation and see if it's an improvement.
-                if (childFrame.eval > products(products.n_rows - 1, queryIndex))
-                {
-                  // This is a better result.  Find out where to insert.
-                  size_t insertPosition = 0;
-                  for ( ; insertPosition < products.n_rows - 1;
-                      ++insertPosition)
-                    if (childFrame.eval > products(insertPosition, queryIndex))
-                      break;
-
-                  // Insert into the correct position; we are guaranteed that
-                  // insertPosition is valid.
-                  InsertNeighbor(indices, products, queryIndex, insertPosition,
-                      childFrame.node->Point(), childFrame.eval);
-                }
-
-                // Now add this to the queue (if it has any children which may
-                // need to be recursed into).
-                if (childFrame.node->NumChildren() > 0)
-                  frameQueue.push(childFrame);
-              }
-              else
-                ++numPrunes;
-            }
-            else
-              ++numPrunes;
-          }
-        }
-
-        frameQueue.pop();
-      }
-    }
-
-    Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
-    Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
-        << std::endl;
-    Log::Info << "Distance evaluations: " << distanceEvaluations << "."
-        << std::endl;
-
-    Timer::Stop("computing_products");
-    return;
-  }
-
-  // Double-tree implementation.
-  Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
-
-}
-*/
-
 }; // namespace fastmks
 }; // namespace mlpack
 
diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index abb7708..1e0dc94 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -28,12 +28,12 @@ PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
     "'kernels.csv' and the indices are stored in 'indices.csv'."
     "\n\n"
     "$ fastmks --k 5 --reference_file reference.csv --query_file query.csv\n"
-    "  --indices_file indices.csv --products_file kernels.csv --kernel linear"
+    "  --indices_file indices.csv --kernels_file kernels.csv --kernel linear"
     "\n\n"
     "The output files are organized such that row i and column j in the indices"
     " output file corresponds to the index of the point in the reference set "
     "that has i'th largest kernel evaluation with the point in the query set "
-    "with index j.  Row i and column j in the products output file corresponds "
+    "with index j.  Row i and column j in the kernels output file corresponds "
     "to the kernel evaluation between those two points."
     "\n\n"
     "This executable performs FastMKS using a cover tree.  The base used to "
@@ -44,10 +44,10 @@ PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
     "r");
 PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
 
-PARAM_INT_REQ("k", "Number of maximum inner products to find.", "k");
+PARAM_INT_REQ("k", "Number of maximum kernels to find.", "k");
 
-PARAM_STRING("products_file", "File to save inner products into.", "p", "");
-PARAM_STRING("indices_file", "File to save indices of inner products into.",
+PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
+PARAM_STRING("indices_file", "File to save indices of kernels into.",
     "i", "");
 
 PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
@@ -76,20 +76,29 @@ void RunFastMKS(const arma::mat& referenceData,
                 const double base,
                 const size_t k,
                 arma::Mat<size_t>& indices,
-                arma::mat& products,
+                arma::mat& kernels,
                 KernelType& kernel)
 {
-  // Create the tree with the specified base.
-  typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
-      TreeType;
-  IPMetric<KernelType> metric(kernel);
-  TreeType tree(referenceData, metric, base);
+  if (naive)
+  {
+    // No need for trees.
+    FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
+    fastmks.Search(k, indices, kernels);
+  }
+  else
+  {
+    // Create the tree with the specified base.
+    typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+        TreeType;
+    IPMetric<KernelType> metric(kernel);
+    TreeType tree(referenceData, metric, base);
 
-  // Create FastMKS object.
-  FastMKS<KernelType> fastmks(referenceData, &tree, (single && !naive), naive);
+    // Create FastMKS object.
+    FastMKS<KernelType> fastmks(&tree, single);
 
-  // Now search with it.
-  fastmks.Search(k, indices, products);
+    // Now search with it.
+    fastmks.Search(k, indices, kernels);
+  }
 }
 
 //! Run FastMKS for a given query and reference set using the given kernel type.
@@ -101,22 +110,30 @@ void RunFastMKS(const arma::mat& referenceData,
                 const double base,
                 const size_t k,
                 arma::Mat<size_t>& indices,
-                arma::mat& products,
+                arma::mat& kernels,
                 KernelType& kernel)
 {
-  // Create the tree with the specified base.
-  typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
-      TreeType;
-  IPMetric<KernelType> metric(kernel);
-  TreeType referenceTree(referenceData, metric, base);
-  TreeType queryTree(queryData, metric, base);
-
-  // Create FastMKS object.
-  FastMKS<KernelType> fastmks(referenceData, &referenceTree, queryData,
-      &queryTree, (single && !naive), naive);
-
-  // Now search with it.
-  fastmks.Search(k, indices, products);
+  if (naive)
+  {
+    // No need for trees.
+    FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
+    fastmks.Search(queryData, k, indices, kernels);
+  }
+  else
+  {
+    // Create the tree with the specified base.
+    typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+        TreeType;
+    IPMetric<KernelType> metric(kernel);
+    TreeType referenceTree(referenceData, metric, base);
+    TreeType queryTree(queryData, metric, base);
+
+    // Create FastMKS object.
+    FastMKS<KernelType> fastmks(&referenceTree, single);
+
+    // Now search with it.
+    fastmks.Search(&queryTree, k, indices, kernels);
+  }
 }
 
 int main(int argc, char** argv)
@@ -194,7 +211,7 @@ int main(int argc, char** argv)
 
   // Matrices for output storage.
   arma::Mat<size_t> indices;
-  arma::mat products;
+  arma::mat kernels;
 
   // Construct FastMKS object.
   if (queryData.n_elem == 0)
@@ -203,44 +220,44 @@ int main(int argc, char** argv)
     {
       LinearKernel lk;
       RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
-          products, lk);
+          kernels, lk);
     }
     else if (kernelType == "polynomial")
     {
 
       PolynomialKernel pk(degree, offset);
       RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
-          indices, products, pk);
+          indices, kernels, pk);
     }
     else if (kernelType == "cosine")
     {
       CosineDistance cd;
       RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
-          products, cd);
+          kernels, cd);
     }
     else if (kernelType == "gaussian")
     {
       GaussianKernel gk(bandwidth);
       RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
-          products, gk);
+          kernels, gk);
     }
     else if (kernelType == "epanechnikov")
     {
       EpanechnikovKernel ek(bandwidth);
       RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
-          indices, products, ek);
+          indices, kernels, ek);
     }
     else if (kernelType == "triangular")
     {
       TriangularKernel tk(bandwidth);
       RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
-          indices, products, tk);
+          indices, kernels, tk);
     }
     else if (kernelType == "hyptan")
     {
       HyperbolicTangentKernel htk(scale, offset);
       RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
-          indices, products, htk);
+          indices, kernels, htk);
     }
   }
   else
@@ -249,51 +266,51 @@ int main(int argc, char** argv)
     {
       LinearKernel lk;
       RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
-          indices, products, lk);
+          indices, kernels, lk);
     }
     else if (kernelType == "polynomial")
     {
       PolynomialKernel pk(degree, offset);
       RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
-          base, k, indices, products, pk);
+          base, k, indices, kernels, pk);
     }
     else if (kernelType == "cosine")
     {
       CosineDistance cd;
       RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
-          k, indices, products, cd);
+          k, indices, kernels, cd);
     }
     else if (kernelType == "gaussian")
     {
       GaussianKernel gk(bandwidth);
       RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
-          k, indices, products, gk);
+          k, indices, kernels, gk);
     }
     else if (kernelType == "epanechnikov")
     {
       EpanechnikovKernel ek(bandwidth);
       RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
-          base, k, indices, products, ek);
+          base, k, indices, kernels, ek);
     }
     else if (kernelType == "triangular")
     {
       TriangularKernel tk(bandwidth);
       RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
-          base, k, indices, products, tk);
+          base, k, indices, kernels, tk);
     }
     else if (kernelType == "hyptan")
     {
       HyperbolicTangentKernel htk(scale, offset);
       RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
-          naive, base, k, indices, products, htk);
+          naive, base, k, indices, kernels, htk);
     }
   }
 
   // Save output, if we were asked to.
-  if (CLI::HasParam("products_file"))
+  if (CLI::HasParam("kernels_file"))
   {
-    const string productsFile = CLI::GetParam<string>("products_file");
-    data::Save(productsFile, products, false);
+    const string kernelsFile = CLI::GetParam<string>("kernels_file");
+    data::Save(kernelsFile, kernels, false);
   }
 
   if (CLI::HasParam("indices_file"))



More information about the mlpack-git mailing list