[mlpack-git] master: Update NSModel to use SpillSearch class. (0120357)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:14 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 0120357e4ae582e69850d79790e19f6c1bd9c5eb
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 28 14:22:00 2016 -0300

    Update NSModel to use SpillSearch class.


>---------------------------------------------------------------

0120357e4ae582e69850d79790e19f6c1bd9c5eb
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 16 +++++---
 .../methods/neighbor_search/ns_model_impl.hpp      | 45 +++++++++++++---------
 2 files changed, 38 insertions(+), 23 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 6a31d5b..03855a1 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -16,6 +16,7 @@
 #include <mlpack/core/tree/spill_tree.hpp>
 #include <boost/variant.hpp>
 #include "neighbor_search.hpp"
+#include "spill_search.hpp"
 
 namespace mlpack {
 namespace neighbor {
@@ -35,6 +36,11 @@ using NSType = NeighborSearch<SortPolicy,
                                   NeighborSearchStat<SortPolicy>,
                                   arma::mat>::template DualTreeTraverser>;
 
+/**
+ * Alias template for euclidean spill search.
+ */
+using NSSpillType = SpillSearch<metric::EuclideanDistance, arma::mat>;
+
 template<typename SortPolicy>
 struct NSModelName
 {
@@ -128,8 +134,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
   //! Bichromatic neighbor search on the given NSType specialized for BallTrees.
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
-  //! Bichromatic neighbor search on the given NSType specialized for SPTrees.
-  void operator()(NSTypeT<tree::SPTree>* ns) const;
+  //! Bichromatic neighbor search specialized for SPTrees.
+  void operator()(NSSpillType* ns) const;
 
   //! Construct the BiSearchVisitor.
   BiSearchVisitor(const arma::mat& querySet,
@@ -180,8 +186,8 @@ class TrainVisitor : public boost::static_visitor<void>
   //! Train on the given NSType specialized for BallTrees.
   void operator()(NSTypeT<tree::BallTree>* ns) const;
 
-  //! Train on the given NSType specialized for SPTrees.
-  void operator()(NSTypeT<tree::SPTree>* ns) const;
+  //! Train specialized for SPTrees.
+  void operator()(NSSpillType* ns) const;
 
   //! Construct the TrainVisitor object with the given reference set, leafSize
   //! for BinarySpaceTrees, and tau for spill trees.
@@ -303,7 +309,7 @@ class NSModel
                  NSType<SortPolicy, tree::HilbertRTree>*,
                  NSType<SortPolicy, tree::RPlusTree>*,
                  NSType<SortPolicy, tree::RPlusPlusTree>*,
-                 NSType<SortPolicy, tree::SPTree>*> nSearch;
+                 NSSpillType*> nSearch;
 
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1bb3e51..d46b7e5 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -73,15 +73,17 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
-//! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+//! Bichromatic neighbor search specialized for SPTrees.
 template<typename SortPolicy>
-void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::SPTree>* ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
 {
   if (ns)
   {
     if (!ns->Naive() && !ns->SingleMode())
     {
-      typename NSTypeT<tree::SPTree>::Tree queryTree(std::move(querySet), 0,
+      // For Dual Tree Search on SpillTrees, the queryTree must be built with
+      // non overlapping (tau = 0).
+      typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
           leafSize);
       ns->Search(&queryTree, k, neighbors, distances);
     }
@@ -160,9 +162,9 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
   throw std::runtime_error("no neighbor search model initialized");
 }
 
-//! Train on the given NSType specialized for SPTrees.
+//! Train specialized for SPTrees.
 template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
 {
   if (ns)
   {
@@ -170,11 +172,11 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
       ns->Train(std::move(referenceSet));
     else
     {
-      typename NSTypeT<tree::SPTree>::Tree* tree = new typename
-          NSTypeT<tree::SPTree>::Tree(std::move(referenceSet), tau, leafSize);
+      typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
+          std::move(referenceSet), tau, leafSize);
       ns->Train(tree);
       // Give the model ownership of the tree.
-      ns->treeOwner = true;
+      ns->neighborSearch.treeOwner = true;
     }
   }
   else
@@ -268,26 +270,34 @@ NSModel<SortPolicy>::~NSModel()
 }
 
 /**
- * Non-intrusive serialization for Neighbor Search class. We need this
- * definition because we are going to use the serialize function for boost
- * variant, which will look for a serialize function for its member types.
+ * Non-intrusive serialization for NeighborSearch class. We need this definition
+ * because we are going to use the serialize function for boost variant, which
+ * will look for a serialize function for its member types.
  */
 template<typename Archive,
          typename SortPolicy,
-         typename MetrType,
-         typename MatType,
          template<typename TreeMetricType,
                   typename TreeStatType,
-                  typename TreeMatType> class TreeType,
-         template<typename RuleType> class TraversalType>
+                  typename TreeMatType> class TreeType>
 void serialize(
     Archive& ar,
-    NeighborSearch<SortPolicy, MetrType, MatType, TreeType, TraversalType>& ns,
+    NSType<SortPolicy, TreeType>& ns,
     const unsigned int version)
 {
   ns.Serialize(ar, version);
 }
 
+/**
+ * Non-intrusive serialization for SpillSearch class. We need this definition
+ * because we are going to use the serialize function for boost variant, which
+ * will look for a serialize function for its member types.
+ */
+template<typename Archive>
+void serialize(Archive& ar, NSSpillType& ns, const unsigned int version)
+{
+  ns.Serialize(ar, version);
+}
+
 //! Serialize the kNN model.
 template<typename SortPolicy>
 template<typename Archive>
@@ -444,8 +454,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
           epsilon);
       break;
     case SPILL_TREE:
-      nSearch = new NSType<SortPolicy, tree::SPTree>(naive, singleMode,
-          epsilon);
+      nSearch = new NSSpillType(naive, singleMode, tau, epsilon);
       break;
   }
 




More information about the mlpack-git mailing list