[mlpack-svn] r14966 - mlpack/trunk/src/mlpack/methods/fastmks

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Apr 26 23:10:48 EDT 2013


Author: rcurtin
Date: 2013-04-26 23:10:47 -0400 (Fri, 26 Apr 2013)
New Revision: 14966

Modified:
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
   mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
Log:
Update FastMKS API and do some documentation.  Add some new constructors for
flexibility.


Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp	2013-04-27 03:05:40 UTC (rev 14965)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp	2013-04-27 03:10:47 UTC (rev 14966)
@@ -2,7 +2,8 @@
  * @file fastmks.hpp
  * @author Ryan Curtin
  *
- * Definition of the FastMKS class, which is the fast max-kernel search.
+ * Definition of the FastMKS class, which implements fast exact max-kernel
+ * search.
  */
 #ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
 #define __MLPACK_METHODS_FASTMKS_FASTMKS_HPP
@@ -15,44 +16,192 @@
 namespace mlpack {
 namespace fastmks {
 
-template<typename KernelType>
+/**
+ * An implementation of fast exact max-kernel search.  Given a query dataset and
+ * a reference dataset (or optionally just a reference dataset which is also
+ * used as the query dataset), fast exact max-kernel search finds, for each
+ * point in the query dataset, the k points in the reference set with maximum
+ * kernel value K(p_q, p_r), where k is a specified parameter and K() is a
+ * Mercer kernel.
+ *
+ * For more information, see the following paper.
+ *
+ * @code
+ * @inproceedings{curtin2013fast,
+ *   title={Fast Exact Max-Kernel Search},
+ *   author={Curtin, Ryan R. and Ram, Parikshit and Gray, Alexander G.},
+ *   booktitle={Proceedings of the 2013 SIAM International Conference on Data
+ *       Mining (SDM 13)},
+ *   year={2013}
+ * }
+ * @endcode
+ *
+ * This class allows specification of the type of kernel and also of the type of
+ * tree.  FastMKS can be run on kernels that work on arbitrary objects --
+ * however, this only works with cover trees and other trees that are built only
+ * on points in the dataset (and not centroids of regions or anything like
+ * that).
+ *
+ * @tparam KernelType Type of kernel to run FastMKS with.
+ * @tparam TreeType Type of tree to run FastMKS with; it must have metric
+ *     IPMetric<KernelType>.
+ */
+template<
+    typename KernelType,
+    typename TreeType = tree::CoverTree<IPMetric<KernelType>,
+        tree::FirstPointIsRoot, FastMKSStat>
+>
 class FastMKS
 {
  public:
+  /**
+   * Create the FastMKS object using the reference set as the query set.
+   * Optionally, specify whether or not single-tree search or naive
+   * (brute-force) search should be used.
+   *
+   * @param referenceSet Set of data to run FastMKS on.
+   * @param single Whether or not to run single-tree search.
+   * @param naive Whether or not to run brute-force (naive) search.
+   */
   FastMKS(const arma::mat& referenceSet,
+          const bool single = false,
+          const bool naive = false);
+
+  /**
+   * Create the FastMKS object using separate reference and query sets.
+   * Optionally, specify whether or not single-tree search or naive
+   * (brute-force) search should be used.
+   *
+   * @param referenceSet Reference set of data for FastMKS.
+   * @param querySet Set of query points for FastMKS.
+   * @param single Whether or not to run single-tree search.
+   * @param naive Whether or not to run brute-force (naive) search.
+   */
+  FastMKS(const arma::mat& referenceSet,
+          const arma::mat& querySet,
+          const bool single = false,
+          const bool naive = false);
+
+  /**
+   * Create the FastMKS object using the reference set as the query set, and
+   * with an initialized kernel.  This is useful for when the kernel stores
+   * state.  Optionally, specify whether or not single-tree search or naive
+   * (brute-force) search should be used.
+   *
+   * @param referenceSet Reference set of data for FastMKS.
+   * @param kernel Initialized kernel.
+   * @param single Whether or not to run single-tree search.
+   * @param naive Whether or not to run brute-force (naive) search.
+   */
+  FastMKS(const arma::mat& referenceSet,
           KernelType& kernel,
-          bool single = false,
-          bool naive = false,
-          double expansionConstant = 2.0);
+          const bool single = false,
+          const bool naive = false);
 
+  /**
+   * Create the FastMKS object using separate reference and query sets, and with
+   * an initialized kernel.  This is useful for when the kernel stores state.
+   * Optionally, specify whether or not single-tree search or naive
+   * (brute-force) search should be used.
+   *
+   * @param referenceSet Reference set of data for FastMKS.
+   * @param querySet Set of query points for FastMKS.
+   * @param kernel Initialized kernel.
+   * @param single Whether or not to run single-tree search.
+   * @param naive Whether or not to run brute-force (naive) search.
+   */
   FastMKS(const arma::mat& referenceSet,
           const arma::mat& querySet,
           KernelType& kernel,
-          bool single = false,
-          bool naive = false,
-          double expansionConstant = 2.0);
+          const bool single = false,
+          const bool naive = false);
 
+  /**
+   * Create the FastMKS object with an already-initialized tree built on the
+   * reference points.  Be sure that the tree is built with the metric type
+   * IPMetric<KernelType>.  For this constructor, the reference set and the
+   * query set are the same points.  Optionally, whether or not to run
+   * single-tree search or brute-force (naive) search can be specified.
+   *
+   * @param referenceSet Reference set of data for FastMKS.
+   * @param referenceTree Tree built on reference data.
+   * @param kernel Initialized kernel.
+   * @param single Whether or not to run single-tree search.
+   * @param naive Whether or not to run brute-force (naive) search.
+   */
+  FastMKS(const arma::mat& referenceSet,
+          TreeType* referenceTree,
+          const bool single = false,
+          const bool naive = false);
+
+  /**
+   * Create the FastMKS object with already-initialized trees built on the
+   * reference and query points.  Be sure that the trees are built with the
+   * metric type IPMetric<KernelType>.  Optionally, whether or not to run
+   * single-tree search or naive (brute-force) search can be specified.
+   *
+   * @param referenceSet Reference set of data for FastMKS.
+   * @param referenceTree Tree built on reference data.
+   * @param querySet Set of query points for FastMKS.
+   * @param queryTree Tree built on query data.
+   * @param kernel Initialized kernel.
+   * @param single Whether or not to use single-tree search.
+   * @param naive Whether or not to use naive (brute-force) search.
+   */
+  FastMKS(const arma::mat& referenceSet,
+          TreeType* referenceTree,
+          const arma::mat& querySet,
+          TreeType* queryTree,
+          const bool single = false,
+          const bool naive = false);
+
+  //! Destructor for the FastMKS object.
   ~FastMKS();
 
+  /**
+   * Search for the maximum inner products of the query set (or if no query set
+   * was passed, the reference set is used).  The resulting maximum inner
+   * products are stored in the products matrix and the corresponding point
+   * indices are stores in the indices matrix.  The results for each point in
+   * the query set are stored in the corresponding column of the indices and
+   * products matrices; for instance, the index of the point with maximum inner
+   * product to point 4 in the query set will be stored in row 0 and column 4 of
+   * the indices matrix.
+   *
+   * @param k The number of maximum kernels to find.
+   * @param indices Matrix to store resulting indices of max-kernel search in.
+   * @param products Matrix to store resulting max-kernel values in.
+   */
   void Search(const size_t k,
               arma::Mat<size_t>& indices,
               arma::mat& products);
 
+  //! Get the inner-product metric induced by the given kernel.
+  const IPMetric<KernelType>& Metric() const { return metric; }
+  //! Modify the inner-product metric induced by the given kernel.
+  IPMetric<KernelType>& Metric() { return metric; }
+
  private:
+  //! The reference dataset.
   const arma::mat& referenceSet;
-
+  //! The query dataset.
   const arma::mat& querySet;
 
-  tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
-      referenceTree;
+  //! The tree built on the reference dataset.
+  TreeType* referenceTree;
+  //! The tree built on the query dataset.  This is NULL if there is no query
+  //! set.
+  TreeType* queryTree;
 
-  tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
-      queryTree;
+  //! If true, this object created the trees and is responsible for them.
+  bool treeOwner;
 
+  //! If true, single-tree search is used.
   bool single;
-
+  //! If true, naive (brute-force) search is used.
   bool naive;
 
+  //! The instantiated inner-product metric induced by the given kernel.
   IPMetric<KernelType> metric;
 
   // Utility function.  Copied too many times from too many places.

Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	2013-04-27 03:05:40 UTC (rev 14965)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp	2013-04-27 03:10:47 UTC (rev 14966)
@@ -18,80 +18,161 @@
 namespace mlpack {
 namespace fastmks {
 
-template<typename KernelType>
-FastMKS<KernelType>::FastMKS(const arma::mat& referenceSet,
-                             KernelType& kernel,
-                             bool single,
-                             bool naive,
-                             double expansionConstant) :
+// Single dataset, no instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       const bool single,
+                                       const bool naive) :
     referenceSet(referenceSet),
-    querySet(referenceSet), // Gotta point it somewhere...
+    querySet(referenceSet),
+    referenceTree(NULL),
     queryTree(NULL),
+    treeOwner(true),
     single(single),
+    naive(naive)
+{
+  Timer::Start("tree_building");
+
+  if (!naive)
+    referenceTree = new TreeType(referenceSet);
+
+  Timer::Stop("tree_building");
+}
+
+// Two datasets, no instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       const arma::mat& querySet,
+                                       const bool single,
+                                       const bool naive) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    referenceTree(NULL),
+    queryTree(NULL),
+    treeOwner(true),
+    single(single),
+    naive(naive)
+{
+  Timer::Start("tree_building");
+
+  // If necessary, the trees should be built.
+  if (!naive)
+    referenceTree = new TreeType(referenceSet);
+
+  if (!naive && !single)
+    queryTree = new TreeType(querySet);
+
+  Timer::Stop("tree_building");
+}
+
+// One dataset, instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       KernelType& kernel,
+                                       const bool single,
+                                       const bool naive) :
+    referenceSet(referenceSet),
+    querySet(referenceSet),
+    referenceTree(NULL),
+    queryTree(NULL),
+    treeOwner(true),
+    single(single),
     naive(naive),
     metric(kernel)
 {
-
   Timer::Start("tree_building");
 
-  // Build the tree.  Could we do this in the initialization list?
-  if (naive)
-    referenceTree = NULL;
-  else
-    referenceTree = new tree::CoverTree<IPMetric<KernelType>,
-        tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
-        &metric);
+  // If necessary, the reference tree should be built.  There is no query tree.
+  if (!naive)
+    referenceTree = new TreeType(referenceSet, metric);
 
   Timer::Stop("tree_building");
 }
 
-template<typename KernelType>
-FastMKS<KernelType>::FastMKS(const arma::mat& referenceSet,
-                             const arma::mat& querySet,
-                             KernelType& kernel,
-                             bool single,
-                             bool naive,
-                             double expansionConstant) :
+// Two datasets, instantiated kernel.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       const arma::mat& querySet,
+                                       KernelType& kernel,
+                                       const bool single,
+                                       const bool naive) :
     referenceSet(referenceSet),
     querySet(querySet),
+    referenceTree(NULL),
+    queryTree(NULL),
+    treeOwner(true),
     single(single),
     naive(naive),
     metric(kernel)
 {
   Timer::Start("tree_building");
 
-  // Build the trees.  Could we do this in the initialization lists?
-  if (naive)
-    referenceTree = NULL;
-  else
-    referenceTree = new tree::CoverTree<IPMetric<KernelType>,
-        tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
-        &metric);
+  // If necessary, the trees should be built.
+  if (!naive)
+    referenceTree = new TreeType(referenceSet, metric);
 
-  if (single || naive)
-    queryTree = NULL;
-  else
-    queryTree = new tree::CoverTree<IPMetric<KernelType>,
-        tree::FirstPointIsRoot, FastMKSStat>(querySet, expansionConstant,
-        &metric);
+  if (!naive && !single)
+    queryTree = new TreeType(querySet, metric);
 
   Timer::Stop("tree_building");
 }
 
-template<typename KernelType>
-FastMKS<KernelType>::~FastMKS()
+// One dataset, pre-built tree.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       TreeType* referenceTree,
+                                       const bool single,
+                                       const bool naive) :
+    referenceSet(referenceSet),
+    querySet(referenceSet),
+    referenceTree(referenceTree),
+    queryTree(NULL),
+    treeOwner(false),
+    single(single),
+    naive(naive),
+    metric(referenceTree->Metric())
 {
-  if (queryTree)
-    delete queryTree;
-  if (referenceTree)
-    delete referenceTree;
+  // Nothing to do.
 }
 
-template<typename KernelType>
-void FastMKS<KernelType>::Search(const size_t k,
-                               arma::Mat<size_t>& indices,
-                               arma::mat& products)
+// Two datasets, pre-built trees.
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+                                       TreeType* referenceTree,
+                                       const arma::mat& querySet,
+                                       TreeType* queryTree,
+                                       const bool single,
+                                       const bool naive) :
+    referenceSet(referenceSet),
+    querySet(querySet),
+    referenceTree(referenceTree),
+    queryTree(queryTree),
+    treeOwner(false),
+    single(single),
+    naive(naive),
+    metric(referenceTree->Metric())
 {
+  // Nothing to do.
+}
+
+template<typename KernelType, typename TreeType>
+FastMKS<KernelType, TreeType>::~FastMKS()
+{
+  // If we created the trees, we must delete them.
+  if (treeOwner)
+  {
+    if (queryTree)
+      delete queryTree;
+    if (referenceTree)
+      delete referenceTree;
+  }
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::Search(const size_t k,
+                                           arma::Mat<size_t>& indices,
+                                           arma::mat& products)
+{
   // No remapping will be necessary because we are using the cover tree.
   indices.set_size(k, querySet.n_cols);
   products.set_size(k, querySet.n_cols);
@@ -139,8 +220,6 @@
   {
     // Create rules object (this will store the results).  This constructor
     // precalculates each self-kernel value.
-    typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
-        FastMKSStat> TreeType;
     typedef FastMKSRules<KernelType, TreeType> RuleType;
     RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
 
@@ -163,8 +242,6 @@
   }
 
   // Dual-tree implementation.
-  typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
-      FastMKSStat> TreeType;
   typedef FastMKSRules<KernelType, TreeType> RuleType;
   RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
 
@@ -194,13 +271,13 @@
  * @param neighbor Index of reference point which is being inserted.
  * @param distance Distance from query point to reference point.
  */
-template<typename KernelType>
-void FastMKS<KernelType>::InsertNeighbor(arma::Mat<size_t>& indices,
-                                         arma::mat& products,
-                                         const size_t queryIndex,
-                                         const size_t pos,
-                                         const size_t neighbor,
-                                         const double distance)
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
+                                                   arma::mat& products,
+                                                   const size_t queryIndex,
+                                                   const size_t pos,
+                                                   const size_t neighbor,
+                                                   const double distance)
 {
   // We only memmove() if there is actually a need to shift something.
   if (pos < (products.n_rows - 1))




More information about the mlpack-svn mailing list