[mlpack-git] master: Update SpillSearch to the general definition of Spill Trees. (b9e23d4)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 4 12:09:34 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit b9e23d444e2ec5419b911a7f42e7d9136bed52ab
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Aug 4 13:09:34 2016 -0300
Update SpillSearch to the general definition of Spill Trees.
>---------------------------------------------------------------
b9e23d444e2ec5419b911a7f42e7d9136bed52ab
.../methods/neighbor_search/neighbor_search.hpp | 2 ++
.../methods/neighbor_search/spill_search.hpp | 8 +++--
.../methods/neighbor_search/spill_search_impl.hpp | 36 ++++++++++++-------
.../methods/neighbor_search/spill_search_rules.hpp | 5 +--
.../neighbor_search/spill_search_rules_impl.hpp | 42 +++++++++++-----------
5 files changed, 55 insertions(+), 38 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e25c93a..75e6a55 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -29,6 +29,7 @@ namespace neighbor /** Neighbor-search routines. These include
// Forward declaration.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
class SpillSearch;
@@ -337,6 +338,7 @@ class NeighborSearch
template<typename MetricT,
typename MatT,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
friend class SpillSearch;
}; // class NeighborSearch
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
index 8490062..e1d4008 100644
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -38,20 +38,22 @@ class TrainVisitor;
*/
template<typename MetricType = mlpack::metric::EuclideanDistance,
typename MatType = arma::mat,
+ template<typename HyperplaneMetricType>
+ class HyperplaneType = tree::AxisOrthogonalHyperplane,
template<typename SplitBoundT, typename SplitMatT> class SplitType =
- tree::MidpointSplit>
+ tree::MidpointSpaceSplit>
class SpillSearch
{
public:
//! Convenience typedef.
typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
- MatType, SplitType> Tree;
+ MatType, HyperplaneType, SplitType> Tree;
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType>
using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
- SplitType>;
+ HyperplaneType, SplitType>;
/**
* Initialize the SpillSearch object, passing a reference dataset (this is
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
index f1ce70a..0a06037 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -18,8 +18,9 @@ namespace neighbor {
// Construct the object.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
@@ -37,8 +38,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
// Construct the object.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
MatType&& referenceSetIn,
const bool naive,
const bool singleMode,
@@ -56,8 +58,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
// Construct the object.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
Tree* referenceTree,
const bool singleMode,
const double tau,
@@ -74,8 +77,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
// Construct the object without a reference dataset.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
const bool naive,
const bool singleMode,
const double tau,
@@ -91,8 +95,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
// Clean memory.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
~SpillSearch()
{
/* Nothing to do */
@@ -100,8 +105,9 @@ SpillSearch<MetricType, MatType, SplitType>::
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Train(const MatType& referenceSet)
{
if (Naive())
@@ -118,8 +124,9 @@ Train(const MatType& referenceSet)
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Train(MatType&& referenceSetIn)
{
if (Naive())
@@ -136,8 +143,9 @@ Train(MatType&& referenceSetIn)
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Train(Tree* referenceTree)
{
neighborSearch.Train(referenceTree);
@@ -145,8 +153,9 @@ Train(Tree* referenceTree)
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
@@ -165,8 +174,9 @@ Search(const MatType& querySet,
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Search(Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
@@ -177,8 +187,9 @@ Search(Tree* queryTree,
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Search(const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -198,9 +209,10 @@ Search(const size_t k,
//! Serialize SpillSearch.
template<typename MetricType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType>
template<typename Archive>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
Serialize(Archive& ar, const unsigned int /* version */)
{
ar & data::CreateNVP(neighborSearch, "neighborSearch");
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
index 54b50ca..f7d7085 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
@@ -28,13 +28,14 @@ namespace neighbor {
*/
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, SplitType>>
+ StatisticType, MatType, HyperplaneType, SplitType>>
{
- typedef tree::SpillTree<MetricType, StatisticType, MatType, SplitType>
+ typedef tree::SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>
TreeType;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
index 16cdd24..50a61ac 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
@@ -16,11 +16,12 @@ namespace neighbor {
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, SplitType>>::NeighborSearchRules(
+ StatisticType, MatType, HyperplaneType, SplitType>>::NeighborSearchRules(
const typename TreeType::Mat& referenceSet,
const typename TreeType::Mat& querySet,
const size_t k,
@@ -61,11 +62,12 @@ NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, SplitType>>::GetResults(
+ StatisticType, MatType, HyperplaneType, SplitType>>::GetResults(
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
@@ -86,12 +88,13 @@ void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline force_inline // Absolutely MUST be inline so optimizations can happen.
double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, SplitType>>::BaseCase(
+ StatisticType, MatType, HyperplaneType, SplitType>>::BaseCase(
const size_t queryIndex,
const size_t referenceIndex)
{
@@ -111,11 +114,12 @@ double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::Score(
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
const size_t queryIndex,
TreeType& referenceNode)
{
@@ -126,12 +130,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
if (referenceNode.Parent()->Overlap()) // Defeatist search.
{
- const double value = referenceNode.Parent()->SplitValue();
- const size_t dim = referenceNode.Parent()->SplitDimension();
- const bool left = &referenceNode == referenceNode.Parent()->Left();
-
- if ((left && querySet(dim, queryIndex) <= value) ||
- (!left && querySet(dim, queryIndex) > value))
+ if (referenceNode.HalfSpaceContains(querySet.col(queryIndex)))
return 0;
else
return DBL_MAX;
@@ -149,11 +148,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::Rescore(
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
const size_t queryIndex,
TreeType& /* referenceNode */,
double oldScore) const
@@ -171,11 +171,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::Score(
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
@@ -186,12 +187,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
if (referenceNode.Parent()->Overlap()) // Defeatist search.
{
- const double value = referenceNode.Parent()->SplitValue();
- const size_t dim = referenceNode.Parent()->SplitDimension();
- const bool left = &referenceNode == referenceNode.Parent()->Left();
-
- if ((left && queryNode.Bound()[dim].Lo() <= value) ||
- (!left && queryNode.Bound()[dim].Hi() > value))
+ if (referenceNode.HalfSpaceIntersects(queryNode))
return 0;
else
return DBL_MAX;
@@ -312,11 +308,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::Rescore(
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
TreeType& queryNode,
TreeType& /* referenceNode */,
const double oldScore) const
@@ -337,11 +334,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
// it.
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
CalculateBound(TreeType& queryNode) const
{
// This is an adapted form of the B(N_q) function in the paper
@@ -415,11 +413,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
*/
template<typename StatisticType,
typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitBoundT, typename SplitMatT> class SplitType,
typename SortPolicy,
typename MetricType>
inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, SplitType>>::InsertNeighbor(
+ MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
+ InsertNeighbor(
const size_t queryIndex,
const size_t neighbor,
const double distance)
More information about the mlpack-git
mailing list