[mlpack-git] master: Allow ownership of data matrix. (9ae893e)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Oct 22 11:11:17 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c81893381e80d4ecae4283cec5fe5264bdf4f677...d1dfaa8e0978e01c240660a3217e68c4fa7c3e0a

>---------------------------------------------------------------

commit 9ae893ee70f006d0e58c3cf647120b381039d29f
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Oct 22 15:10:43 2015 +0000

    Allow ownership of data matrix.


>---------------------------------------------------------------

9ae893ee70f006d0e58c3cf647120b381039d29f
 src/mlpack/methods/range_search/range_search.hpp   |  7 ++--
 .../methods/range_search/range_search_impl.hpp     | 39 +++++++++++++---------
 2 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index d1a7c02..de95d79 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -219,11 +219,14 @@ class RangeSearch
   std::vector<size_t> oldFromNewReferences;
   //! Reference tree.
   Tree* referenceTree;
-  //! Reference set (data should be accessed using this).
-  const MatType& referenceSet;
+  //! Reference set (data should be accessed using this).  In some situations we
+  //! may be the owner of this.
+  const MatType* referenceSet;
 
   //! If true, this object is responsible for deleting the trees.
   bool treeOwner;
+  //! If true, we own the reference set.
+  bool setOwner;
 
   //! If true, O(n^2) naive computation is used.
   bool naive;
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index c6b2e44..f832af7 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -51,8 +51,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
     const MetricType metric) :
     referenceTree(naive ? NULL : BuildTree<Tree>(
         const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
-    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+    referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
+    setOwner(false),
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
     metric(metric),
@@ -72,8 +73,9 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
     const bool singleMode,
     const MetricType metric) :
     referenceTree(referenceTree),
-    referenceSet(referenceTree->Dataset()),
+    referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
+    setOwner(false),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -92,6 +94,8 @@ RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
+  if (setOwner && referenceSet)
+    delete referenceSet;
 }
 
 template<typename MetricType,
@@ -146,12 +150,12 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
   if (naive)
   {
-    RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+    RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
         metric);
 
     // The naive brute-force solution.
     for (size_t i = 0; i < querySet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+      for (size_t j = 0; j < referenceSet->n_cols; ++j)
         rules.BaseCase(i, j);
 
     baseCases += (querySet.n_cols * referenceSet->n_cols);
@@ -159,7 +163,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
-    RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+    RuleType rules(*referenceSet, querySet, range, *neighborPtr, *distancePtr,
         metric);
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
@@ -181,7 +185,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     Timer::Start("range_search/computing_neighbors");
 
     // Create the traverser.
-    RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+    RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
         *distancePtr, metric);
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
@@ -298,7 +302,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
   // Create the helper object for the traversal.
   typedef RangeSearchRules<MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+  RuleType rules(*referenceSet, queryTree->Dataset(), range, *neighborPtr,
       distances, metric);
 
   // Create the traverser.
@@ -355,21 +359,24 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
   // Resize each vector.
   neighborPtr->clear(); // Just in case there was anything in it.
-  neighborPtr->resize(referenceSet.n_cols);
+  neighborPtr->resize(referenceSet->n_cols);
   distancePtr->clear();
-  distancePtr->resize(referenceSet.n_cols);
+  distancePtr->resize(referenceSet->n_cols);
 
   // Create the helper object for the traversal.
   typedef RangeSearchRules<MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, referenceSet, range, *neighborPtr, *distancePtr,
-      metric, true /* don't return the query point in the results */);
+  RuleType rules(*referenceSet, *referenceSet, range, *neighborPtr,
+      *distancePtr, metric, true /* don't return the query in the results */);
 
   if (naive)
   {
     // The naive brute-force solution.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
+      for (size_t j = 0; j < referenceSet->n_cols; ++j)
         rules.BaseCase(i, j);
+
+    baseCases = (referenceSet->n_cols * referenceSet->n_cols);
+    scores = 0;
   }
   else if (singleMode)
   {
@@ -377,7 +384,7 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
 
     baseCases = rules.BaseCases();
@@ -400,9 +407,9 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
   if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     neighbors.clear();
-    neighbors.resize(referenceSet.n_cols);
+    neighbors.resize(referenceSet->n_cols);
     distances.clear();
-    distances.resize(referenceSet.n_cols);
+    distances.resize(referenceSet->n_cols);
 
     for (size_t i = 0; i < distances.size(); i++)
     {



More information about the mlpack-git mailing list