[mlpack-git] master: Modify NSModel to use boost variant. (86a9852)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Jun 13 14:29:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a7e8d3bac60d6aa2717f27e0d4f6c53ff20607f5...779556fed748819a18cc898d9a6f69900740ef23
>---------------------------------------------------------------
commit 86a9852f19daf58e4ae5a3d3ca745deee43d8d16
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Mon Jun 13 10:31:58 2016 -0300
Modify NSModel to use boost variant.
>---------------------------------------------------------------
86a9852f19daf58e4ae5a3d3ca745deee43d8d16
.../methods/neighbor_search/neighbor_search.hpp | 5 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 153 ++++++-
.../methods/neighbor_search/ns_model_impl.hpp | 506 +++++++++------------
3 files changed, 345 insertions(+), 319 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e84b74c..bf62ae7 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -27,8 +27,7 @@ namespace neighbor /** Neighbor-search routines. These include
* searches. */ {
// Forward declaration.
-template<typename SortPolicy>
-class NSModel;
+class TrainVisitor;
/**
* The NeighborSearch class is a template class for performing distance-based
@@ -308,7 +307,7 @@ class NeighborSearch
bool treeNeedsReset;
//! The NSModel class should have access to internal members.
- friend class NSModel<SortPolicy>;
+ friend class TrainVisitor;
}; // class NeighborSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 9c16199..458bf97 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -13,12 +13,24 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
-
+#include <boost/variant.hpp>
#include "neighbor_search.hpp"
namespace mlpack {
namespace neighbor {
+template<typename SortPolicy,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+using NSType = NeighborSearch<SortPolicy,
+ metric::EuclideanDistance,
+ arma::mat,
+ TreeType,
+ TreeType<metric::EuclideanDistance,
+ NeighborSearchStat<SortPolicy>,
+ arma::mat>::template DualTreeTraverser>;
+
template<typename SortPolicy>
struct NSModelName
{
@@ -37,6 +49,121 @@ struct NSModelName<FurthestNeighborSort>
static const std::string Name() { return "furthest_neighbor_search_model"; }
};
+class SearchKVisitor : public boost::static_visitor<void>
+{
+ private:
+ const size_t k;
+ arma::Mat<size_t>& neighbors;
+ arma::mat& distances;
+
+ public:
+ template<typename NSType>
+ void operator()(NSType *ns) const;
+
+ SearchKVisitor(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+};
+
+class SearchVisitor : public boost::static_visitor<void>
+{
+ private:
+ const arma::mat& querySet;
+ const size_t k;
+ arma::Mat<size_t>& neighbors;
+ arma::mat& distances;
+ const size_t leafSize;
+
+ template<typename NSType>
+ void SearchLeaf(NSType *ns) const;
+
+ public:
+ template<typename SortPolicy,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(NSType<SortPolicy,TreeType> *ns) const;
+
+ template<typename SortPolicy>
+ void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+
+ template<typename SortPolicy>
+ void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+
+ SearchVisitor(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ const size_t leafSize);
+};
+
+class TrainVisitor : public boost::static_visitor<void>
+{
+ private:
+ arma::mat&& referenceSet;
+ size_t leafSize;
+
+ template<typename NSType>
+ void TrainLeaf(NSType* ns) const;
+
+ public:
+ template<typename SortPolicy,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ void operator()(NSType<SortPolicy,TreeType> *ns) const;
+
+ template<typename SortPolicy>
+ void operator()(NSType<SortPolicy,tree::KDTree> *ns) const;
+
+ template<typename SortPolicy>
+ void operator()(NSType<SortPolicy,tree::BallTree> *ns) const;
+
+ TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+};
+
+class SingleModeVisitor : public boost::static_visitor<bool&>
+{
+ public:
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
+};
+
+class NaiveVisitor : public boost::static_visitor<bool&>
+{
+ public:
+ template<typename NSType>
+ bool& operator()(NSType *ns) const;
+};
+
+class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
+{
+ public:
+ template<typename NSType>
+ const arma::mat& operator()(NSType *ns) const;
+};
+
+class DeleteVisitor : public boost::static_visitor<void>
+{
+ public:
+ template<typename NSType>
+ void operator()(NSType *ns) const;
+};
+
+template<typename Archive>
+class SerializeVisitor : public boost::static_visitor<void>
+{
+ private:
+ Archive& ar;
+ const std::string& name;
+
+ public:
+ template<typename NSType>
+ void operator()(NSType *ns) const;
+
+ SerializeVisitor(Archive& ar, const std::string& name);
+};
+
template<typename SortPolicy>
class NSModel
{
@@ -59,24 +186,12 @@ class NSModel
bool randomBasis;
arma::mat q;
- template<template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType>
- using NSType = NeighborSearch<SortPolicy,
- metric::EuclideanDistance,
- arma::mat,
- TreeType,
- TreeType<metric::EuclideanDistance,
- NeighborSearchStat<SortPolicy>,
- arma::mat>::template DualTreeTraverser>;
-
- // Only one of these pointers will be non-NULL.
- NSType<tree::KDTree>* kdTreeNS;
- NSType<tree::StandardCoverTree>* coverTreeNS;
- NSType<tree::RTree>* rTreeNS;
- NSType<tree::RStarTree>* rStarTreeNS;
- NSType<tree::BallTree>* ballTreeNS;
- NSType<tree::XTree>* xTreeNS;
+ boost::variant<NSType<SortPolicy, tree::KDTree>*,
+ NSType<SortPolicy, tree::StandardCoverTree>*,
+ NSType<SortPolicy, tree::RTree>*,
+ NSType<SortPolicy, tree::RStarTree>*,
+ NSType<SortPolicy, tree::BallTree>*,
+ NSType<SortPolicy, tree::XTree>*> nSearch;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index e4aa7c1..ea2206e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -16,6 +16,190 @@
namespace mlpack {
namespace neighbor {
+SearchKVisitor::SearchKVisitor(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) :
+ k(k),
+ neighbors(neighbors),
+ distances(distances)
+{}
+
+template<typename NSType>
+void SearchKVisitor::operator()(NSType *ns) const
+{
+ if (ns)
+ return ns->Search(k, neighbors, distances);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+SearchVisitor::SearchVisitor(const arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ const size_t leafSize) :
+ querySet(querySet),
+ k(k),
+ neighbors(neighbors),
+ distances(distances),
+ leafSize(leafSize)
+{}
+
+template<typename SortPolicy,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void SearchVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+{
+ if (ns)
+ return ns->Search(querySet, k, neighbors, distances);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void SearchVisitor::operator()(NSType<SortPolicy,tree::KDTree> *ns) const
+{
+ if (ns)
+ return SearchLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void SearchVisitor::operator()(NSType<SortPolicy,tree::BallTree> *ns) const
+{
+ if (ns)
+ return SearchLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename NSType>
+void SearchVisitor::SearchLeaf(NSType *ns) const
+{
+ if (!ns->Naive() && !ns->SingleMode())
+ {
+ std::vector<size_t> oldFromNewQueries;
+ typename NSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
+ leafSize);
+
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ ns->Search(&queryTree, k, neighborsOut, distancesOut);
+
+ // Unmap the query points.
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+ {
+ neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+ }
+ }
+ else
+ ns->Search(querySet, k, neighbors, distances);
+}
+
+
+TrainVisitor::TrainVisitor(arma::mat&& referenceSet, const size_t leafSize) :
+ referenceSet(std::move(referenceSet)),
+ leafSize(leafSize)
+{}
+
+template<typename SortPolicy,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void TrainVisitor::operator()(NSType<SortPolicy,TreeType> *ns) const
+{
+ if (ns)
+ return ns->Train(std::move(referenceSet));
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void TrainVisitor::operator ()(NSType<SortPolicy,tree::KDTree> *ns) const
+{
+ if (ns)
+ return TrainLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+void TrainVisitor::operator ()(NSType<SortPolicy,tree::BallTree> *ns) const
+{
+ if (ns)
+ return TrainLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename NSType>
+void TrainVisitor::TrainLeaf(NSType* ns) const
+{
+ if (ns->Naive())
+ ns->Train(std::move(referenceSet));
+ else
+ {
+ std::vector<size_t> oldFromNewReferences;
+ typename NSType::Tree* tree =
+ new typename NSType::Tree(std::move(referenceSet),
+ oldFromNewReferences, leafSize);
+ ns->Train(tree);
+
+ // Give the model ownership of the tree and the mappings.
+ ns->treeOwner = true;
+ ns->oldFromNewReferences = std::move(oldFromNewReferences);
+ }
+}
+
+
+template<typename NSType>
+bool& SingleModeVisitor::operator()(NSType *ns) const
+{
+ if (ns)
+ return ns->SingleMode();
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+bool& NaiveVisitor::operator()(NSType *ns) const
+{
+ if (ns)
+ return ns->Naive();
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+const arma::mat& ReferenceSetVisitor::operator()(NSType *ns) const
+{
+ if (ns)
+ return ns->ReferenceSet();
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+
+template<typename NSType>
+void DeleteVisitor::operator()(NSType *ns) const
+{
+ if (ns)
+ delete ns;
+}
+
+
+template<typename Archive>
+SerializeVisitor<Archive>::SerializeVisitor(Archive& ar,
+ const std::string& name) :
+ ar(ar),
+ name(name)
+{}
+
+template<typename Archive>
+template<typename NSType>
+void SerializeVisitor<Archive>::operator()(NSType *ns) const
+{
+ ar & data::CreateNVP(ns, name);
+}
+
/**
* Initialize the NSModel with the given type and whether or not a random
* basis should be used.
@@ -23,13 +207,7 @@ namespace neighbor {
template<typename SortPolicy>
NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
treeType(treeType),
- randomBasis(randomBasis),
- kdTreeNS(NULL),
- coverTreeNS(NULL),
- rTreeNS(NULL),
- rStarTreeNS(NULL),
- ballTreeNS(NULL),
- xTreeNS(NULL)
+ randomBasis(randomBasis)
{
// Nothing to do.
}
@@ -38,18 +216,7 @@ NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
template<typename SortPolicy>
NSModel<SortPolicy>::~NSModel()
{
- if (kdTreeNS)
- delete kdTreeNS;
- if (coverTreeNS)
- delete coverTreeNS;
- if (rTreeNS)
- delete rTreeNS;
- if (rStarTreeNS)
- delete rStarTreeNS;
- if (ballTreeNS)
- delete ballTreeNS;
- if (xTreeNS)
- delete xTreeNS;
+ boost::apply_visitor(DeleteVisitor(), nSearch);
}
//! Serialize the kNN model.
@@ -64,148 +231,43 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
// This should never happen, but just in case, be clean with memory.
if (Archive::is_loading::value)
- {
- if (kdTreeNS)
- delete kdTreeNS;
- if (coverTreeNS)
- delete coverTreeNS;
- if (rTreeNS)
- delete rTreeNS;
- if (rStarTreeNS)
- delete rStarTreeNS;
- if (ballTreeNS)
- delete ballTreeNS;
- if (xTreeNS)
- delete xTreeNS;
-
- // Set all the pointers to NULL.
- kdTreeNS = NULL;
- coverTreeNS = NULL;
- rTreeNS = NULL;
- rStarTreeNS = NULL;
- ballTreeNS = NULL;
- xTreeNS = NULL;
- }
+ boost::apply_visitor(DeleteVisitor(), nSearch);
// We'll only need to serialize one of the kNN objects, based on the type.
const std::string& name = NSModelName<SortPolicy>::Name();
- switch (treeType)
- {
- case KD_TREE:
- ar & data::CreateNVP(kdTreeNS, name);
- break;
- case COVER_TREE:
- ar & data::CreateNVP(coverTreeNS, name);
- break;
- case R_TREE:
- ar & data::CreateNVP(rTreeNS, name);
- break;
- case R_STAR_TREE:
- ar & data::CreateNVP(rStarTreeNS, name);
- break;
- case BALL_TREE:
- ar & data::CreateNVP(ballTreeNS, name);
- break;
- case X_TREE:
- ar & data::CreateNVP(xTreeNS, name);
- break;
- }
+ SerializeVisitor<Archive> s(ar, name);
+ boost::apply_visitor(s, nSearch);
}
template<typename SortPolicy>
const arma::mat& NSModel<SortPolicy>::Dataset() const
{
- if (kdTreeNS)
- return kdTreeNS->ReferenceSet();
- else if (coverTreeNS)
- return coverTreeNS->ReferenceSet();
- else if (rTreeNS)
- return rTreeNS->ReferenceSet();
- else if (rStarTreeNS)
- return rStarTreeNS->ReferenceSet();
- else if (ballTreeNS)
- return ballTreeNS->ReferenceSet();
- else if (xTreeNS)
- return xTreeNS->ReferenceSet();
-
- throw std::runtime_error("no neighbor search model initialized");
+ return boost::apply_visitor(ReferenceSetVisitor(), nSearch);
}
//! Expose singleMode.
template<typename SortPolicy>
bool NSModel<SortPolicy>::SingleMode() const
{
- if (kdTreeNS)
- return kdTreeNS->SingleMode();
- else if (coverTreeNS)
- return coverTreeNS->SingleMode();
- else if (rTreeNS)
- return rTreeNS->SingleMode();
- else if (rStarTreeNS)
- return rStarTreeNS->SingleMode();
- else if (ballTreeNS)
- return ballTreeNS->SingleMode();
- else if (xTreeNS)
- return xTreeNS->SingleMode();
-
- throw std::runtime_error("no neighbor search model initialized");
+ return boost::apply_visitor(SingleModeVisitor(), nSearch);
}
template<typename SortPolicy>
bool& NSModel<SortPolicy>::SingleMode()
{
- if (kdTreeNS)
- return kdTreeNS->SingleMode();
- else if (coverTreeNS)
- return coverTreeNS->SingleMode();
- else if (rTreeNS)
- return rTreeNS->SingleMode();
- else if (rStarTreeNS)
- return rStarTreeNS->SingleMode();
- else if (ballTreeNS)
- return ballTreeNS->SingleMode();
- else if (xTreeNS)
- return xTreeNS->SingleMode();
-
- throw std::runtime_error("no neighbor search model initialized");
+ return boost::apply_visitor(SingleModeVisitor(), nSearch);
}
template<typename SortPolicy>
bool NSModel<SortPolicy>::Naive() const
{
- if (kdTreeNS)
- return kdTreeNS->Naive();
- else if (coverTreeNS)
- return coverTreeNS->Naive();
- else if (rTreeNS)
- return rTreeNS->Naive();
- else if (rStarTreeNS)
- return rStarTreeNS->Naive();
- else if (ballTreeNS)
- return ballTreeNS->Naive();
- else if (xTreeNS)
- return xTreeNS->Naive();
-
- throw std::runtime_error("no neighbor search model initialized");
+ return boost::apply_visitor(NaiveVisitor(), nSearch);
}
template<typename SortPolicy>
bool& NSModel<SortPolicy>::Naive()
{
- if (kdTreeNS)
- return kdTreeNS->Naive();
- else if (coverTreeNS)
- return coverTreeNS->Naive();
- else if (rTreeNS)
- return rTreeNS->Naive();
- else if (rStarTreeNS)
- return rStarTreeNS->Naive();
- else if (ballTreeNS)
- return ballTreeNS->Naive();
- else if (xTreeNS)
- return xTreeNS->Naive();
-
- throw std::runtime_error("no neighbor search model initialized");
+ return boost::apply_visitor(NaiveVisitor(), nSearch);
}
//! Build the reference tree.
@@ -248,18 +310,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
}
// Clean memory, if necessary.
- if (kdTreeNS)
- delete kdTreeNS;
- if (coverTreeNS)
- delete coverTreeNS;
- if (rTreeNS)
- delete rTreeNS;
- if (rStarTreeNS)
- delete rStarTreeNS;
- if (ballTreeNS)
- delete ballTreeNS;
- if (xTreeNS)
- delete xTreeNS;
+ boost::apply_visitor(DeleteVisitor(), nSearch);
// Do we need to modify the reference set?
if (randomBasis)
@@ -274,69 +325,29 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
switch (treeType)
{
case KD_TREE:
- // If necessary, build the kd-tree.
- if (naive)
- {
- kdTreeNS = new NSType<tree::KDTree>(std::move(referenceSet), naive,
- singleMode);
- }
- else
- {
- std::vector<size_t> oldFromNewReferences;
- typename NSType<tree::KDTree>::Tree* kdTree =
- new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
- oldFromNewReferences, leafSize);
- kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
-
- // Give the model ownership of the tree and the mappings.
- kdTreeNS->treeOwner = true;
- kdTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
- }
-
+ nSearch = new NSType<SortPolicy, tree::KDTree>(naive, singleMode);
break;
case COVER_TREE:
- // If necessary, build the cover tree.
- coverTreeNS = new NSType<tree::StandardCoverTree>(std::move(referenceSet),
- naive, singleMode);
+ nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(naive,
+ singleMode);
break;
case R_TREE:
- // If necessary, build the R tree.
- rTreeNS = new NSType<tree::RTree>(std::move(referenceSet), naive,
- singleMode);
+ nSearch = new NSType<SortPolicy, tree::RTree>(naive, singleMode);
break;
case R_STAR_TREE:
- // If necessary, build the R* tree.
- rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
- singleMode);
+ nSearch = new NSType<SortPolicy, tree::RStarTree>(naive, singleMode);
break;
case BALL_TREE:
- // If necessary, build the ball tree.
- if (naive)
- {
- ballTreeNS = new NSType<tree::BallTree>(std::move(referenceSet), naive,
- singleMode);
- }
- else
- {
- std::vector<size_t> oldFromNewReferences;
- typename NSType<tree::BallTree>::Tree* ballTree =
- new typename NSType<tree::BallTree>::Tree(std::move(referenceSet),
- oldFromNewReferences, leafSize);
- ballTreeNS = new NSType<tree::BallTree>(ballTree, singleMode);
-
- // Give the model ownership of the tree and the mappings.
- ballTreeNS->treeOwner = true;
- ballTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
- }
-
+ nSearch = new NSType<SortPolicy, tree::BallTree>(naive, singleMode);
break;
case X_TREE:
- // If necessary, build the X tree.
- xTreeNS = new NSType<tree::XTree>(std::move(referenceSet), naive,
- singleMode);
+ nSearch = new NSType<SortPolicy, tree::XTree>(naive, singleMode);
break;
}
+ TrainVisitor tn(std::move(referenceSet),leafSize);
+ boost::apply_visitor(tn, nSearch);
+
if (!naive)
{
Timer::Stop("tree_building");
@@ -363,88 +374,8 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
else
Log::Info << "brute-force (naive) search..." << std::endl;
- switch (treeType)
- {
- case KD_TREE:
- if (!kdTreeNS->Naive() && !kdTreeNS->SingleMode())
- {
- // Build a second tree and search.
- Timer::Start("tree_building");
- Log::Info << "Building query tree..." << std::endl;
- std::vector<size_t> oldFromNewQueries;
- typename NSType<tree::KDTree>::Tree queryTree(std::move(querySet),
- oldFromNewQueries, leafSize);
- Log::Info << "Tree built." << std::endl;
- Timer::Stop("tree_building");
-
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
- kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
-
- // Unmap the query points.
- distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
- neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
- for (size_t i = 0; i < neighborsOut.n_cols; ++i)
- {
- neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
- distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
- }
- }
- else
- {
- // Search without building a second tree.
- kdTreeNS->Search(querySet, k, neighbors, distances);
- }
- break;
- case COVER_TREE:
- // No mapping necessary.
- coverTreeNS->Search(querySet, k, neighbors, distances);
- break;
- case R_TREE:
- // No mapping necessary.
- rTreeNS->Search(querySet, k, neighbors, distances);
- break;
- case R_STAR_TREE:
- // No mapping necessary.
- rStarTreeNS->Search(querySet, k, neighbors, distances);
- break;
- case BALL_TREE:
- if (!ballTreeNS->Naive() && !ballTreeNS->SingleMode())
- {
- // Build a second tree and search.
- Timer::Start("tree_building");
- Log::Info << "Building query tree..." << std::endl;
- std::vector<size_t> oldFromNewQueries;
- typename NSType<tree::BallTree>::Tree queryTree(std::move(querySet),
- oldFromNewQueries, leafSize);
- Log::Info << "Tree built." << std::endl;
- Timer::Stop("tree_building");
-
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
- ballTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
-
- // Unmap the query points.
- distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
- neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
- for (size_t i = 0; i < neighborsOut.n_cols; ++i)
- {
- neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
- distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
- }
- }
- else
- {
- // Search without building a second tree.
- ballTreeNS->Search(querySet, k, neighbors, distances);
- }
-
- break;
- case X_TREE:
- // No mapping necessary.
- xTreeNS->Search(querySet, k, neighbors, distances);
- break;
- }
+ SearchVisitor search(querySet, k, neighbors, distances, leafSize);
+ boost::apply_visitor(search, nSearch);
}
//! Perform neighbor search.
@@ -461,27 +392,8 @@ void NSModel<SortPolicy>::Search(const size_t k,
else
Log::Info << "brute-force (naive) search..." << std::endl;
- switch (treeType)
- {
- case KD_TREE:
- kdTreeNS->Search(k, neighbors, distances);
- break;
- case COVER_TREE:
- coverTreeNS->Search(k, neighbors, distances);
- break;
- case R_TREE:
- rTreeNS->Search(k, neighbors, distances);
- break;
- case R_STAR_TREE:
- rStarTreeNS->Search(k, neighbors, distances);
- break;
- case BALL_TREE:
- ballTreeNS->Search(k, neighbors, distances);
- break;
- case X_TREE:
- xTreeNS->Search(k, neighbors, distances);
- break;
- }
+ SearchKVisitor search(k, neighbors, distances);
+ boost::apply_visitor(search, nSearch);
}
//! Get the name of the tree type.
More information about the mlpack-git
mailing list