[mlpack-git] master: Refactor NeighborSearch internals to deal with the tree holding the dataset internally. (e7890e2)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:12 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit e7890e2540ddf84cb6bb5fd798e0ad2487da99da
Author: ryan <ryan at ratml.org>
Date: Wed Apr 22 18:11:06 2015 -0400
Refactor NeighborSearch internals to deal with the tree holding the dataset internally.
>---------------------------------------------------------------
e7890e2540ddf84cb6bb5fd798e0ad2487da99da
.../methods/neighbor_search/neighbor_search.hpp | 11 ++---
.../neighbor_search/neighbor_search_impl.hpp | 56 +++++++---------------
2 files changed, 21 insertions(+), 46 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 18111a3..6621b72 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -174,15 +174,12 @@ class NeighborSearch
size_t& Scores() { return scores; }
private:
- //! Copy of reference dataset (if we need it, because tree building modifies
- //! it).
- typename TreeType::Mat referenceCopy;
- //! Reference dataset.
- const typename TreeType::Mat& referenceSet;
- //! Pointer to the root of the reference tree.
- TreeType* referenceTree;
//! Permutations of reference points during tree building.
std::vector<size_t> oldFromNewReferences;
+ //! Pointer to the root of the reference tree.
+ TreeType* referenceTree;
+ //! Reference to reference dataset.
+ const typename TreeType::Mat& referenceSet;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 5a1c3c6..c6dc565 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -18,7 +18,7 @@ namespace neighbor {
//! Call the tree constructor that does mapping.
template<typename TreeType>
TreeType* BuildTree(
- typename TreeType::Mat& dataset,
+ const typename TreeType::Mat& dataset,
std::vector<size_t>& oldFromNew,
typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -46,9 +46,9 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
const bool naive,
const bool singleMode,
const MetricType metric) :
- referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
- ? referenceCopy : referenceSetIn),
- referenceTree(NULL),
+ referenceTree(naive ? NULL :
+ BuildTree<TreeType>(referenceSetIn, oldFromNewReferences)),
+ referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
treeOwner(!naive), // False if a tree was passed. If naive, then no trees.
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
@@ -56,25 +56,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
baseCases(0),
scores(0)
{
- // Build the tree.
- Timer::Start("tree_building");
-
- if (!naive)
- {
- // Copy the dataset, if it will be modified during tree building.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- referenceCopy = referenceSetIn;
-
- // The const_cast is safe; if RearrangesDataset == false, then it'll be
- // casted back to const anyway, and if not, referenceSet points to
- // referenceCopy, which isn't const.
- referenceTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(referenceSet),
- oldFromNewReferences);
- }
-
- // Stop the timer we started above.
- Timer::Stop("tree_building");
+ // Nothing to do.
}
// Construct the object.
@@ -83,8 +65,8 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
TreeType* referenceTree,
const bool singleMode,
const MetricType metric) :
- referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
+ referenceSet(referenceTree->Dataset()),
treeOwner(false),
naive(false),
singleMode(singleMode),
@@ -142,23 +124,13 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
- // If we will be building a tree and it will modify the query set, make a copy
- // of the dataset.
- typename TreeType::Mat queryCopy;
- const bool needsCopy = (!naive && !singleMode &&
- tree::TreeTraits<TreeType>::RearrangesDataset);
- if (needsCopy)
- queryCopy = querySet;
-
- const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
- querySet;
-
- // Create the helper object for the tree traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
- RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr, metric);
if (naive)
{
+ // Create the helper object for the tree traversal.
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
// The naive brute-force traversal.
for (size_t i = 0; i < querySet.n_cols; ++i)
for (size_t j = 0; j < referenceSet.n_cols; ++j)
@@ -168,6 +140,9 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
}
else if (singleMode)
{
+ // Create the helper object for the tree traversal.
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
// Create the traverser.
typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -186,11 +161,14 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
// Build the query tree.
Timer::Stop("computing_neighbors");
Timer::Start("tree_building");
- TreeType* queryTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+ TreeType* queryTree = BuildTree<TreeType>(querySet, oldFromNewQueries);
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
+ // Create the helper object for the tree traversal.
+ RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+ *distancePtr, metric);
+
// Create the traverser.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
More information about the mlpack-git
mailing list