[mlpack-git] master: Refactor to handle interally-copying trees correctly. (45d9117)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:29 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit 45d9117361478923883e002d87f573f2be1be58c
Author: ryan <ryan at ratml.org>
Date: Mon Jul 27 00:09:30 2015 -0400
Refactor to handle interally-copying trees correctly.
>---------------------------------------------------------------
45d9117361478923883e002d87f573f2be1be58c
src/mlpack/methods/range_search/range_search.hpp | 15 +++----
.../methods/range_search/range_search_impl.hpp | 48 ++++++----------------
2 files changed, 21 insertions(+), 42 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 04d3071..2b7d03e 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -198,18 +198,19 @@ class RangeSearch
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances);
- // Returns a string representation of this object.
+ //! Returns a string representation of this object.
std::string ToString() const;
+ //! Return the reference tree (or NULL if in naive mode).
+ Tree* ReferenceTree() { return referenceTree; }
+
private:
- //! Copy of reference matrix; used when a tree is built internally.
- MatType referenceCopy;
- //! Reference set (data should be accessed using this).
- const MatType& referenceSet;
- //! Reference tree.
- Tree* referenceTree;
//! Mappings to old reference indices (used when this object builds trees).
std::vector<size_t> oldFromNewReferences;
+ //! Reference tree.
+ Tree* referenceTree;
+ //! Reference set (data should be accessed using this).
+ const MatType& referenceSet;
//! If true, this object is responsible for deleting the trees.
bool treeOwner;
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index a7ad027..d9325c7 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -48,32 +48,15 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
const bool naive,
const bool singleMode,
const MetricType metric) :
- referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
- ? referenceCopy : referenceSetIn),
- referenceTree(NULL),
+ referenceTree(naive ? NULL : BuildTree<Tree>(
+ const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
+ referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
treeOwner(!naive), // If in naive mode, we are not building any trees.
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
metric(metric)
{
- // Build the tree.
- Timer::Start("range_search/tree_building");
-
- // If in naive mode, then we do not need to build trees.
- if (!naive)
- {
- // Copy the dataset, if it will be modified during tree building.
- if (tree::TreeTraits<Tree>::RearrangesDataset)
- referenceCopy = referenceSetIn;
-
- // The const_cast is safe; if RearrangesDataset == false, then it'll be
- // casted back to const anyway, and if not, referenceSet points to
- // referenceCopy, which isn't const.
- referenceTree = BuildTree<Tree>(const_cast<MatType&>(referenceSet),
- oldFromNewReferences);
- }
-
- Timer::Stop("range_search/tree_building");
+ // Nothing to do.
}
template<typename MetricType,
@@ -84,8 +67,8 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
Tree* referenceTree,
const bool singleMode,
const MetricType metric) :
- referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
+ referenceSet(referenceTree->Dataset()),
treeOwner(false),
naive(false),
singleMode(singleMode),
@@ -119,16 +102,6 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// This will hold mappings for query points, if necessary.
std::vector<size_t> oldFromNewQueries;
- // If we will be building a tree and it will modify the query set, make a copy
- // of the dataset.
- MatType queryCopy;
- const bool needsCopy = (!naive && !singleMode &&
- tree::TreeTraits<Tree>::RearrangesDataset);
- if (needsCopy)
- queryCopy = querySet;
-
- const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
-
// If we have built the trees ourselves, then we will have to map all the
// indices back to their original indices when this computation is finished.
// To avoid extra copies, we will store the unmapped neighbors and distances
@@ -158,11 +131,12 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Create the helper object for the traversal.
typedef RangeSearchRules<MetricType, Tree> RuleType;
- RuleType rules(referenceSet, querySetRef, range, *neighborPtr, *distancePtr,
- metric);
if (naive)
{
+ RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+ metric);
+
// The naive brute-force solution.
for (size_t i = 0; i < querySet.n_cols; ++i)
for (size_t j = 0; j < referenceSet.n_cols; ++j)
@@ -171,6 +145,8 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
else if (singleMode)
{
// Create the traverser.
+ RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+ metric);
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
@@ -182,12 +158,14 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Build the query tree.
Timer::Stop("range_search/computing_neighbors");
Timer::Start("range_search/tree_building");
- Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+ Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySet),
oldFromNewQueries);
Timer::Stop("range_search/tree_building");
Timer::Start("range_search/computing_neighbors");
// Create the traverser.
+ RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+ *distancePtr, metric);
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
More information about the mlpack-git
mailing list