[mlpack-git] master: Remove SpillSearch class. Use NeighborSearch class and typedefs for defeatist search. (5795c74)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 10:46:01 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 5795c7468ec78962390ca378a721c2f7a26497fe
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Aug 18 11:46:01 2016 -0300
Remove SpillSearch class. Use NeighborSearch class and typedefs for defeatist search.
>---------------------------------------------------------------
5795c7468ec78962390ca378a721c2f7a26497fe
src/mlpack/methods/neighbor_search/CMakeLists.txt | 2 -
.../methods/neighbor_search/neighbor_search.hpp | 13 -
.../neighbor_search/neighbor_search_impl.hpp | 18 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 12 +-
.../methods/neighbor_search/ns_model_impl.hpp | 23 +-
.../methods/neighbor_search/spill_search.hpp | 323 ---------------------
.../methods/neighbor_search/spill_search_impl.hpp | 251 ----------------
src/mlpack/methods/neighbor_search/typedef.hpp | 29 ++
8 files changed, 52 insertions(+), 619 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index e66b5a7..1c51ce4 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -12,8 +12,6 @@ set(SOURCES
sort_policies/nearest_neighbor_sort_impl.hpp
sort_policies/furthest_neighbor_sort.hpp
sort_policies/furthest_neighbor_sort_impl.hpp
- spill_search.hpp
- spill_search_impl.hpp
typedef.hpp
unmap.hpp
unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 9b5b3be..71e064c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,13 +27,6 @@ namespace neighbor /** Neighbor-search routines. These include
* searches. */ {
// Forward declaration.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-class SpillSearch;
-
-// Forward declaration.
template<typename SortPolicy>
class TrainVisitor;
@@ -371,12 +364,6 @@ class NeighborSearch
//! The NSModel class should have access to internal members.
template<typename SortPol>
friend class TrainVisitor;
-
- template<typename MetricT,
- typename MatT,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
- friend class SpillSearch;
}; // class NeighborSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 8968e63..3cbeaed 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -11,6 +11,7 @@
#include <mlpack/core.hpp>
#include "neighbor_search_rules.hpp"
+#include <mlpack/core/tree/spill_tree/is_spill_tree.hpp>
namespace mlpack {
namespace neighbor {
@@ -686,7 +687,19 @@ SingleTreeTraversalType>::Search(const size_t k,
// Create the traverser.
TraversalType<RuleType> traverser(rules);
- traverser.Traverse(*referenceTree, *referenceTree);
+ if (tree::IsSpillTree<Tree>::value)
+ {
+ // For Dual Tree Search on SpillTree, the queryTree must be built with non
+ // overlapping (tau = 0).
+ Tree queryTree(*referenceSet);
+ traverser.Traverse(queryTree, *referenceTree);
+ }
+ else
+ {
+ traverser.Traverse(*referenceTree, *referenceTree);
+ // Next time we perform this search, we'll need to reset the tree.
+ treeNeedsReset = true;
+ }
scores += rules.Scores();
baseCases += rules.BaseCases();
@@ -695,9 +708,6 @@ SingleTreeTraversalType>::Search(const size_t k,
<< std::endl;
Log::Info << rules.BaseCases() << " base cases were calculated."
<< std::endl;
-
- // Next time we perform this search, we'll need to reset the tree.
- treeNeedsReset = true;
}
rules.GetResults(*neighborPtr, *distancePtr);
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index e1606e9..bafccc1 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -16,7 +16,6 @@
#include <mlpack/core/tree/spill_tree.hpp>
#include <boost/variant.hpp>
#include "neighbor_search.hpp"
-#include "spill_search.hpp"
namespace mlpack {
namespace neighbor {
@@ -36,11 +35,6 @@ using NSType = NeighborSearch<SortPolicy,
NeighborSearchStat<SortPolicy>,
arma::mat>::template DualTreeTraverser>;
-/**
- * Alias template for euclidean spill search.
- */
-using NSSpillType = SpillSearch<metric::EuclideanDistance, arma::mat>;
-
template<typename SortPolicy>
struct NSModelName
{
@@ -137,7 +131,7 @@ class BiSearchVisitor : public boost::static_visitor<void>
void operator()(NSTypeT<tree::BallTree>* ns) const;
//! Bichromatic neighbor search specialized for SPTrees.
- void operator()(NSSpillType* ns) const;
+ void operator()(SpillKNN* ns) const;
//! Construct the BiSearchVisitor.
BiSearchVisitor(const arma::mat& querySet,
@@ -192,7 +186,7 @@ class TrainVisitor : public boost::static_visitor<void>
void operator()(NSTypeT<tree::BallTree>* ns) const;
//! Train specialized for SPTrees.
- void operator()(NSSpillType* ns) const;
+ void operator()(SpillKNN* ns) const;
//! Construct the TrainVisitor object with the given reference set, leafSize
//! for BinarySpaceTrees, and tau and rho for spill trees.
@@ -319,7 +313,7 @@ class NSModel
NSType<SortPolicy, tree::RPlusTree>*,
NSType<SortPolicy, tree::RPlusPlusTree>*,
NSType<SortPolicy, tree::VPTree>*,
- NSSpillType*> nSearch;
+ SpillKNN*> nSearch;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index e0d152b..3fb2b2e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -77,7 +77,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
//! Bichromatic neighbor search specialized for SPTrees.
template<typename SortPolicy>
-void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
+void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
{
if (ns)
{
@@ -85,7 +85,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
{
// For Dual Tree Search on SpillTrees, the queryTree must be built with
// non overlapping (tau = 0).
- typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
+ typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
leafSize, rho);
ns->Search(&queryTree, k, neighbors, distances);
}
@@ -168,7 +168,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
//! Train specialized for SPTrees.
template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
+void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
{
if (ns)
{
@@ -176,11 +176,11 @@ void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
ns->Train(std::move(referenceSet));
else
{
- typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
+ typename SpillKNN::Tree* tree = new typename SpillKNN::Tree(
std::move(referenceSet), tau, leafSize, rho);
ns->Train(tree);
// Give the model ownership of the tree.
- ns->neighborSearch.treeOwner = true;
+ ns->treeOwner = true;
}
}
else
@@ -299,17 +299,6 @@ void serialize(
ns.Serialize(ar, version);
}
-/**
- * Non-intrusive serialization for SpillSearch class. We need this definition
- * because we are going to use the serialize function for boost variant, which
- * will look for a serialize function for its member types.
- */
-template<typename Archive>
-void serialize(Archive& ar, NSSpillType& ns, const unsigned int version)
-{
- ns.Serialize(ar, version);
-}
-
//! Serialize the kNN model.
template<typename SortPolicy>
template<typename Archive>
@@ -475,7 +464,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
epsilon);
break;
case SPILL_TREE:
- nSearch = new NSSpillType(naive, singleMode, tau, leafSize, rho, epsilon);
+ nSearch = new SpillKNN(naive, singleMode, epsilon);
break;
}
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
deleted file mode 100644
index 20053a0..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ /dev/null
@@ -1,323 +0,0 @@
-/**
- * @file spill_search.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Defines the SpillSearch class, which performs a Hybrid sp-tree search on
- * two datasets.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-#include "sort_policies/nearest_neighbor_sort.hpp"
-#include "neighbor_search.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Forward declaration.
-template<typename SortPolicy>
-class TrainVisitor;
-
-/**
- * The SpillSearch class is a template class for performing distance-based
- * neighbor searches with Spill Trees. It takes a query dataset and a reference
- * dataset (or just a reference dataset) and, for each point in the query
- * dataset, finds the k neighbors in the reference dataset which have the 'best'
- * distance according to a given sorting policy. A constructor is given which
- * takes only a reference dataset, and if that constructor is used, the given
- * reference dataset is also used as the query dataset.
- *
- * @tparam MetricType The metric to use for computation.
- * @tparam MatType The type of data matrix.
- * @tparam SplitType The class that partitions the dataset/points at a
- * particular node into two parts. Its definition decides the way this split
- * is done when building spill trees.
- */
-template<typename MetricType = mlpack::metric::EuclideanDistance,
- typename MatType = arma::mat,
- template<typename HyperplaneMetricType>
- class HyperplaneType = tree::AxisOrthogonalHyperplane,
- template<typename SplitBoundT, typename SplitMatT> class SplitType =
- tree::MidpointSpaceSplit>
-class SpillSearch
-{
- public:
- //! Convenience typedef.
- typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
- MatType, HyperplaneType, SplitType> Tree;
-
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType>
- using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
- HyperplaneType, SplitType>;
-
- /**
- * Initialize the SpillSearch object, passing a reference dataset (this is
- * the dataset which is searched). Optionally, perform the computation in
- * naive mode or single-tree mode. An initialized distance metric can be
- * given, for cases where the metric has internal data (i.e. the
- * distance::MahalanobisDistance class).
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param tau Overlapping size (non-negative).
- * @param leafSize Max size of each leaf in the tree.
- * @param rho Balance threshold (non-negative).
- * @param epsilon Relative approximate error (non-negative).
- * @param metric An optional instance of the MetricType class.
- */
- SpillSearch(const MatType& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const double tau = 0,
- const double leafSize = 20,
- const double rho = 0.7,
- const double epsilon = 0,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the SpillSearch object, taking ownership of the reference
- * dataset (this is the dataset which is searched). Optionally, perform the
- * computation in naive mode or single-tree mode. An initialized distance
- * metric can be given, for cases where the metric has internal data (i.e. the
- * distance::MahalanobisDistance class).
- *
- * @param referenceSet Set of reference points.
- * @param naive If true, O(n^2) naive search will be used (as opposed to
- * dual-tree search). This overrides singleMode (if it is set to true).
- * @param singleMode If true, single-tree search will be used (as opposed to
- * dual-tree search).
- * @param tau Overlapping size (non-negative).
- * @param leafSize Max size of each leaf in the tree.
- * @param rho Balance threshold (non-negative).
- * @param epsilon Relative approximate error (non-negative).
- * @param metric An optional instance of the MetricType class.
- */
- SpillSearch(MatType&& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const double tau = 0,
- const double leafSize = 20,
- const double rho = 0.7,
- const double epsilon = 0,
- const MetricType metric = MetricType());
-
- /**
- * Initialize the SpillSearch object with the given pre-constructed
- * reference tree (this is the tree built on the points that will be
- * searched). Optionally, choose to use single-tree mode. Naive mode is not
- * available as an option for this constructor. Additionally, an instantiated
- * distance metric can be given, for cases where the distance metric holds
- * data.
- *
- * @param referenceTree Pre-built tree for reference points.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param tau Overlapping size (non-negative).
- * @param leafSize Max size of each leaf in the tree.
- * @param rho Balance threshold (non-negative).
- * @param epsilon Relative approximate error (non-negative).
- * @param metric Instantiated distance metric.
- */
- SpillSearch(Tree* referenceTree,
- const bool singleMode = false,
- const double tau = 0,
- const double leafSize = 20,
- const double rho = 0.7,
- const double epsilon = 0,
- const MetricType metric = MetricType());
-
- /**
- * Create a SpillSearch object without any reference data. If Search() is
- * called before a reference set is set with Train(), an exception will be
- * thrown.
- *
- * @param naive Whether to use naive search.
- * @param singleMode Whether single-tree computation should be used (as
- * opposed to dual-tree computation).
- * @param tau Overlapping size (non-negative).
- * @param leafSize Max size of each leaf in the tree.
- * @param rho Balance threshold (non-negative).
- * @param epsilon Relative approximate error (non-negative).
- * @param metric Instantiated metric.
- */
- SpillSearch(const bool naive = false,
- const bool singleMode = false,
- const double tau = 0,
- const double leafSize = 20,
- const double rho = 0.7,
- const double epsilon = 0,
- const MetricType metric = MetricType());
-
-
- /**
- * Delete the SpillSearch object. The tree is the only member we are
- * responsible for deleting. The others will take care of themselves.
- */
- ~SpillSearch();
-
- /**
- * Set the reference set to a new reference set, and build a tree if
- * necessary. This method is called 'Train()' in order to match the rest of
- * the mlpack abstractions, even though calling this "training" is maybe a bit
- * of a stretch.
- *
- * @param referenceSet New set of reference data.
- */
- void Train(const MatType& referenceSet);
-
- /**
- * Set the reference set to a new reference set, taking ownership of the set,
- * and build a tree if necessary. This method is called 'Train()' in order to
- * match the rest of the mlpack abstractions, even though calling this
- * "training" is maybe a bit of a stretch.
- *
- * @param referenceSet New set of reference data.
- */
- void Train(MatType&& referenceSet);
-
- /**
- * Set the reference tree to a new reference tree.
- *
- * @param referenceTree Pre-built tree for reference points.
- */
- void Train(Tree* referenceTree);
-
- /**
- * For each point in the query set, compute the nearest neighbors and store
- * the output in the given matrices. The matrices will be set to the size of
- * n columns by k rows, where n is the number of points in the query dataset
- * and k is the number of neighbors being searched for.
- *
- * If querySet contains only a few query points, the extra cost of building a
- * tree on the points for dual-tree search may not be warranted, and it may be
- * worthwhile to set singleMode = false (either in the constructor or with
- * SingleMode()).
- *
- * @param querySet Set of query points (can be just one point).
- * @param k Number of neighbors to search for.
- * @param neighbors Matrix storing lists of neighbors for each query point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void Search(const MatType& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- /**
- * Given a pre-built query tree, search for the nearest neighbors of each
- * point in the query tree, storing the output in the given matrices. The
- * matrices will be set to the size of n columns by k rows, where n is the
- * number of points in the query dataset and k is the number of neighbors
- * being searched for.
- *
- * Note that if you are calling Search() multiple times with a single query
- * tree, you need to reset the bounds in the statistic of each query node,
- * otherwise the result may be wrong! You can do this by calling
- * TreeType::Stat()::Reset() on each node in the query tree.
- *
- * @param queryTree Tree built on query points.
- * @param k Number of neighbors to search for.
- * @param neighbors Matrix storing lists of neighbors for each query point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void Search(Tree* queryTree,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- /**
- * Search for the nearest neighbors of every point in the reference set. This
- * is basically equivalent to calling any other overload of Search() with the
- * reference set as the query set; so, this lets you do
- * all-k-nearest-neighbors search. The results are stored in the given
- * matrices. The matrices will be set to the size of n columns by k rows,
- * where n is the number of points in the query dataset and k is the number of
- * neighbors being searched for.
- *
- * @param k Number of neighbors to search for.
- * @param neighbors Matrix storing lists of neighbors for each query point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void Search(const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances);
-
- //! Return the total number of base case evaluations performed during the last
- //! search.
- size_t BaseCases() const { return neighborSearch.BaseCases(); }
-
- //! Return the number of node combination scores during the last search.
- size_t Scores() const { return neighborSearch.Scores(); }
-
- //! Access whether or not search is done in naive linear scan mode.
- bool Naive() const { return neighborSearch.Naive(); }
- //! Modify whether or not search is done in naive linear scan mode.
- bool& Naive() { return neighborSearch.Naive(); }
-
- //! Access whether or not search is done in single-tree mode.
- bool SingleMode() const { return neighborSearch.SingleMode(); }
- //! Modify whether or not search is done in single-tree mode.
- bool& SingleMode() { return neighborSearch.SingleMode(); }
-
- //! Access the relative error to be considered in approximate search.
- double Epsilon() const { return neighborSearch.Epsilon(); }
- //! Modify the relative error to be considered in approximate search.
- double& Epsilon() { return neighborSearch.Epsilon(); }
-
- //! Access the overlapping size.
- double Tau() const { return tau; }
-
- //! Access the balance threshold.
- double Rho() const { return rho; }
-
- //! Access the leaf size.
- double LeafSize() const { return leafSize; }
-
- //! Access the reference dataset.
- const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
-
- //! Serialize the SpillSearch model.
- template<typename Archive>
- void Serialize(Archive& ar, const unsigned int /* version */);
-
- private:
- //! Internal instance of NeighborSearch class.
- NeighborSearch<NearestNeighborSort,
- MetricType,
- MatType,
- TreeType,
- Tree::template DefeatistDualTreeTraverser,
- Tree::template DefeatistSingleTreeTraverser> neighborSearch;
-
- //! Overlapping size.
- double tau;
-
- //! Balance threshold.
- double rho;
-
- //! Max leaf size.
- double leafSize;
-
- //! The NSModel class should have access to internal members.
- template<typename SortPolicy>
- friend class TrainVisitor;
-}; // class SpillSearch
-
-} // namespace neighbor
-} // namespace mlpack
-
-// Include implementation.
-#include "spill_search_impl.hpp"
-
-#endif
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
deleted file mode 100644
index 5638959..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ /dev/null
@@ -1,251 +0,0 @@
-/**
- * @file spill_search_impl.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Implementation of SpillSearch class, which performs a Hybrid sp-tree search
- * on two datasets.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "spill_search.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-// Construct the object.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
- const MatType& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const double tau,
- const double leafSize,
- const double rho,
- const double epsilon,
- const MetricType metric) :
- neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau),
- rho(rho),
- leafSize(leafSize)
-{
- if (tau < 0)
- throw std::invalid_argument("tau must be non-negative");
- if (rho < 0 || rho > 1)
- throw std::invalid_argument("rho must be in the range [0,1]");
- Train(referenceSetIn);
-}
-
-// Construct the object.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
- MatType&& referenceSetIn,
- const bool naive,
- const bool singleMode,
- const double tau,
- const double leafSize,
- const double rho,
- const double epsilon,
- const MetricType metric) :
- neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau),
- rho(rho),
- leafSize(leafSize)
-{
- if (tau < 0)
- throw std::invalid_argument("tau must be non-negative");
- if (rho < 0 || rho > 1)
- throw std::invalid_argument("rho must be in the range [0,1]");
- Train(std::move(referenceSetIn));
-}
-
-// Construct the object.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
- Tree* referenceTree,
- const bool singleMode,
- const double tau,
- const double leafSize,
- const double rho,
- const double epsilon,
- const MetricType metric) :
- neighborSearch(singleMode, epsilon, metric),
- tau(tau),
- rho(rho),
- leafSize(leafSize)
-{
- if (tau < 0)
- throw std::invalid_argument("tau must be non-negative");
- if (rho < 0 || rho > 1)
- throw std::invalid_argument("rho must be in the range [0,1]");
- Train(referenceTree);
-}
-
-// Construct the object without a reference dataset.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
- const bool naive,
- const bool singleMode,
- const double tau,
- const double leafSize,
- const double rho,
- const double epsilon,
- const MetricType metric) :
- neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau),
- rho(rho),
- leafSize(leafSize)
-{
- if (tau < 0)
- throw std::invalid_argument("tau must be non-negative");
- if (rho < 0 || rho > 1)
- throw std::invalid_argument("rho must be in the range [0,1]");
-}
-
-// Clean memory.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-~SpillSearch()
-{
- /* Nothing to do */
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(const MatType& referenceSet)
-{
- if (Naive())
- neighborSearch.Train(referenceSet);
- else
- {
- // Build reference tree with proper value for tau.
- Tree* tree = new Tree(referenceSet, tau, leafSize, rho);
- neighborSearch.Train(tree);
- // Give the model ownership of the tree.
- neighborSearch.treeOwner = true;
- }
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(MatType&& referenceSetIn)
-{
- if (Naive())
- neighborSearch.Train(std::move(referenceSetIn));
- else
- {
- // Build reference tree with proper value for tau.
- Tree* tree = new Tree(std::move(referenceSetIn), tau, leafSize, rho);
- neighborSearch.Train(tree);
- // Give the model ownership of the tree.
- neighborSearch.treeOwner = true;
- }
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Train(Tree* referenceTree)
-{
- neighborSearch.Train(referenceTree);
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(const MatType& querySet,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- if (Naive() || SingleMode())
- neighborSearch.Search(querySet, k, neighbors, distances);
- else
- {
- // For Dual Tree Search on SpillTrees, the queryTree must be built with non
- // overlapping (tau = 0).
- Tree queryTree(querySet, 0 /* tau */, leafSize, rho);
- neighborSearch.Search(&queryTree, k, neighbors, distances);
- }
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(Tree* queryTree,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- neighborSearch.Search(queryTree, k, neighbors, distances);
-}
-
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
-Search(const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- if (tau == 0 || Naive() || SingleMode())
- neighborSearch.Search(k, neighbors, distances);
- else
- {
- // For Dual Tree Search on SpillTrees, the queryTree must be built with non
- // overlapping (tau = 0). If the referenceTree was built with a non-zero
- // value for tau, we need to build a new queryTree.
- Tree queryTree(ReferenceSet(), 0 /* tau */, leafSize, rho);
- neighborSearch.Search(&queryTree, k, neighbors, distances, true);
- }
-}
-
-//! Serialize SpillSearch.
-template<typename MetricType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType>
-template<typename Archive>
-void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
- Serialize(Archive& ar, const unsigned int /* version */)
-{
- ar & data::CreateNVP(neighborSearch, "neighborSearch");
- ar & data::CreateNVP(tau, "tau");
- ar & data::CreateNVP(rho, "rho");
- ar & data::CreateNVP(leafSize, "leafSize");
-}
-
-} // namespace neighbor
-} // namespace mlpack
-
-#endif
diff --git a/src/mlpack/methods/neighbor_search/typedef.hpp b/src/mlpack/methods/neighbor_search/typedef.hpp
index 1059e20..5c12abe 100644
--- a/src/mlpack/methods/neighbor_search/typedef.hpp
+++ b/src/mlpack/methods/neighbor_search/typedef.hpp
@@ -33,6 +33,35 @@ typedef NeighborSearch<NearestNeighborSort, metric::EuclideanDistance> KNN;
typedef NeighborSearch<FurthestNeighborSort, metric::EuclideanDistance> KFN;
/**
+ * The DefeatistKNN class is the k-nearest-neighbors method considering
+ * defeatist search. It returns L2 distances (Euclidean distances) for each of
+ * the k nearest neighbors found.
+ * @tparam TreeType The tree type to use; must adhere to the TreeType API,
+ * and implement Defeatist Traversers.
+ */
+template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType = tree::SPTree>
+using DefeatistKNN = NeighborSearch<
+ NearestNeighborSort,
+ metric::EuclideanDistance,
+ arma::mat,
+ TreeType,
+ TreeType<metric::EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>::template DefeatistDualTreeTraverser,
+ TreeType<metric::EuclideanDistance,
+ NeighborSearchStat<NearestNeighborSort>,
+ arma::mat>::template DefeatistSingleTreeTraverser>;
+
+/**
+ * The SpillKNN class is the k-nearest-neighbors method considering defeatist
+ * search on SPTree. It returns L2 distances (Euclidean distances) for each of
+ * the k nearest neighbors found.
+ */
+typedef DefeatistKNN<tree::SPTree> SpillKNN;
+
+/**
* @deprecated
* The AllkNN class is the k-nearest-neighbors method. It returns L2 distances
* (Euclidean distances) for each of the k nearest neighbors. This typedef will
More information about the mlpack-git
mailing list