[mlpack-git] master: Allow ownership of data matrix. (9ae893e)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Oct 22 11:11:17 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/c81893381e80d4ecae4283cec5fe5264bdf4f677...d1dfaa8e0978e01c240660a3217e68c4fa7c3e0a
>---------------------------------------------------------------
commit 9ae893ee70f006d0e58c3cf647120b381039d29f
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Oct 22 15:10:43 2015 +0000
Allow ownership of data matrix.
>---------------------------------------------------------------
9ae893ee70f006d0e58c3cf647120b381039d29f
src/mlpack/methods/range_search/range_search.hpp | 7 ++--
.../methods/range_search/range_search_impl.hpp | 39 +++++++++++++---------
2 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index d1a7c02..de95d79 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -219,11 +219,14 @@ class RangeSearch
std::vector<size_t> oldFromNewReferences;
//! Reference tree.
Tree* referenceTree;
- //! Reference set (data should be accessed using this).
- const MatType& referenceSet;
+ //! Reference set (data should be accessed using this). In some situations we
+ //! may be the owner of this.
+ const MatType* referenceSet;
//! If true, this object is responsible for deleting the trees.
bool treeOwner;
+ //! If true, we own the reference set.
+ bool setOwner;
//! If true, O(n^2) naive computation is used.
bool naive;
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index c6b2e44..f832af7 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -51,8 +51,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
const MetricType metric) :
referenceTree(naive ? NULL : BuildTree<Tree>(
const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
- referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+ referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
treeOwner(!naive), // If in naive mode, we are not building any trees.
+ setOwner(false),
naive(naive),
singleMode(!naive && singleMode), // Naive overrides single mode.
metric(metric),
@@ -72,8 +73,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
const bool singleMode,
const MetricType metric) :
referenceTree(referenceTree),
- referenceSet(referenceTree->Dataset()),
+ referenceSet(&referenceTree->Dataset()),
treeOwner(false),
+ setOwner(false),
naive(false),
singleMode(singleMode),
metric(metric),
@@ -92,6 +94,8 @@ RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
{
if (treeOwner && referenceTree)
delete referenceTree;
+ if (setOwner && referenceSet)
+ delete referenceSet;
}
template<typename MetricType,
@@ -146,12 +150,12 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
if (naive)
{
- RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+ RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
metric);
// The naive brute-force solution.
for (size_t i = 0; i < querySet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
rules.BaseCase(i, j);
baseCases += (querySet.n_cols * referenceSet->n_cols);
@@ -159,7 +163,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
else if (singleMode)
{
// Create the traverser.
- RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+ RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
metric);
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -181,7 +185,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
Timer::Start("range_search/computing_neighbors");
// Create the traverser.
- RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+ RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
*distancePtr, metric);
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
@@ -298,7 +302,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Create the helper object for the traversal.
typedef RangeSearchRules<MetricType, Tree> RuleType;
- RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+ RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
distances, metric);
// Create the traverser.
@@ -355,21 +359,24 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
// Resize each vector.
neighborPtr->clear(); // Just in case there was anything in it.
- neighborPtr->resize(referenceSet.n_cols);
+ neighborPtr->resize(referenceSet->n_cols);
distancePtr->clear();
- distancePtr->resize(referenceSet.n_cols);
+ distancePtr->resize(referenceSet->n_cols);
// Create the helper object for the traversal.
typedef RangeSearchRules<MetricType, Tree> RuleType;
- RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
- metric, true /* don't return the query point in the results */);
+ RuleType rules(*referenceSet, *referenceSet, range, *neighborPtr,
+ *distancePtr, metric, true /* don't return the query in the results */);
if (naive)
{
// The naive brute-force solution.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
rules.BaseCase(i, j);
+
+ baseCases = (referenceSet->n_cols * referenceSet->n_cols);
+ scores = 0;
}
else if (singleMode)
{
@@ -377,7 +384,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
traverser.Traverse(i, *referenceTree);
baseCases = rules.BaseCases();
@@ -400,9 +407,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
neighbors.clear();
- neighbors.resize(referenceSet.n_cols);
+ neighbors.resize(referenceSet->n_cols);
distances.clear();
- distances.resize(referenceSet.n_cols);
+ distances.resize(referenceSet->n_cols);
for (size_t i = 0; i < distances.size(); i++)
{
More information about the mlpack-git
mailing list