[mlpack-git] master: Update NSModel to use SpillSearch class. (0120357)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:14 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 0120357e4ae582e69850d79790e19f6c1bd9c5eb
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 28 14:22:00 2016 -0300
Update NSModel to use SpillSearch class.
>---------------------------------------------------------------
0120357e4ae582e69850d79790e19f6c1bd9c5eb
src/mlpack/methods/neighbor_search/ns_model.hpp | 16 +++++---
.../methods/neighbor_search/ns_model_impl.hpp | 45 +++++++++++++---------
2 files changed, 38 insertions(+), 23 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 6a31d5b..03855a1 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -16,6 +16,7 @@
#include <mlpack/core/tree/spill_tree.hpp>
#include <boost/variant.hpp>
#include "neighbor_search.hpp"
+#include "spill_search.hpp"
namespace mlpack {
namespace neighbor {
@@ -35,6 +36,11 @@ using NSType = NeighborSearch<SortPolicy,
NeighborSearchStat<SortPolicy>,
arma::mat>::template DualTreeTraverser>;
+/**
+ * Alias template for euclidean spill search.
+ */
+using NSSpillType = SpillSearch<metric::EuclideanDistance, arma::mat>;
+
template<typename SortPolicy>
struct NSModelName
{
@@ -128,8 +134,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
//! Bichromatic neighbor search on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;
- //! Bichromatic neighbor search on the given NSType specialized for SPTrees.
- void operator()(NSTypeT<tree::SPTree>* ns) const;
+ //! Bichromatic neighbor search specialized for SPTrees.
+ void operator()(NSSpillType* ns) const;
//! Construct the BiSearchVisitor.
BiSearchVisitor(const arma::mat& querySet,
@@ -180,8 +186,8 @@ class TrainVisitor : public boost::static_visitor<void>
//! Train on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;
- //! Train on the given NSType specialized for SPTrees.
- void operator()(NSTypeT<tree::SPTree>* ns) const;
+ //! Train specialized for SPTrees.
+ void operator()(NSSpillType* ns) const;
//! Construct the TrainVisitor object with the given reference set, leafSize
//! for BinarySpaceTrees, and tau for spill trees.
@@ -303,7 +309,7 @@ class NSModel
NSType<SortPolicy, tree::HilbertRTree>*,
NSType<SortPolicy, tree::RPlusTree>*,
NSType<SortPolicy, tree::RPlusPlusTree>*,
- NSType<SortPolicy, tree::SPTree>*> nSearch;
+ NSSpillType*> nSearch;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1bb3e51..d46b7e5 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -73,15 +73,17 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
-//! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+//! Bichromatic neighbor search specialized for SPTrees.
template<typename SortPolicy>
-void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::SPTree>* ns) const
+void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
{
if (ns)
{
if (!ns->Naive() && !ns->SingleMode())
{
- typename NSTypeT<tree::SPTree>::Tree queryTree(std::move(querySet), 0,
+ // For Dual Tree Search on SpillTrees, the queryTree must be built with
+ // non overlapping (tau = 0).
+ typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
leafSize);
ns->Search(&queryTree, k, neighbors, distances);
}
@@ -160,9 +162,9 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
-//! Train on the given NSType specialized for SPTrees.
+//! Train specialized for SPTrees.
template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
+void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
{
if (ns)
{
@@ -170,11 +172,11 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
ns->Train(std::move(referenceSet));
else
{
- typename NSTypeT<tree::SPTree>::Tree* tree = new typename
- NSTypeT<tree::SPTree>::Tree(std::move(referenceSet), tau, leafSize);
+ typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
+ std::move(referenceSet), tau, leafSize);
ns->Train(tree);
// Give the model ownership of the tree.
- ns->treeOwner = true;
+ ns->neighborSearch.treeOwner = true;
}
}
else
@@ -268,26 +270,34 @@ NSModel<SortPolicy>::~NSModel()
}
/**
- * Non-intrusive serialization for Neighbor Search class. We need this
- * definition because we are going to use the serialize function for boost
- * variant, which will look for a serialize function for its member types.
+ * Non-intrusive serialization for NeighborSearch class. We need this definition
+ * because we are going to use the serialize function for boost variant, which
+ * will look for a serialize function for its member types.
*/
template<typename Archive,
typename SortPolicy,
- typename MetrType,
- typename MatType,
template<typename TreeMetricType,
typename TreeStatType,
- typename TreeMatType> class TreeType,
- template<typename RuleType> class TraversalType>
+ typename TreeMatType> class TreeType>
void serialize(
Archive& ar,
- NeighborSearch<SortPolicy, MetrType, MatType, TreeType, TraversalType>& ns,
+ NSType<SortPolicy, TreeType>& ns,
const unsigned int version)
{
ns.Serialize(ar, version);
}
+/**
+ * Non-intrusive serialization for SpillSearch class. We need this definition
+ * because we are going to use the serialize function for boost variant, which
+ * will look for a serialize function for its member types.
+ */
+template<typename Archive>
+void serialize(Archive& ar, NSSpillType& ns, const unsigned int version)
+{
+ ns.Serialize(ar, version);
+}
+
//! Serialize the kNN model.
template<typename SortPolicy>
template<typename Archive>
@@ -444,8 +454,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
epsilon);
break;
case SPILL_TREE:
- nSearch = new NSType<SortPolicy, tree::SPTree>(naive, singleMode,
- epsilon);
+ nSearch = new NSSpillType(naive, singleMode, tau, epsilon);
break;
}
More information about the mlpack-git
mailing list