[mlpack-git] master: Create a new class SpillSearch that encapsulates an instance of NeighborSearch class, and adds the functionality to deal with spill trees. (e95e3b4)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:22 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit e95e3b429601a964c4fd31e7f073931e4e079781
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 28 13:55:18 2016 -0300

    Create a new class SpillSearch that encapsulates an instance of NeighborSearch class, and adds the functionality to deal with spill trees.


>---------------------------------------------------------------

e95e3b429601a964c4fd31e7f073931e4e079781
 src/mlpack/methods/neighbor_search/CMakeLists.txt  |   2 +
 .../methods/neighbor_search/neighbor_search.hpp    |  19 +-
 .../neighbor_search/neighbor_search_impl.hpp       |   5 +-
 .../{neighbor_search.hpp => spill_search.hpp}      | 221 ++++++++-------------
 .../methods/neighbor_search/spill_search_impl.hpp  | 213 ++++++++++++++++++++
 5 files changed, 321 insertions(+), 139 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 95fe37b..e4c76e1 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -14,6 +14,8 @@ set(SOURCES
   sort_policies/nearest_neighbor_sort_impl.hpp
   sort_policies/furthest_neighbor_sort.hpp
   sort_policies/furthest_neighbor_sort_impl.hpp
+  spill_search.hpp
+  spill_search_impl.hpp
   typedef.hpp
   unmap.hpp
   unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index f1acea4..e25c93a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,6 +27,12 @@ namespace neighbor /** Neighbor-search routines.  These include
                     * searches. */ {
 
 // Forward declaration.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+class SpillSearch;
+
+// Forward declaration.
 template<typename SortPolicy>
 class TrainVisitor;
 
@@ -237,11 +243,14 @@ class NeighborSearch
    * @param neighbors Matrix storing lists of neighbors for each query point.
    * @param distances Matrix storing distances of neighbors for each query
    *      point.
+   * @param sameSet Denotes whether or not the reference and query sets are the
+   *      same.
    */
   void Search(Tree* queryTree,
               const size_t k,
               arma::Mat<size_t>& neighbors,
-              arma::mat& distances);
+              arma::mat& distances,
+              bool sameSet = false);
 
   /**
    * Search for the nearest neighbors of every point in the reference set.  This
@@ -323,7 +332,13 @@ class NeighborSearch
   bool treeNeedsReset;
 
   //! The NSModel class should have access to internal members.
-  friend class TrainVisitor<SortPolicy>;
+  template<typename SortPol>
+  friend class TrainVisitor;
+
+  template<typename MetricT,
+           typename MatT,
+           template<typename SplitBoundT, typename SplitMatT> class SplitType>
+  friend class SpillSearch;
 }; // class NeighborSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 79a31d0..4fd073d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -506,7 +506,8 @@ void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 Search(Tree* queryTree,
        const size_t k,
        arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
+       arma::mat& distances,
+       bool sameSet)
 {
   if (k > referenceSet->n_cols)
   {
@@ -540,7 +541,7 @@ Search(Tree* queryTree,
 
   // Create the helper object for the traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+  RuleType rules(*referenceSet, querySet, k, metric, epsilon, sameSet);
 
   // Create the traverser.
   TraversalType<RuleType> traverser(rules);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
similarity index 54%
copy from src/mlpack/methods/neighbor_search/neighbor_search.hpp
copy to src/mlpack/methods/neighbor_search/spill_search.hpp
index f1acea4..8490062 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -1,177 +1,149 @@
 /**
- * @file neighbor_search.hpp
+ * @file spill_search.hpp
  * @author Ryan Curtin
+ * @author Marcos Pividori
  *
- * Defines the NeighborSearch class, which performs an abstract
- * nearest-neighbor-like query on two datasets.
+ * Defines the SpillSearch class, which performs a Hybrid sp-tree search on
+ * two datasets.
  */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
 
 #include <mlpack/core.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <mlpack/core/tree/rectangle_tree.hpp>
-#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
-
 #include <mlpack/core/metrics/lmetric.hpp>
-#include "neighbor_search_stat.hpp"
 #include "sort_policies/nearest_neighbor_sort.hpp"
-#include "neighbor_search_rules.hpp"
+#include "neighbor_search.hpp"
 
 namespace mlpack {
-namespace neighbor /** Neighbor-search routines.  These include
-                    * all-nearest-neighbors and all-furthest-neighbors
-                    * searches. */ {
+namespace neighbor {
 
 // Forward declaration.
 template<typename SortPolicy>
 class TrainVisitor;
 
 /**
- * The NeighborSearch class is a template class for performing distance-based
- * neighbor searches.  It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy.  A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
+ * The SpillSearch class is a template class for performing distance-based
+ * neighbor searches with Spill Trees.  It takes a query dataset and a reference
+ * dataset (or just a reference dataset) and, for each point in the query
+ * dataset, finds the k neighbors in the reference dataset which have the 'best'
+ * distance according to a given sorting policy.  A constructor is given which
+ * takes only a reference dataset, and if that constructor is used, the given
+ * reference dataset is also used as the query dataset.
  *
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used.  More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
  * @tparam MetricType The metric to use for computation.
  * @tparam MatType The type of data matrix.
- * @tparam TreeType The tree type to use; must adhere to the TreeType API.
- * @tparam TraversalType The type of traversal to use (defaults to the tree's
- *      default traverser).
+ * @tparam SplitType The class that partitions the dataset/points at a
+ *     particular node into two parts. Its definition decides the way this split
+ *     is done when building spill trees.
  */
-template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::EuclideanDistance,
+template<typename MetricType = mlpack::metric::EuclideanDistance,
          typename MatType = arma::mat,
-         template<typename TreeMetricType,
-                  typename TreeStatType,
-                  typename TreeMatType> class TreeType = tree::KDTree,
-         template<typename RuleType> class TraversalType =
-             TreeType<MetricType,
-                      NeighborSearchStat<SortPolicy>,
-                      MatType>::template DualTreeTraverser>
-class NeighborSearch
+         template<typename SplitBoundT, typename SplitMatT> class SplitType =
+             tree::MidpointSplit>
+class SpillSearch
 {
  public:
   //! Convenience typedef.
-  typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
+  typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
+      MatType, SplitType> Tree;
+
+  template<typename TreeMetricType,
+           typename TreeStatType,
+           typename TreeMatType>
+  using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
+      SplitType>;
 
   /**
-   * Initialize the NeighborSearch object, passing a reference dataset (this is
+   * Initialize the SpillSearch object, passing a reference dataset (this is
    * the dataset which is searched).  Optionally, perform the computation in
    * naive mode or single-tree mode.  An initialized distance metric can be
    * given, for cases where the metric has internal data (i.e. the
    * distance::MahalanobisDistance class).
    *
-   * This method will copy the matrices to internal copies, which are rearranged
-   * during tree-building.  You can avoid this extra copy by pre-constructing
-   * the trees and passing them using a different constructor, or by using the
-   * construct that takes an rvalue reference to the dataset.
-   *
    * @param referenceSet Set of reference points.
    * @param naive If true, O(n^2) naive search will be used (as opposed to
    *      dual-tree search).  This overrides singleMode (if it is set to true).
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
+   * @param tau Overlapping size (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(const MatType& referenceSet,
-                 const bool naive = false,
-                 const bool singleMode = false,
-                 const double epsilon = 0,
-                 const MetricType metric = MetricType());
+  SpillSearch(const MatType& referenceSet,
+              const bool naive = false,
+              const bool singleMode = false,
+              const double tau = 0,
+              const double epsilon = 0,
+              const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object, taking ownership of the reference
+   * Initialize the SpillSearch object, taking ownership of the reference
    * dataset (this is the dataset which is searched).  Optionally, perform the
    * computation in naive mode or single-tree mode.  An initialized distance
    * metric can be given, for cases where the metric has internal data (i.e. the
    * distance::MahalanobisDistance class).
    *
-   * This method will not copy the data matrix, but will take ownership of it,
-   * and depending on the type of tree used, may rearrange the points.  If you
-   * would rather a copy be made, consider using the constructor that takes a
-   * const reference to the data instead.
-   *
    * @param referenceSet Set of reference points.
    * @param naive If true, O(n^2) naive search will be used (as opposed to
    *      dual-tree search).  This overrides singleMode (if it is set to true).
    * @param singleMode If true, single-tree search will be used (as opposed to
    *      dual-tree search).
+   * @param tau Overlapping size (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric An optional instance of the MetricType class.
    */
-  NeighborSearch(MatType&& referenceSet,
-                 const bool naive = false,
-                 const bool singleMode = false,
-                 const double epsilon = 0,
-                 const MetricType metric = MetricType());
+  SpillSearch(MatType&& referenceSet,
+              const bool naive = false,
+              const bool singleMode = false,
+              const double tau = 0,
+              const double epsilon = 0,
+              const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object with the given pre-constructed
+   * Initialize the SpillSearch object with the given pre-constructed
    * reference tree (this is the tree built on the points that will be
    * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
    * available as an option for this constructor.  Additionally, an instantiated
    * distance metric can be given, for cases where the distance metric holds
    * data.
    *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
-   *
-   * @note
-   * Mapping the points of the matrix back to their original indices is not done
-   * when this constructor is used, so if the tree type you are using maps
-   * points (like BinarySpaceTree), then you will have to perform the re-mapping
-   * manually.
-   * @endnote
-   *
    * @param referenceTree Pre-built tree for reference points.
-   * @param referenceSet Set of reference points corresponding to referenceTree.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
+   * @param tau Overlapping size (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated distance metric.
    */
-  NeighborSearch(Tree* referenceTree,
-                 const bool singleMode = false,
-                 const double epsilon = 0,
-                 const MetricType metric = MetricType());
+  SpillSearch(Tree* referenceTree,
+              const bool singleMode = false,
+              const double tau = 0,
+              const double epsilon = 0,
+              const MetricType metric = MetricType());
 
   /**
-   * Create a NeighborSearch object without any reference data.  If Search() is
+   * Create a SpillSearch object without any reference data.  If Search() is
    * called before a reference set is set with Train(), an exception will be
    * thrown.
    *
    * @param naive Whether to use naive search.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
+   * @param tau Overlapping size (non-negative).
    * @param epsilon Relative approximate error (non-negative).
    * @param metric Instantiated metric.
    */
-  NeighborSearch(const bool naive = false,
-                 const bool singleMode = false,
-                 const double epsilon = 0,
-                 const MetricType metric = MetricType());
+  SpillSearch(const bool naive = false,
+              const bool singleMode = false,
+              const double tau = 0,
+              const double epsilon = 0,
+              const MetricType metric = MetricType());
 
 
   /**
-   * Delete the NeighborSearch object. The tree is the only member we are
+   * Delete the SpillSearch object. The tree is the only member we are
    * responsible for deleting.  The others will take care of themselves.
    */
-  ~NeighborSearch();
+  ~SpillSearch();
 
   /**
    * Set the reference set to a new reference set, and build a tree if
@@ -195,6 +167,8 @@ class NeighborSearch
 
   /**
    * Set the reference tree to a new reference tree.
+   *
+   * @param referenceTree Pre-built tree for reference points.
    */
   void Train(Tree* referenceTree);
 
@@ -263,76 +237,53 @@ class NeighborSearch
 
   //! Return the total number of base case evaluations performed during the last
   //! search.
-  size_t BaseCases() const { return baseCases; }
+  size_t BaseCases() const { return neighborSearch.BaseCases(); }
 
   //! Return the number of node combination scores during the last search.
-  size_t Scores() const { return scores; }
+  size_t Scores() const { return neighborSearch.Scores(); }
 
   //! Access whether or not search is done in naive linear scan mode.
-  bool Naive() const { return naive; }
+  bool Naive() const { return neighborSearch.Naive(); }
   //! Modify whether or not search is done in naive linear scan mode.
-  bool& Naive() { return naive; }
+  bool& Naive() { return neighborSearch.Naive(); }
 
   //! Access whether or not search is done in single-tree mode.
-  bool SingleMode() const { return singleMode; }
+  bool SingleMode() const { return neighborSearch.SingleMode(); }
   //! Modify whether or not search is done in single-tree mode.
-  bool& SingleMode() { return singleMode; }
+  bool& SingleMode() { return neighborSearch.SingleMode(); }
 
   //! Access the relative error to be considered in approximate search.
-  double Epsilon() const { return epsilon; }
+  double Epsilon() const { return neighborSearch.Epsilon(); }
   //! Modify the relative error to be considered in approximate search.
-  double& Epsilon() { return epsilon; }
+  double& Epsilon() { return neighborSearch.Epsilon(); }
+
+  //! Access the overlapping size.
+  double Tau() const { return tau; }
 
   //! Access the reference dataset.
-  const MatType& ReferenceSet() const { return *referenceSet; }
+  const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
 
-  //! Serialize the NeighborSearch model.
+  //! Serialize the SpillSearch model.
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
  private:
-  //! Permutations of reference points during tree building.
-  std::vector<size_t> oldFromNewReferences;
-  //! Pointer to the root of the reference tree.
-  Tree* referenceTree;
-  //! Reference dataset.  In some situations we may be the owner of this.
-  const MatType* referenceSet;
-
-  //! If true, this object created the trees and is responsible for them.
-  bool treeOwner;
-  //! If true, we own the reference set.
-  bool setOwner;
+  //! Internal instance of NeighborSearch class.
+  NeighborSearch<NearestNeighborSort, MetricType, MatType, TreeType>
+      neighborSearch;
 
-  //! Indicates if O(n^2) naive search is being used.
-  bool naive;
-  //! Indicates if single-tree search is being used (as opposed to dual-tree).
-  bool singleMode;
-  //! Indicates the relative error to be considered in approximate search.
-  double epsilon;
-
-  //! Instantiation of metric.
-  MetricType metric;
-
-  //! The total number of base cases.
-  size_t baseCases;
-  //! The total number of scores (applicable for non-naive search).
-  size_t scores;
-
-  //! If this is true, the reference tree bounds need to be reset on a call to
-  //! Search() without a query set.
-  bool treeNeedsReset;
+  //! Overlapping size.
+  double tau;
 
   //! The NSModel class should have access to internal members.
-  friend class TrainVisitor<SortPolicy>;
-}; // class NeighborSearch
+  template<typename SortPolicy>
+  friend class TrainVisitor;
+}; // class SpillSearch
 
 } // namespace neighbor
 } // namespace mlpack
 
 // Include implementation.
-#include "neighbor_search_impl.hpp"
-
-// Include convenience typedefs.
-#include "typedef.hpp"
+#include "spill_search_impl.hpp"
 
 #endif
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
new file mode 100644
index 0000000..f1ce70a
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -0,0 +1,213 @@
+/**
+ * @file spill_search_impl.hpp
+ * @author Ryan Curtin
+ * @author Marcos Pividori
+ *
+ * Implementation of SpillSearch class, which performs a Hybrid sp-tree search
+ * on two datasets.
+ */
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "spill_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Construct the object.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+    const MatType& referenceSetIn,
+    const bool naive,
+    const bool singleMode,
+    const double tau,
+    const double epsilon,
+    const MetricType metric) :
+    neighborSearch(naive, singleMode, epsilon, metric),
+    tau(tau)
+{
+  if (tau < 0)
+    throw std::invalid_argument("tau must be non-negative");
+  Train(referenceSetIn);
+}
+
+// Construct the object.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+    MatType&& referenceSetIn,
+    const bool naive,
+    const bool singleMode,
+    const double tau,
+    const double epsilon,
+    const MetricType metric) :
+    neighborSearch(naive, singleMode, epsilon, metric),
+    tau(tau)
+{
+  if (tau < 0)
+    throw std::invalid_argument("tau must be non-negative");
+  Train(std::move(referenceSetIn));
+}
+
+// Construct the object.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+    Tree* referenceTree,
+    const bool singleMode,
+    const double tau,
+    const double epsilon,
+    const MetricType metric) :
+    neighborSearch(singleMode, epsilon, metric),
+    tau(tau)
+{
+  if (tau < 0)
+    throw std::invalid_argument("tau must be non-negative");
+  Train(referenceTree);
+}
+
+// Construct the object without a reference dataset.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+    const bool naive,
+    const bool singleMode,
+    const double tau,
+    const double epsilon,
+    const MetricType metric) :
+    neighborSearch(naive, singleMode, epsilon, metric),
+    tau(tau)
+{
+  if (tau < 0)
+    throw std::invalid_argument("tau must be non-negative");
+}
+
+// Clean memory.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::
+~SpillSearch()
+{
+  /* Nothing to do */
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(const MatType& referenceSet)
+{
+  if (Naive())
+    neighborSearch.Train(referenceSet);
+  else
+  {
+    // Build reference tree with proper value for tau.
+    Tree* tree = new Tree(referenceSet, tau);
+    neighborSearch.Train(tree);
+    // Give the model ownership of the tree.
+    neighborSearch.treeOwner = true;
+  }
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(MatType&& referenceSetIn)
+{
+  if (Naive())
+    neighborSearch.Train(std::move(referenceSetIn));
+  else
+  {
+    // Build reference tree with proper value for tau.
+    Tree* tree = new Tree(std::move(referenceSetIn), tau);
+    neighborSearch.Train(tree);
+    // Give the model ownership of the tree.
+    neighborSearch.treeOwner = true;
+  }
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(Tree* referenceTree)
+{
+  neighborSearch.Train(referenceTree);
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(const MatType& querySet,
+       const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
+{
+  if (Naive() || SingleMode())
+    neighborSearch.Search(querySet, k, neighbors, distances);
+  else
+  {
+    // For Dual Tree Search on SpillTrees, the queryTree must be built with non
+    // overlapping (tau = 0).
+    Tree queryTree(querySet, 0 /* tau */);
+    neighborSearch.Search(&queryTree, k, neighbors, distances);
+  }
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(Tree* queryTree,
+       const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
+{
+  neighborSearch.Search(queryTree, k, neighbors, distances);
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(const size_t k,
+       arma::Mat<size_t>& neighbors,
+       arma::mat& distances)
+{
+  if (tau == 0 || Naive() || SingleMode())
+    neighborSearch.Search(k, neighbors, distances);
+  else
+  {
+    // For Dual Tree Search on SpillTrees, the queryTree must be built with non
+    // overlapping (tau = 0). If the referenceTree was built with a non-zero
+    // value for tau, we need to build a new queryTree.
+    Tree queryTree(ReferenceSet(), 0 /* tau */);
+    neighborSearch.Search(&queryTree, k, neighbors, distances, true);
+  }
+}
+
+//! Serialize SpillSearch.
+template<typename MetricType,
+         typename MatType,
+         template<typename SplitBoundT, typename SplitMatT> class SplitType>
+template<typename Archive>
+void SpillSearch<MetricType, MatType, SplitType>::
+    Serialize(Archive& ar, const unsigned int /* version */)
+{
+  ar & data::CreateNVP(neighborSearch, "neighborSearch");
+  ar & data::CreateNVP(tau, "tau");
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif




More information about the mlpack-git mailing list