[mlpack-git] master: Create a new class SpillSearch that encapsulates an instance of NeighborSearch class, and adds the functionality to deal with spill trees. (e95e3b4)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:22 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit e95e3b429601a964c4fd31e7f073931e4e079781
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 28 13:55:18 2016 -0300
Create a new class SpillSearch that encapsulates an instance of NeighborSearch class, and adds the functionality to deal with spill trees.
>---------------------------------------------------------------
e95e3b429601a964c4fd31e7f073931e4e079781
src/mlpack/methods/neighbor_search/CMakeLists.txt | 2 +
.../methods/neighbor_search/neighbor_search.hpp | 19 +-
.../neighbor_search/neighbor_search_impl.hpp | 5 +-
.../{neighbor_search.hpp => spill_search.hpp} | 221 ++++++++-------------
.../methods/neighbor_search/spill_search_impl.hpp | 213 ++++++++++++++++++++
5 files changed, 321 insertions(+), 139 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 95fe37b..e4c76e1 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -14,6 +14,8 @@ set(SOURCES
sort_policies/nearest_neighbor_sort_impl.hpp
sort_policies/furthest_neighbor_sort.hpp
sort_policies/furthest_neighbor_sort_impl.hpp
+ spill_search.hpp
+ spill_search_impl.hpp
typedef.hpp
unmap.hpp
unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index f1acea4..e25c93a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,6 +27,12 @@ namespace neighbor /** Neighbor-search routines. These include
* searches. */ {
// Forward declaration.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+class SpillSearch;
+
+// Forward declaration.
template<typename SortPolicy>
class TrainVisitor;
@@ -237,11 +243,14 @@ class NeighborSearch
* @param neighbors Matrix storing lists of neighbors for each query point.
* @param distances Matrix storing distances of neighbors for each query
* point.
+ * @param sameSet Denotes whether or not the reference and query sets are the
+ * same.
*/
void Search(Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
- arma::mat& distances);
+ arma::mat& distances,
+ bool sameSet = false);
/**
* Search for the nearest neighbors of every point in the reference set. This
@@ -323,7 +332,13 @@ class NeighborSearch
bool treeNeedsReset;
//! The NSModel class should have access to internal members.
- friend class TrainVisitor<SortPolicy>;
+ template<typename SortPol>
+ friend class TrainVisitor;
+
+ template<typename MetricT,
+ typename MatT,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+ friend class SpillSearch;
}; // class NeighborSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 79a31d0..4fd073d 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -506,7 +506,8 @@ void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
Search(Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
- arma::mat& distances)
+ arma::mat& distances,
+ bool sameSet)
{
if (k > referenceSet->n_cols)
{
@@ -540,7 +541,7 @@ Search(Tree* queryTree,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(*referenceSet, querySet, k, metric, epsilon);
+ RuleType rules(*referenceSet, querySet, k, metric, epsilon, sameSet);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
similarity index 54%
copy from src/mlpack/methods/neighbor_search/neighbor_search.hpp
copy to src/mlpack/methods/neighbor_search/spill_search.hpp
index f1acea4..8490062 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -1,177 +1,149 @@
/**
- * @file neighbor_search.hpp
+ * @file spill_search.hpp
* @author Ryan Curtin
+ * @author Marcos Pividori
*
- * Defines the NeighborSearch class, which performs an abstract
- * nearest-neighbor-like query on two datasets.
+ * Defines the SpillSearch class, which performs a Hybrid sp-tree search on
+ * two datasets.
*/
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_HPP
#include <mlpack/core.hpp>
-#include <vector>
-#include <string>
-
-#include <mlpack/core/tree/binary_space_tree.hpp>
-#include <mlpack/core/tree/rectangle_tree.hpp>
-#include <mlpack/core/tree/binary_space_tree/binary_space_tree.hpp>
-
#include <mlpack/core/metrics/lmetric.hpp>
-#include "neighbor_search_stat.hpp"
#include "sort_policies/nearest_neighbor_sort.hpp"
-#include "neighbor_search_rules.hpp"
+#include "neighbor_search.hpp"
namespace mlpack {
-namespace neighbor /** Neighbor-search routines. These include
- * all-nearest-neighbors and all-furthest-neighbors
- * searches. */ {
+namespace neighbor {
// Forward declaration.
template<typename SortPolicy>
class TrainVisitor;
/**
- * The NeighborSearch class is a template class for performing distance-based
- * neighbor searches. It takes a query dataset and a reference dataset (or just
- * a reference dataset) and, for each point in the query dataset, finds the k
- * neighbors in the reference dataset which have the 'best' distance according
- * to a given sorting policy. A constructor is given which takes only a
- * reference dataset, and if that constructor is used, the given reference
- * dataset is also used as the query dataset.
+ * The SpillSearch class is a template class for performing distance-based
+ * neighbor searches with Spill Trees. It takes a query dataset and a reference
+ * dataset (or just a reference dataset) and, for each point in the query
+ * dataset, finds the k neighbors in the reference dataset which have the 'best'
+ * distance according to a given sorting policy. A constructor is given which
+ * takes only a reference dataset, and if that constructor is used, the given
+ * reference dataset is also used as the query dataset.
*
- * The template parameters SortPolicy and Metric define the sort function used
- * and the metric (distance function) used. More information on those classes
- * can be found in the NearestNeighborSort class and the kernel::ExampleKernel
- * class.
- *
- * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
* @tparam MetricType The metric to use for computation.
* @tparam MatType The type of data matrix.
- * @tparam TreeType The tree type to use; must adhere to the TreeType API.
- * @tparam TraversalType The type of traversal to use (defaults to the tree's
- * default traverser).
+ * @tparam SplitType The class that partitions the dataset/points at a
+ * particular node into two parts. Its definition decides the way this split
+ * is done when building spill trees.
*/
-template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::EuclideanDistance,
+template<typename MetricType = mlpack::metric::EuclideanDistance,
typename MatType = arma::mat,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType = tree::KDTree,
- template<typename RuleType> class TraversalType =
- TreeType<MetricType,
- NeighborSearchStat<SortPolicy>,
- MatType>::template DualTreeTraverser>
-class NeighborSearch
+ template<typename SplitBoundT, typename SplitMatT> class SplitType =
+ tree::MidpointSplit>
+class SpillSearch
{
public:
//! Convenience typedef.
- typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
+ typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
+ MatType, SplitType> Tree;
+
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType>
+ using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
+ SplitType>;
/**
- * Initialize the NeighborSearch object, passing a reference dataset (this is
+ * Initialize the SpillSearch object, passing a reference dataset (this is
* the dataset which is searched). Optionally, perform the computation in
* naive mode or single-tree mode. An initialized distance metric can be
* given, for cases where the metric has internal data (i.e. the
* distance::MahalanobisDistance class).
*
- * This method will copy the matrices to internal copies, which are rearranged
- * during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a different constructor, or by using the
- * construct that takes an rvalue reference to the dataset.
- *
* @param referenceSet Set of reference points.
* @param naive If true, O(n^2) naive search will be used (as opposed to
* dual-tree search). This overrides singleMode (if it is set to true).
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
+ * @param tau Overlapping size (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(const MatType& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const double epsilon = 0,
- const MetricType metric = MetricType());
+ SpillSearch(const MatType& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const double tau = 0,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
/**
- * Initialize the NeighborSearch object, taking ownership of the reference
+ * Initialize the SpillSearch object, taking ownership of the reference
* dataset (this is the dataset which is searched). Optionally, perform the
* computation in naive mode or single-tree mode. An initialized distance
* metric can be given, for cases where the metric has internal data (i.e. the
* distance::MahalanobisDistance class).
*
- * This method will not copy the data matrix, but will take ownership of it,
- * and depending on the type of tree used, may rearrange the points. If you
- * would rather a copy be made, consider using the constructor that takes a
- * const reference to the data instead.
- *
* @param referenceSet Set of reference points.
* @param naive If true, O(n^2) naive search will be used (as opposed to
* dual-tree search). This overrides singleMode (if it is set to true).
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
+ * @param tau Overlapping size (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
- NeighborSearch(MatType&& referenceSet,
- const bool naive = false,
- const bool singleMode = false,
- const double epsilon = 0,
- const MetricType metric = MetricType());
+ SpillSearch(MatType&& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const double tau = 0,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
/**
- * Initialize the NeighborSearch object with the given pre-constructed
+ * Initialize the SpillSearch object with the given pre-constructed
* reference tree (this is the tree built on the points that will be
* searched). Optionally, choose to use single-tree mode. Naive mode is not
* available as an option for this constructor. Additionally, an instantiated
* distance metric can be given, for cases where the distance metric holds
* data.
*
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
- *
- * @note
- * Mapping the points of the matrix back to their original indices is not done
- * when this constructor is used, so if the tree type you are using maps
- * points (like BinarySpaceTree), then you will have to perform the re-mapping
- * manually.
- * @endnote
- *
* @param referenceTree Pre-built tree for reference points.
- * @param referenceSet Set of reference points corresponding to referenceTree.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
+ * @param tau Overlapping size (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated distance metric.
*/
- NeighborSearch(Tree* referenceTree,
- const bool singleMode = false,
- const double epsilon = 0,
- const MetricType metric = MetricType());
+ SpillSearch(Tree* referenceTree,
+ const bool singleMode = false,
+ const double tau = 0,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
/**
- * Create a NeighborSearch object without any reference data. If Search() is
+ * Create a SpillSearch object without any reference data. If Search() is
* called before a reference set is set with Train(), an exception will be
* thrown.
*
* @param naive Whether to use naive search.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
+ * @param tau Overlapping size (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated metric.
*/
- NeighborSearch(const bool naive = false,
- const bool singleMode = false,
- const double epsilon = 0,
- const MetricType metric = MetricType());
+ SpillSearch(const bool naive = false,
+ const bool singleMode = false,
+ const double tau = 0,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
/**
- * Delete the NeighborSearch object. The tree is the only member we are
+ * Delete the SpillSearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
- ~NeighborSearch();
+ ~SpillSearch();
/**
* Set the reference set to a new reference set, and build a tree if
@@ -195,6 +167,8 @@ class NeighborSearch
/**
* Set the reference tree to a new reference tree.
+ *
+ * @param referenceTree Pre-built tree for reference points.
*/
void Train(Tree* referenceTree);
@@ -263,76 +237,53 @@ class NeighborSearch
//! Return the total number of base case evaluations performed during the last
//! search.
- size_t BaseCases() const { return baseCases; }
+ size_t BaseCases() const { return neighborSearch.BaseCases(); }
//! Return the number of node combination scores during the last search.
- size_t Scores() const { return scores; }
+ size_t Scores() const { return neighborSearch.Scores(); }
//! Access whether or not search is done in naive linear scan mode.
- bool Naive() const { return naive; }
+ bool Naive() const { return neighborSearch.Naive(); }
//! Modify whether or not search is done in naive linear scan mode.
- bool& Naive() { return naive; }
+ bool& Naive() { return neighborSearch.Naive(); }
//! Access whether or not search is done in single-tree mode.
- bool SingleMode() const { return singleMode; }
+ bool SingleMode() const { return neighborSearch.SingleMode(); }
//! Modify whether or not search is done in single-tree mode.
- bool& SingleMode() { return singleMode; }
+ bool& SingleMode() { return neighborSearch.SingleMode(); }
//! Access the relative error to be considered in approximate search.
- double Epsilon() const { return epsilon; }
+ double Epsilon() const { return neighborSearch.Epsilon(); }
//! Modify the relative error to be considered in approximate search.
- double& Epsilon() { return epsilon; }
+ double& Epsilon() { return neighborSearch.Epsilon(); }
+
+ //! Access the overlapping size.
+ double Tau() const { return tau; }
//! Access the reference dataset.
- const MatType& ReferenceSet() const { return *referenceSet; }
+ const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
- //! Serialize the NeighborSearch model.
+ //! Serialize the SpillSearch model.
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);
private:
- //! Permutations of reference points during tree building.
- std::vector<size_t> oldFromNewReferences;
- //! Pointer to the root of the reference tree.
- Tree* referenceTree;
- //! Reference dataset. In some situations we may be the owner of this.
- const MatType* referenceSet;
-
- //! If true, this object created the trees and is responsible for them.
- bool treeOwner;
- //! If true, we own the reference set.
- bool setOwner;
+ //! Internal instance of NeighborSearch class.
+ NeighborSearch<NearestNeighborSort, MetricType, MatType, TreeType>
+ neighborSearch;
- //! Indicates if O(n^2) naive search is being used.
- bool naive;
- //! Indicates if single-tree search is being used (as opposed to dual-tree).
- bool singleMode;
- //! Indicates the relative error to be considered in approximate search.
- double epsilon;
-
- //! Instantiation of metric.
- MetricType metric;
-
- //! The total number of base cases.
- size_t baseCases;
- //! The total number of scores (applicable for non-naive search).
- size_t scores;
-
- //! If this is true, the reference tree bounds need to be reset on a call to
- //! Search() without a query set.
- bool treeNeedsReset;
+ //! Overlapping size.
+ double tau;
//! The NSModel class should have access to internal members.
- friend class TrainVisitor<SortPolicy>;
-}; // class NeighborSearch
+ template<typename SortPolicy>
+ friend class TrainVisitor;
+}; // class SpillSearch
} // namespace neighbor
} // namespace mlpack
// Include implementation.
-#include "neighbor_search_impl.hpp"
-
-// Include convenience typedefs.
-#include "typedef.hpp"
+#include "spill_search_impl.hpp"
#endif
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
new file mode 100644
index 0000000..f1ce70a
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -0,0 +1,213 @@
+/**
+ * @file spill_search_impl.hpp
+ * @author Ryan Curtin
+ * @author Marcos Pividori
+ *
+ * Implementation of SpillSearch class, which performs a Hybrid sp-tree search
+ * on two datasets.
+ */
+#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
+#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "spill_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+// Construct the object.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+ const MatType& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const double tau,
+ const double epsilon,
+ const MetricType metric) :
+ neighborSearch(naive, singleMode, epsilon, metric),
+ tau(tau)
+{
+ if (tau < 0)
+ throw std::invalid_argument("tau must be non-negative");
+ Train(referenceSetIn);
+}
+
+// Construct the object.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+ MatType&& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const double tau,
+ const double epsilon,
+ const MetricType metric) :
+ neighborSearch(naive, singleMode, epsilon, metric),
+ tau(tau)
+{
+ if (tau < 0)
+ throw std::invalid_argument("tau must be non-negative");
+ Train(std::move(referenceSetIn));
+}
+
+// Construct the object.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+ Tree* referenceTree,
+ const bool singleMode,
+ const double tau,
+ const double epsilon,
+ const MetricType metric) :
+ neighborSearch(singleMode, epsilon, metric),
+ tau(tau)
+{
+ if (tau < 0)
+ throw std::invalid_argument("tau must be non-negative");
+ Train(referenceTree);
+}
+
+// Construct the object without a reference dataset.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+ const bool naive,
+ const bool singleMode,
+ const double tau,
+ const double epsilon,
+ const MetricType metric) :
+ neighborSearch(naive, singleMode, epsilon, metric),
+ tau(tau)
+{
+ if (tau < 0)
+ throw std::invalid_argument("tau must be non-negative");
+}
+
+// Clean memory.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+SpillSearch<MetricType, MatType, SplitType>::
+~SpillSearch()
+{
+ /* Nothing to do */
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(const MatType& referenceSet)
+{
+ if (Naive())
+ neighborSearch.Train(referenceSet);
+ else
+ {
+ // Build reference tree with proper value for tau.
+ Tree* tree = new Tree(referenceSet, tau);
+ neighborSearch.Train(tree);
+ // Give the model ownership of the tree.
+ neighborSearch.treeOwner = true;
+ }
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(MatType&& referenceSetIn)
+{
+ if (Naive())
+ neighborSearch.Train(std::move(referenceSetIn));
+ else
+ {
+ // Build reference tree with proper value for tau.
+ Tree* tree = new Tree(std::move(referenceSetIn), tau);
+ neighborSearch.Train(tree);
+ // Give the model ownership of the tree.
+ neighborSearch.treeOwner = true;
+ }
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Train(Tree* referenceTree)
+{
+ neighborSearch.Train(referenceTree);
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(const MatType& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ if (Naive() || SingleMode())
+ neighborSearch.Search(querySet, k, neighbors, distances);
+ else
+ {
+ // For Dual Tree Search on SpillTrees, the queryTree must be built with non
+ // overlapping (tau = 0).
+ Tree queryTree(querySet, 0 /* tau */);
+ neighborSearch.Search(&queryTree, k, neighbors, distances);
+ }
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(Tree* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ neighborSearch.Search(queryTree, k, neighbors, distances);
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+void SpillSearch<MetricType, MatType, SplitType>::
+Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ if (tau == 0 || Naive() || SingleMode())
+ neighborSearch.Search(k, neighbors, distances);
+ else
+ {
+ // For Dual Tree Search on SpillTrees, the queryTree must be built with non
+ // overlapping (tau = 0). If the referenceTree was built with a non-zero
+ // value for tau, we need to build a new queryTree.
+ Tree queryTree(ReferenceSet(), 0 /* tau */);
+ neighborSearch.Search(&queryTree, k, neighbors, distances, true);
+ }
+}
+
+//! Serialize SpillSearch.
+template<typename MetricType,
+ typename MatType,
+ template<typename SplitBoundT, typename SplitMatT> class SplitType>
+template<typename Archive>
+void SpillSearch<MetricType, MatType, SplitType>::
+ Serialize(Archive& ar, const unsigned int /* version */)
+{
+ ar & data::CreateNVP(neighborSearch, "neighborSearch");
+ ar & data::CreateNVP(tau, "tau");
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list