[mlpack-git] master: Refactor for non-modifying TreeTypes. (fa72b45)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:59 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit fa72b45f3ab9ae758ae3319c869fde7074a8d6a7
Author: ryan <ryan at ratml.org>
Date: Mon Jul 27 15:06:02 2015 -0400
Refactor for non-modifying TreeTypes.
>---------------------------------------------------------------
fa72b45f3ab9ae758ae3319c869fde7074a8d6a7
src/mlpack/methods/rann/ra_search.hpp | 13 +++----
src/mlpack/methods/rann/ra_search_impl.hpp | 54 +++++++++++-------------------
2 files changed, 23 insertions(+), 44 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 219b33f..59bd081 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -290,13 +290,12 @@ class RASearch
std::string ToString() const;
private:
- //! Copy of reference dataset (if we need it, because tree building modifies
- //! it).
- MatType referenceCopy;
- //! Reference dataset.
- const MatType& referenceSet;
+ //! Permutations of reference points during tree building.
+ std::vector<size_t> oldFromNewReferences;
//! Pointer to the root of the reference tree.
Tree* referenceTree;
+ //! Reference dataset.
+ const MatType& referenceSet;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
@@ -320,10 +319,6 @@ class RASearch
//! Instantiation of kernel.
MetricType metric;
-
- //! Permutations of reference points during tree building.
- std::vector<size_t> oldFromNewReferences;
-
}; // class RASearch
} // namespace neighbor
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index d44707a..a700869 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -59,9 +59,9 @@ RASearch(const MatType& referenceSetIn,
const bool firstLeafExact,
const size_t singleSampleLimit,
const MetricType metric) :
- referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
- ? referenceCopy : referenceSetIn),
- referenceTree(NULL),
+ referenceTree(naive ? NULL : aux::BuildTree<Tree>(
+ const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
+ referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
treeOwner(!naive),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
@@ -72,20 +72,7 @@ RASearch(const MatType& referenceSetIn,
singleSampleLimit(singleSampleLimit),
metric(metric)
{
- // We'll time tree building.
- Timer::Start("tree_building");
-
- if (!naive)
- {
- if (tree::TreeTraits<Tree>::RearrangesDataset)
- referenceCopy = referenceSetIn;
-
- referenceTree = aux::BuildTree<Tree>(const_cast<MatType&>(referenceSet),
- oldFromNewReferences);
- }
-
- // Stop the timer we started above.
- Timer::Stop("tree_building");
+ // Nothing to do.
}
// Construct the object.
@@ -103,8 +90,8 @@ RASearch(Tree* referenceTree,
const bool firstLeafExact,
const size_t singleSampleLimit,
const MetricType metric) :
- referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
+ referenceSet(referenceTree->Dataset()),
treeOwner(false),
naive(false),
singleMode(singleMode),
@@ -176,24 +163,14 @@ Search(const MatType& querySet,
distancePtr->set_size(k, querySet.n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
- // If we will be building a tree and it will modify the query set, make a copy
- // of the dataset.
- MatType queryCopy;
- const bool needsCopy = (!naive && !singleMode &&
- tree::TreeTraits<Tree>::RearrangesDataset);
- if (needsCopy)
- queryCopy = querySet;
-
- const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
-
- // Create the helper object for the tree traversal.
typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr,
- metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
- singleSampleLimit, false);
if (naive)
{
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+ singleSampleLimit, false);
+
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
@@ -204,12 +181,16 @@ Search(const MatType& querySet,
// Run the base case on each combination of query point and sampled
// reference point.
- for (size_t i = 0; i < querySetRef.n_cols; ++i)
+ for (size_t i = 0; i < querySet.n_cols; ++i)
for (size_t j = 0; j < distinctSamples.n_elem; ++j)
rules.BaseCase(i, (size_t) distinctSamples[j]);
}
else if (singleMode)
{
+ RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+ tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+ singleSampleLimit, false);
+
// If the reference root node is a leaf, then the sampling has already been
// done in the RASearchRules constructor. This happens when naive = true.
if (!referenceTree->IsLeaf())
@@ -220,7 +201,7 @@ Search(const MatType& querySet,
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
- for (size_t i = 0; i < querySetRef.n_cols; ++i)
+ for (size_t i = 0; i < querySet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
Log::Info << "Single-tree traversal complete." << std::endl;
@@ -236,11 +217,14 @@ Search(const MatType& querySet,
// Build the query tree.
Timer::Stop("computing_neighbors");
Timer::Start("tree_building");
- Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+ Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySet),
oldFromNewQueries);
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
+ RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+ *distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
+ firstLeafExact, singleSampleLimit, false);
typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
Log::Info << "Query statistic pre-search: "
More information about the mlpack-git
mailing list