[mlpack-git] master: Update SpillSearch to the general definition of Spill Trees. (b9e23d4)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 4 12:09:34 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit b9e23d444e2ec5419b911a7f42e7d9136bed52ab
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Aug 4 13:09:34 2016 -0300

    Update SpillSearch to the general definition of Spill Trees.


>---------------------------------------------------------------

b9e23d444e2ec5419b911a7f42e7d9136bed52ab
 .../methods/neighbor_search/neighbor_search.hpp    |  2 ++
 .../methods/neighbor_search/spill_search.hpp       |  8 +++--
 .../methods/neighbor_search/spill_search_impl.hpp  | 36 ++++++++++++-------
 .../methods/neighbor_search/spill_search_rules.hpp |  5 +--
 .../neighbor_search/spill_search_rules_impl.hpp    | 42 +++++++++++-----------
 5 files changed, 55 insertions(+), 38 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e25c93a..75e6a55 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -29,6 +29,7 @@ namespace neighbor /** Neighbor-search routines.  These include
 // Forward declaration.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
 class SpillSearch;
 
@@ -337,6 +338,7 @@ class NeighborSearch
 
   template<typename MetricT,
            typename MatT,
+           template<typename HyperplaneMetricType> class HyperplaneType,
            template<typename SplitBoundT, typename SplitMatT> class SplitType>
   friend class SpillSearch;
 }; // class NeighborSearch
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
index 8490062..e1d4008 100644
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -38,20 +38,22 @@ class TrainVisitor;
  */
 template<typename MetricType = mlpack::metric::EuclideanDistance,
          typename MatType = arma::mat,
+         template<typename HyperplaneMetricType>
+             class HyperplaneType = tree::AxisOrthogonalHyperplane,
          template<typename SplitBoundT, typename SplitMatT> class SplitType =
-             tree::MidpointSplit>
+             tree::MidpointSpaceSplit>
 class SpillSearch
 {
  public:
   //! Convenience typedef.
   typedef tree::SpillTree<MetricType, NeighborSearchStat<NearestNeighborSort>,
-      MatType, SplitType> Tree;
+      MatType, HyperplaneType, SplitType> Tree;
 
   template<typename TreeMetricType,
            typename TreeStatType,
            typename TreeMatType>
   using TreeType = tree::SpillTree<TreeMetricType, TreeStatType, TreeMatType,
-      SplitType>;
+      HyperplaneType, SplitType>;
 
   /**
    * Initialize the SpillSearch object, passing a reference dataset (this is
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
index f1ce70a..0a06037 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -18,8 +18,9 @@ namespace neighbor {
 // Construct the object.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     const MatType& referenceSetIn,
     const bool naive,
     const bool singleMode,
@@ -37,8 +38,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
 // Construct the object.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     MatType&& referenceSetIn,
     const bool naive,
     const bool singleMode,
@@ -56,8 +58,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
 // Construct the object.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     Tree* referenceTree,
     const bool singleMode,
     const double tau,
@@ -74,8 +77,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
 // Construct the object without a reference dataset.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
     const bool naive,
     const bool singleMode,
     const double tau,
@@ -91,8 +95,9 @@ SpillSearch<MetricType, MatType, SplitType>::SpillSearch(
 // Clean memory.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-SpillSearch<MetricType, MatType, SplitType>::
+SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 ~SpillSearch()
 {
   /* Nothing to do */
@@ -100,8 +105,9 @@ SpillSearch<MetricType, MatType, SplitType>::
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Train(const MatType& referenceSet)
 {
   if (Naive())
@@ -118,8 +124,9 @@ Train(const MatType& referenceSet)
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Train(MatType&& referenceSetIn)
 {
   if (Naive())
@@ -136,8 +143,9 @@ Train(MatType&& referenceSetIn)
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Train(Tree* referenceTree)
 {
   neighborSearch.Train(referenceTree);
@@ -145,8 +153,9 @@ Train(Tree* referenceTree)
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Search(const MatType& querySet,
        const size_t k,
        arma::Mat<size_t>& neighbors,
@@ -165,8 +174,9 @@ Search(const MatType& querySet,
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Search(Tree* queryTree,
        const size_t k,
        arma::Mat<size_t>& neighbors,
@@ -177,8 +187,9 @@ Search(Tree* queryTree,
 
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
 Search(const size_t k,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
@@ -198,9 +209,10 @@ Search(const size_t k,
 //! Serialize SpillSearch.
 template<typename MetricType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType>
 template<typename Archive>
-void SpillSearch<MetricType, MatType, SplitType>::
+void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
     Serialize(Archive& ar, const unsigned int /* version */)
 {
   ar & data::CreateNVP(neighborSearch, "neighborSearch");
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
index 54b50ca..f7d7085 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
@@ -28,13 +28,14 @@ namespace neighbor {
  */
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, SplitType>>
+    StatisticType, MatType, HyperplaneType, SplitType>>
 {
-  typedef tree::SpillTree<MetricType, StatisticType, MatType, SplitType>
+  typedef tree::SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>
       TreeType;
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
index 16cdd24..50a61ac 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
@@ -16,11 +16,12 @@ namespace neighbor {
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, SplitType>>::NeighborSearchRules(
+    StatisticType, MatType, HyperplaneType, SplitType>>::NeighborSearchRules(
     const typename TreeType::Mat& referenceSet,
     const typename TreeType::Mat& querySet,
     const size_t k,
@@ -61,11 +62,12 @@ NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, SplitType>>::GetResults(
+    StatisticType, MatType, HyperplaneType, SplitType>>::GetResults(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
 {
@@ -86,12 +88,13 @@ void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline force_inline // Absolutely MUST be inline so optimizations can happen.
 double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
-    StatisticType, MatType, SplitType>>::BaseCase(
+    StatisticType, MatType, HyperplaneType, SplitType>>::BaseCase(
     const size_t queryIndex,
     const size_t referenceIndex)
 {
@@ -111,11 +114,12 @@ double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::Score(
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
     const size_t queryIndex,
     TreeType& referenceNode)
 {
@@ -126,12 +130,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 
   if (referenceNode.Parent()->Overlap()) // Defeatist search.
   {
-    const double value = referenceNode.Parent()->SplitValue();
-    const size_t dim = referenceNode.Parent()->SplitDimension();
-    const bool left = &referenceNode == referenceNode.Parent()->Left();
-
-    if ((left && querySet(dim, queryIndex) <= value) ||
-        (!left && querySet(dim, queryIndex) > value))
+    if (referenceNode.HalfSpaceContains(querySet.col(queryIndex)))
       return 0;
     else
       return DBL_MAX;
@@ -149,11 +148,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::Rescore(
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
     const size_t queryIndex,
     TreeType& /* referenceNode */,
     double oldScore) const
@@ -171,11 +171,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::Score(
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
     TreeType& queryNode,
     TreeType& referenceNode)
 {
@@ -186,12 +187,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 
   if (referenceNode.Parent()->Overlap()) // Defeatist search.
   {
-    const double value = referenceNode.Parent()->SplitValue();
-    const size_t dim = referenceNode.Parent()->SplitDimension();
-    const bool left = &referenceNode == referenceNode.Parent()->Left();
-
-    if ((left && queryNode.Bound()[dim].Lo() <= value) ||
-        (!left && queryNode.Bound()[dim].Hi() > value))
+    if (referenceNode.HalfSpaceIntersects(queryNode))
       return 0;
     else
       return DBL_MAX;
@@ -312,11 +308,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::Rescore(
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
     TreeType& queryNode,
     TreeType& /* referenceNode */,
     const double oldScore) const
@@ -337,11 +334,12 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
 // it.
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
     CalculateBound(TreeType& queryNode) const
 {
   // This is an adapted form of the B(N_q) function in the paper
@@ -415,11 +413,13 @@ inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
  */
 template<typename StatisticType,
          typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
          template<typename SplitBoundT, typename SplitMatT> class SplitType,
          typename SortPolicy,
          typename MetricType>
 inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
-    MetricType, StatisticType, MatType, SplitType>>::InsertNeighbor(
+    MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
+    InsertNeighbor(
     const size_t queryIndex,
     const size_t neighbor,
     const double distance)




More information about the mlpack-git mailing list