[mlpack-git] master: Use a pointer to the dataset. (4db02b1)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:16 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7
>---------------------------------------------------------------
commit 4db02b1aac38a14a97d656484fc7fd7a2c295670
Author: ryan <ryan at ratml.org>
Date: Tue Oct 27 11:57:57 2015 -0400
Use a pointer to the dataset.
This is in preparation for serialization support.
>---------------------------------------------------------------
4db02b1aac38a14a97d656484fc7fd7a2c295670
src/mlpack/methods/rann/ra_search.hpp | 6 ++--
src/mlpack/methods/rann/ra_search_impl.hpp | 48 ++++++++++++++++--------------
2 files changed, 30 insertions(+), 24 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index dfa5a72..3cd8714 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -295,11 +295,13 @@ class RASearch
std::vector<size_t> oldFromNewReferences;
//! Pointer to the root of the reference tree.
Tree* referenceTree;
- //! Reference dataset.
- const MatType& referenceSet;
+ //! Reference dataset. In some situations we may own this dataset.
+ const MatType* referenceSet;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
+ //! If true, we are responsible for deleting the dataset.
+ bool setOwner;
//! Indicates if naive random sampling on the set is being used.
bool naive;
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index 132d803..cc96166 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -62,8 +62,9 @@ RASearch(const MatType& referenceSetIn,
const MetricType metric) :
referenceTree(naive ? NULL : aux::BuildTree<Tree>(
const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
- referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+ referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
treeOwner(!naive),
+ setOwner(false),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
tau(tau),
@@ -93,8 +94,9 @@ RASearch(Tree* referenceTree,
const size_t singleSampleLimit,
const MetricType metric) :
referenceTree(referenceTree),
- referenceSet(referenceTree->Dataset()),
+ referenceSet(&referenceTree->Dataset()),
treeOwner(false),
+ setOwner(false),
naive(false),
singleMode(singleMode),
tau(tau),
@@ -121,6 +123,8 @@ RASearch<SortPolicy, MetricType, MatType, TreeType>::
{
if (treeOwner && referenceTree)
delete referenceTree;
+ if (setOwner)
+ delete referenceSet;
}
/**
@@ -171,16 +175,16 @@ Search(const MatType& querySet,
if (naive)
{
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, false);
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
- const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
- tau, alpha);
+ const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
+ k, tau, alpha);
arma::uvec distinctSamples;
- RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+ RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols,
distinctSamples);
// Run the base case on each combination of query point and sampled
@@ -191,7 +195,7 @@ Search(const MatType& querySet,
}
else if (singleMode)
{
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, false);
@@ -226,7 +230,7 @@ Search(const MatType& querySet,
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
- RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+ RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
*distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
firstLeafExact, singleSampleLimit, false);
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
@@ -340,7 +344,7 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
// Create the helper object for the tree traversal.
typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr, distances,
+ RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, distances,
metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, false);
@@ -390,14 +394,14 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
}
// Initialize results.
- neighborPtr->set_size(k, referenceSet.n_cols);
+ neighborPtr->set_size(k, referenceSet->n_cols);
neighborPtr->fill(size_t() - 1);
- distancePtr->set_size(k, referenceSet.n_cols);
+ distancePtr->set_size(k, referenceSet->n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
// Create the helper object for the tree traversal.
typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+ RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, true /* sets are the same */);
@@ -405,15 +409,15 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
{
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
- const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
- tau, alpha);
+ const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
+ k, tau, alpha);
arma::uvec distinctSamples;
- RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+ RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols,
distinctSamples);
// The naive brute-force solution.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
rules.BaseCase(i, j);
}
else if (singleMode)
@@ -422,7 +426,7 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
traverser.Traverse(i, *referenceTree);
}
else
@@ -438,8 +442,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
// Do we need to map the reference indices?
if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
- neighbors.set_size(k, referenceSet.n_cols);
- distances.set_size(k, referenceSet.n_cols);
+ neighbors.set_size(k, referenceSet->n_cols);
+ distances.set_size(k, referenceSet->n_cols);
for (size_t i = 0; i < distances.n_cols; ++i)
{
@@ -486,8 +490,8 @@ std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
{
std::ostringstream convert;
convert << "RASearch [" << this << "]" << std::endl;
- convert << " referenceSet: " << referenceSet.n_rows << "x"
- << referenceSet.n_cols << std::endl;
+ convert << " referenceSet: " << referenceSet->n_rows << "x"
+ << referenceSet->n_cols << std::endl;
convert << " naive: ";
if (naive)
More information about the mlpack-git
mailing list