[mlpack-git] master: Remove SpillSearch class. Use NeighborSearch class and typedefs for defeatist search. (5795c74)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 10:46:01 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 5795c7468ec78962390ca378a721c2f7a26497fe
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Aug 18 11:46:01 2016 -0300

    Remove SpillSearch class. Use NeighborSearch class and typedefs for defeatist search.


>---------------------------------------------------------------

5795c7468ec78962390ca378a721c2f7a26497fe
 src/mlpack/methods/neighbor_search/CMakeLists.txt  |   2 -
 .../methods/neighbor_search/neighbor_search.hpp    |  13 -
 .../neighbor_search/neighbor_search_impl.hpp       |  18 +-
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  12 +-
 .../methods/neighbor_search/ns_model_impl.hpp      |  23 +-
 .../methods/neighbor_search/spill_search.hpp       | 323 ---------------------
 .../methods/neighbor_search/spill_search_impl.hpp  | 251 ----------------
 src/mlpack/methods/neighbor_search/typedef.hpp     |  29 ++
 8 files changed, 52 insertions(+), 619 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index e66b5a7..1c51ce4 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -12,8 +12,6 @@ set(SOURCES
   sort_policies/nearest_neighbor_sort_impl.hpp
   sort_policies/furthest_neighbor_sort.hpp
   sort_policies/furthest_neighbor_sort_impl.hpp
-  spill_search.hpp
-  spill_search_impl.hpp
   typedef.hpp
   unmap.hpp
   unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 9b5b3be..71e064c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,13 +27,6 @@ namespace neighbor /** Neighbor-search routines.  These include
                     * searches. */ {
 
 // Forward declaration.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-class SpillSearch;
-
-// Forward declaration.
 template<typename SortPolicy>
 class TrainVisitor;
 
@@ -371,12 +364,6 @@ class NeighborSearch
   //! The NSModel class should have access to internal members.
   template<typename SortPol>
   friend class TrainVisitor;
-
-  template<typename MetricT,
-           typename MatT,
-           template<typename HyperplaneMetricType> class HyperplaneType,
-           template<typename SplitBoundT, typename SplitMatT> class SplitType>
-  friend class SpillSearch;
 }; // class NeighborSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 8968e63..3cbeaed 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -11,6 +11,7 @@
 #include <mlpack/core.hpp>
 
 #include "neighbor_search_rules.hpp"
+#include <mlpack/core/tree/spill_tree/is_spill_tree.hpp>
 
 namespace mlpack {
 namespace neighbor {
@@ -686,7 +687,19 @@ SingleTreeTraversalType>::Search(const size_t k,
     // Create the traverser.
     TraversalType<RuleType> traverser(rules);
 
-    traverser.Traverse(*referenceTree, *referenceTree);
+    if (tree::IsSpillTree<Tree>::value)
+    {
+      // For Dual Tree Search on SpillTree, the queryTree must be built with non
+      // overlapping (tau = 0).
+      Tree queryTree(*referenceSet);
+      traverser.Traverse(queryTree, *referenceTree);
+    }
+    else
+    {
+      traverser.Traverse(*referenceTree, *referenceTree);
+      // Next time we perform this search, we'll need to reset the tree.
+      treeNeedsReset = true;
+    }
 
     scores += rules.Scores();
     baseCases += rules.BaseCases();
@@ -695,9 +708,6 @@ SingleTreeTraversalType>::Search(const size_t k,
         << std::endl;
     Log::Info << rules.BaseCases() << " base cases were calculated."
         << std::endl;
-
-    // Next time we perform this search, we'll need to reset the tree.
-    treeNeedsReset = true;
   }
 
   rules.GetResults(*neighborPtr, *distancePtr);
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index e1606e9..bafccc1 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -16,7 +16,6 @@
 #include <mlpack/core/tree/spill_tree.hpp>
 #include <boost/variant.hpp>
 #include "neighbor_search.hpp"
-#include "spill_search.hpp"
 
 namespace mlpack {
 namespace neighbor {
@@ -36,11 +35,6 @@ using NSType = NeighborSearch<SortPolicy,
                                   NeighborSearchStat<SortPolicy>,
                                   arma::mat>::template DualTreeTraverser>;
 
-/**
- * Alias template for euclidean spill search.
- */
-using NSSpillType = SpillSearch<metric::EuclideanDistance, arma::mat>;
-
 template<typename SortPolicy>
 struct NSModelName
 {
@@ -137,7 +131,7 @@ class BiSearchVisitor : public boost::static_visitor<void>
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
   //! Bichromatic neighbor search specialized for SPTrees.
-  void operator()(NSSpillType* ns) const;
+  void operator()(SpillKNN* ns) const;
 
   //! Construct the BiSearchVisitor.
   BiSearchVisitor(const arma::mat& querySet,
@@ -192,7 +186,7 @@ class TrainVisitor : public boost::static_visitor<void>
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
   //! Train specialized for SPTrees.
-  void operator()(NSSpillType* ns) const;
+  void operator()(SpillKNN* ns) const;
 
   //! Construct the TrainVisitor object with the given reference set, leafSize
   //! for BinarySpaceTrees, and tau and rho for spill trees.
@@ -319,7 +313,7 @@ class NSModel
                  NSType<SortPolicy, tree::RPlusTree>*,
                  NSType<SortPolicy, tree::RPlusPlusTree>*,
                  NSType<SortPolicy, tree::VPTree>*,
-                 NSSpillType*> nSearch;
+                 SpillKNN*> nSearch;
 
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index e0d152b..3fb2b2e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -77,7 +77,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
 
 //! Bichromatic neighbor search specialized for SPTrees.
 template<typename SortPolicy>
-void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
+void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
 {
   if (ns)
   {
@@ -85,7 +85,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
     {
       // For Dual Tree Search on SpillTrees, the queryTree must be built with
       // non overlapping (tau = 0).
-      typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
+      typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
           leafSize, rho);
       ns->Search(&queryTree, k, neighbors, distances);
     }
@@ -168,7 +168,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
 
 //! Train specialized for SPTrees.
 template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
+void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
 {
   if (ns)
   {
@@ -176,11 +176,11 @@ void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
       ns->Train(std::move(referenceSet));
     else
     {
-      typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
+      typename SpillKNN::Tree* tree = new typename SpillKNN::Tree(
           std::move(referenceSet), tau, leafSize, rho);
       ns->Train(tree);
       // Give the model ownership of the tree.
-      ns->neighborSearch.treeOwner = true;
+      ns->treeOwner = true;
     }
   }
   else
@@ -299,17 +299,6 @@ void serialize(
   ns.Serialize(ar, version);
 }
 
-/**
- * Non-intrusive serialization for SpillSearch class. We need this definition
- * because we are going to use the serialize function for boost variant, which
- * will look for a serialize function for its member types.
- */
-template<typename Archive>
-void serialize(Archive& ar, NSSpillType& ns, const unsigned int version)
-{
-  ns.Serialize(ar, version);
-}
-
 //! Serialize the kNN model.
 template<typename SortPolicy>
 template<typename Archive>
@@ -475,7 +464,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
           epsilon);
       break;
     case SPILL_TREE:
-      nSearch = new NSSpillType(naive, singleMode, tau, leafSize, rho, epsilon);
+      nSearch = new SpillKNN(naive, singleMode, epsilon);
       break;
   }
 
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
deleted file mode 100644
index 20053a0..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ /dev/null
@@ -1,323 +0,0 @@
-/**
- * @file spill_search.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Defines the SpillSearch class, which performs a Hybrid sp-tree search on
- * two datasets.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "sort_policies/nearest_neighbor_sort.hpp"
-#include "neighbor_search.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Forward declaration.
-template<typename SortPolicy>
-class TrainVisitor;
-
-/**
- * The SpillSearch class is a template class for performing distance-based
- * neighbor searches with Spill Trees.  It takes a query dataset and a reference
- * dataset (or just a reference dataset) and, for each point in the query
- * dataset, finds the k neighbors in the reference dataset which have the 'best'
- * distance according to a given sorting policy.  A constructor is given which
- * takes only a reference dataset, and if that constructor is used, the given
- * reference dataset is also used as the query dataset.
- *
- * @tparam MetricType The metric to use for computation.
- * @tparam MatType The type of data matrix.
- * @tparam SplitType The class that partitions the dataset/points at a
- *     particular node into two parts. Its definition decides the way this split
- *     is done when building spill trees.
- */
-template<typename MetricType = mlpack::metric::EuclideanDistance,
-         typename MatType = arma::mat,
-         template<typename HyperplaneMetricType>
-             class HyperplaneType = tree::AxisOrthogonalHyperplane,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType =
-             tree::MidpointSpaceSplit>
-class SpillSearch
-{
- public:
-  //! Convenience typedef.
-  typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
-      MatType, HyperplaneType, SplitType> Tree;
-
-  template<typename TreeMetricType,
-           typename TreeStatType,
-           typename TreeMatType>
-  using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
-      HyperplaneType, SplitType>;
-
-  /**
-   * Initialize the SpillSearch object, passing a reference dataset (this is
-   * the dataset which is searched).  Optionally, perform the computation in
-   * naive mode or single-tree mode.  An initialized distance metric can be
-   * given, for cases where the metric has internal data (i.e. the
-   * distance::MahalanobisDistance class).
-   *
-   * @param referenceSet Set of reference points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param tau Overlapping size (non-negative).
-   * @param leafSize Max size of each leaf in the tree.
-   * @param rho Balance threshold (non-negative).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric An optional instance of the MetricType class.
-   */
-  SpillSearch(const MatType& referenceSet,
-              const bool naive = false,
-              const bool singleMode = false,
-              const double tau = 0,
-              const double leafSize = 20,
-              const double rho = 0.7,
-              const double epsilon = 0,
-              const MetricType metric = MetricType());
-
-  /**
-   * Initialize the SpillSearch object, taking ownership of the reference
-   * dataset (this is the dataset which is searched).  Optionally, perform the
-   * computation in naive mode or single-tree mode.  An initialized distance
-   * metric can be given, for cases where the metric has internal data (i.e. the
-   * distance::MahalanobisDistance class).
-   *
-   * @param referenceSet Set of reference points.
-   * @param naive If true, O(n^2) naive search will be used (as opposed to
-   *      dual-tree search).  This overrides singleMode (if it is set to true).
-   * @param singleMode If true, single-tree search will be used (as opposed to
-   *      dual-tree search).
-   * @param tau Overlapping size (non-negative).
-   * @param leafSize Max size of each leaf in the tree.
-   * @param rho Balance threshold (non-negative).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric An optional instance of the MetricType class.
-   */
-  SpillSearch(MatType&& referenceSet,
-              const bool naive = false,
-              const bool singleMode = false,
-              const double tau = 0,
-              const double leafSize = 20,
-              const double rho = 0.7,
-              const double epsilon = 0,
-              const MetricType metric = MetricType());
-
-  /**
-   * Initialize the SpillSearch object with the given pre-constructed
-   * reference tree (this is the tree built on the points that will be
-   * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
-   * available as an option for this constructor.  Additionally, an instantiated
-   * distance metric can be given, for cases where the distance metric holds
-   * data.
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param tau Overlapping size (non-negative).
-   * @param leafSize Max size of each leaf in the tree.
-   * @param rho Balance threshold (non-negative).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric Instantiated distance metric.
-   */
-  SpillSearch(Tree* referenceTree,
-              const bool singleMode = false,
-              const double tau = 0,
-              const double leafSize = 20,
-              const double rho = 0.7,
-              const double epsilon = 0,
-              const MetricType metric = MetricType());
-
-  /**
-   * Create a SpillSearch object without any reference data.  If Search() is
-   * called before a reference set is set with Train(), an exception will be
-   * thrown.
-   *
-   * @param naive Whether to use naive search.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param tau Overlapping size (non-negative).
-   * @param leafSize Max size of each leaf in the tree.
-   * @param rho Balance threshold (non-negative).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric Instantiated metric.
-   */
-  SpillSearch(const bool naive = false,
-              const bool singleMode = false,
-              const double tau = 0,
-              const double leafSize = 20,
-              const double rho = 0.7,
-              const double epsilon = 0,
-              const MetricType metric = MetricType());
-
-
-  /**
-   * Delete the SpillSearch object. The tree is the only member we are
-   * responsible for deleting.  The others will take care of themselves.
-   */
-  ~SpillSearch();
-
-  /**
-   * Set the reference set to a new reference set, and build a tree if
-   * necessary.  This method is called 'Train()' in order to match the rest of
-   * the mlpack abstractions, even though calling this "training" is maybe a bit
-   * of a stretch.
-   *
-   * @param referenceSet New set of reference data.
-   */
-  void Train(const MatType& referenceSet);
-
-  /**
-   * Set the reference set to a new reference set, taking ownership of the set,
-   * and build a tree if necessary.  This method is called 'Train()' in order to
-   * match the rest of the mlpack abstractions, even though calling this
-   * "training" is maybe a bit of a stretch.
-   *
-   * @param referenceSet New set of reference data.
-   */
-  void Train(MatType&& referenceSet);
-
-  /**
-   * Set the reference tree to a new reference tree.
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   */
-  void Train(Tree* referenceTree);
-
-  /**
-   * For each point in the query set, compute the nearest neighbors and store
-   * the output in the given matrices.  The matrices will be set to the size of
-   * n columns by k rows, where n is the number of points in the query dataset
-   * and k is the number of neighbors being searched for.
-   *
-   * If querySet contains only a few query points, the extra cost of building a
-   * tree on the points for dual-tree search may not be warranted, and it may be
-   * worthwhile to set singleMode = false (either in the constructor or with
-   * SingleMode()).
-   *
-   * @param querySet Set of query points (can be just one point).
-   * @param k Number of neighbors to search for.
-   * @param neighbors Matrix storing lists of neighbors for each query point.
-   * @param distances Matrix storing distances of neighbors for each query
-   *     point.
-   */
-  void Search(const MatType& querySet,
-              const size_t k,
-              arma::Mat<size_t>& neighbors,
-              arma::mat& distances);
-
-  /**
-   * Given a pre-built query tree, search for the nearest neighbors of each
-   * point in the query tree, storing the output in the given matrices.  The
-   * matrices will be set to the size of n columns by k rows, where n is the
-   * number of points in the query dataset and k is the number of neighbors
-   * being searched for.
-   *
-   * Note that if you are calling Search() multiple times with a single query
-   * tree, you need to reset the bounds in the statistic of each query node,
-   * otherwise the result may be wrong!  You can do this by calling
-   * TreeType::Stat()::Reset() on each node in the query tree.
-   *
-   * @param queryTree Tree built on query points.
-   * @param k Number of neighbors to search for.
-   * @param neighbors Matrix storing lists of neighbors for each query point.
-   * @param distances Matrix storing distances of neighbors for each query
-   *      point.
-   */
-  void Search(Tree* queryTree,
-              const size_t k,
-              arma::Mat<size_t>& neighbors,
-              arma::mat& distances);
-
-  /**
-   * Search for the nearest neighbors of every point in the reference set.  This
-   * is basically equivalent to calling any other overload of Search() with the
-   * reference set as the query set; so, this lets you do
-   * all-k-nearest-neighbors search.  The results are stored in the given
-   * matrices.  The matrices will be set to the size of n columns by k rows,
-   * where n is the number of points in the query dataset and k is the number of
-   * neighbors being searched for.
-   *
-   * @param k Number of neighbors to search for.
-   * @param neighbors Matrix storing lists of neighbors for each query point.
-   * @param distances Matrix storing distances of neighbors for each query
-   *      point.
-   */
-  void Search(const size_t k,
-              arma::Mat<size_t>& neighbors,
-              arma::mat& distances);
-
-  //! Return the total number of base case evaluations performed during the last
-  //! search.
-  size_t BaseCases() const { return neighborSearch.BaseCases(); }
-
-  //! Return the number of node combination scores during the last search.
-  size_t Scores() const { return neighborSearch.Scores(); }
-
-  //! Access whether or not search is done in naive linear scan mode.
-  bool Naive() const { return neighborSearch.Naive(); }
-  //! Modify whether or not search is done in naive linear scan mode.
-  bool& Naive() { return neighborSearch.Naive(); }
-
-  //! Access whether or not search is done in single-tree mode.
-  bool SingleMode() const { return neighborSearch.SingleMode(); }
-  //! Modify whether or not search is done in single-tree mode.
-  bool& SingleMode() { return neighborSearch.SingleMode(); }
-
-  //! Access the relative error to be considered in approximate search.
-  double Epsilon() const { return neighborSearch.Epsilon(); }
-  //! Modify the relative error to be considered in approximate search.
-  double& Epsilon() { return neighborSearch.Epsilon(); }
-
-  //! Access the overlapping size.
-  double Tau() const { return tau; }
-
-  //! Access the balance threshold.
-  double Rho() const { return rho; }
-
-  //! Access the leaf size.
-  double LeafSize() const { return leafSize; }
-
-  //! Access the reference dataset.
-  const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
-
-  //! Serialize the SpillSearch model.
-  template<typename Archive>
-  void Serialize(Archive& ar, const unsigned int /* version */);
-
- private:
-  //! Internal instance of NeighborSearch class.
-  NeighborSearch<NearestNeighborSort,
-                 MetricType,
-                 MatType,
-                 TreeType,
-                 Tree::template DefeatistDualTreeTraverser,
-                 Tree::template DefeatistSingleTreeTraverser> neighborSearch;
-
-  //! Overlapping size.
-  double tau;
-
-  //! Balance threshold.
-  double rho;
-
-  //! Max leaf size.
-  double leafSize;
-
-  //! The NSModel class should have access to internal members.
-  template<typename SortPolicy>
-  friend class TrainVisitor;
-}; // class SpillSearch
-
-} // namespace neighbor
-} // namespace mlpack
-
-// Include implementation.
-#include "spill_search_impl.hpp"
-
-#endif
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
deleted file mode 100644
index 5638959..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ /dev/null
@@ -1,251 +0,0 @@
-/**
- * @file spill_search_impl.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Implementation of SpillSearch class, which performs a Hybrid sp-tree search
- * on two datasets.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "spill_search.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Construct the object.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
-    const MatType& referenceSetIn,
-    const bool naive,
-    const bool singleMode,
-    const double tau,
-    const double leafSize,
-    const double rho,
-    const double epsilon,
-    const MetricType metric) :
-    neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau),
-    rho(rho),
-    leafSize(leafSize)
-{
-  if (tau < 0)
-    throw std::invalid_argument("tau must be non-negative");
-  if (rho < 0 || rho > 1)
-    throw std::invalid_argument("rho must be in the range [0,1]");
-  Train(referenceSetIn);
-}
-
-// Construct the object.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
-    MatType&& referenceSetIn,
-    const bool naive,
-    const bool singleMode,
-    const double tau,
-    const double leafSize,
-    const double rho,
-    const double epsilon,
-    const MetricType metric) :
-    neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau),
-    rho(rho),
-    leafSize(leafSize)
-{
-  if (tau < 0)
-    throw std::invalid_argument("tau must be non-negative");
-  if (rho < 0 || rho > 1)
-    throw std::invalid_argument("rho must be in the range [0,1]");
-  Train(std::move(referenceSetIn));
-}
-
-// Construct the object.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
-    Tree* referenceTree,
-    const bool singleMode,
-    const double tau,
-    const double leafSize,
-    const double rho,
-    const double epsilon,
-    const MetricType metric) :
-    neighborSearch(singleMode, epsilon, metric),
-    tau(tau),
-    rho(rho),
-    leafSize(leafSize)
-{
-  if (tau < 0)
-    throw std::invalid_argument("tau must be non-negative");
-  if (rho < 0 || rho > 1)
-    throw std::invalid_argument("rho must be in the range [0,1]");
-  Train(referenceTree);
-}
-
-// Construct the object without a reference dataset.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
-    const bool naive,
-    const bool singleMode,
-    const double tau,
-    const double leafSize,
-    const double rho,
-    const double epsilon,
-    const MetricType metric) :
-    neighborSearch(naive, singleMode, epsilon, metric),
-    tau(tau),
-    rho(rho),
-    leafSize(leafSize)
-{
-  if (tau < 0)
-    throw std::invalid_argument("tau must be non-negative");
-  if (rho < 0 || rho > 1)
-    throw std::invalid_argument("rho must be in the range [0,1]");
-}
-
-// Clean memory.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-~SpillSearch()
-{
-  /* Nothing to do */
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(const MatType& referenceSet)
-{
-  if (Naive())
-    neighborSearch.Train(referenceSet);
-  else
-  {
-    // Build reference tree with proper value for tau.
-    Tree* tree = new Tree(referenceSet, tau, leafSize, rho);
-    neighborSearch.Train(tree);
-    // Give the model ownership of the tree.
-    neighborSearch.treeOwner = true;
-  }
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(MatType&& referenceSetIn)
-{
-  if (Naive())
-    neighborSearch.Train(std::move(referenceSetIn));
-  else
-  {
-    // Build reference tree with proper value for tau.
-    Tree* tree = new Tree(std::move(referenceSetIn), tau, leafSize, rho);
-    neighborSearch.Train(tree);
-    // Give the model ownership of the tree.
-    neighborSearch.treeOwner = true;
-  }
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(Tree* referenceTree)
-{
-  neighborSearch.Train(referenceTree);
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(const MatType& querySet,
-       const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
-{
-  if (Naive() || SingleMode())
-    neighborSearch.Search(querySet, k, neighbors, distances);
-  else
-  {
-    // For Dual Tree Search on SpillTrees, the queryTree must be built with non
-    // overlapping (tau = 0).
-    Tree queryTree(querySet, 0 /* tau */, leafSize, rho);
-    neighborSearch.Search(&queryTree, k, neighbors, distances);
-  }
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(Tree* queryTree,
-       const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
-{
-  neighborSearch.Search(queryTree, k, neighbors, distances);
-}
-
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(const size_t k,
-       arma::Mat<size_t>& neighbors,
-       arma::mat& distances)
-{
-  if (tau == 0 || Naive() || SingleMode())
-    neighborSearch.Search(k, neighbors, distances);
-  else
-  {
-    // For Dual Tree Search on SpillTrees, the queryTree must be built with non
-    // overlapping (tau = 0). If the referenceTree was built with a non-zero
-    // value for tau, we need to build a new queryTree.
-    Tree queryTree(ReferenceSet(), 0 /* tau */, leafSize, rho);
-    neighborSearch.Search(&queryTree, k, neighbors, distances, true);
-  }
-}
-
-//! Serialize SpillSearch.
-template<typename MetricType,
-         typename MatType,
-         template<typename HyperplaneMetricType> class HyperplaneType,
-         template<typename SplitBoundT, typename SplitMatT> class SplitType>
-template<typename Archive>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-    Serialize(Archive& ar, const unsigned int /* version */)
-{
-  ar & data::CreateNVP(neighborSearch, "neighborSearch");
-  ar & data::CreateNVP(tau, "tau");
-  ar & data::CreateNVP(rho, "rho");
-  ar & data::CreateNVP(leafSize, "leafSize");
-}
-
-} // namespace neighbor
-} // namespace mlpack
-
-#endif
diff --git a/src/mlpack/methods/neighbor_search/typedef.hpp b/src/mlpack/methods/neighbor_search/typedef.hpp
index 1059e20..5c12abe 100644
--- a/src/mlpack/methods/neighbor_search/typedef.hpp
+++ b/src/mlpack/methods/neighbor_search/typedef.hpp
@@ -33,6 +33,35 @@ typedef NeighborSearch<NearestNeighborSort, metric::EuclideanDistance> KNN;
 typedef NeighborSearch<FurthestNeighborSort, metric::EuclideanDistance> KFN;
 
 /**
+ * The DefeatistKNN class is the k-nearest-neighbors method considering
+ * defeatist search. It returns L2 distances (Euclidean distances) for each of
+ * the k nearest neighbors found.
+ * @tparam TreeType The tree type to use; must adhere to the TreeType API,
+ *     and implement Defeatist Traversers.
+ */
+template<template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType = tree::SPTree>
+using DefeatistKNN = NeighborSearch<
+    NearestNeighborSort,
+    metric::EuclideanDistance,
+    arma::mat,
+    TreeType,
+    TreeType<metric::EuclideanDistance,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>::template DefeatistDualTreeTraverser,
+    TreeType<metric::EuclideanDistance,
+        NeighborSearchStat<NearestNeighborSort>,
+        arma::mat>::template DefeatistSingleTreeTraverser>;
+
+/**
+ * The SpillKNN class is the k-nearest-neighbors method considering defeatist
+ * search on SPTree.  It returns L2 distances (Euclidean distances) for each of
+ * the k nearest neighbors found.
+ */
+typedef DefeatistKNN<tree::SPTree> SpillKNN;
+
+/**
  * @deprecated
  * The AllkNN class is the k-nearest-neighbors method.  It returns L2 distances
  * (Euclidean distances) for each of the k nearest neighbors.  This typedef will




More information about the mlpack-git mailing list