[mlpack-git] master: Use a pointer to the dataset. (4db02b1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:16 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7

>---------------------------------------------------------------

commit 4db02b1aac38a14a97d656484fc7fd7a2c295670
Author: ryan <ryan at ratml.org>
Date:   Tue Oct 27 11:57:57 2015 -0400

    Use a pointer to the dataset.
    
    This is in preparation for serialization support.


>---------------------------------------------------------------

4db02b1aac38a14a97d656484fc7fd7a2c295670
 src/mlpack/methods/rann/ra_search.hpp      |  6 ++--
 src/mlpack/methods/rann/ra_search_impl.hpp | 48 ++++++++++++++++--------------
 2 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index dfa5a72..3cd8714 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -295,11 +295,13 @@ class RASearch
   std::vector<size_t> oldFromNewReferences;
   //! Pointer to the root of the reference tree.
   Tree* referenceTree;
-  //! Reference dataset.
-  const MatType& referenceSet;
+  //! Reference dataset.  In some situations we may own this dataset.
+  const MatType* referenceSet;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
+  //! If true, we are responsible for deleting the dataset.
+  bool setOwner;
 
   //! Indicates if naive random sampling on the set is being used.
   bool naive;
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index 132d803..cc96166 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -62,8 +62,9 @@ RASearch(const MatType& referenceSetIn,
          const MetricType metric) :
     referenceTree(naive ? NULL : aux::BuildTree<Tree>(
         const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
-    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+    referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
     treeOwner(!naive),
+    setOwner(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     tau(tau),
@@ -93,8 +94,9 @@ RASearch(Tree* referenceTree,
          const size_t singleSampleLimit,
          const MetricType metric) :
     referenceTree(referenceTree),
-    referenceSet(referenceTree->Dataset()),
+    referenceSet(&referenceTree->Dataset()),
     treeOwner(false),
+    setOwner(false),
     naive(false),
     singleMode(singleMode),
     tau(tau),
@@ -121,6 +123,8 @@ RASearch<SortPolicy, MetricType, MatType, TreeType>::
 {
   if (treeOwner && referenceTree)
     delete referenceTree;
+  if (setOwner)
+    delete referenceSet;
 }
 
 /**
@@ -171,16 +175,16 @@ Search(const MatType& querySet,
 
   if (naive)
   {
-    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
                    tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                    singleSampleLimit, false);
 
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
-    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
-        tau, alpha);
+    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
+        k, tau, alpha);
     arma::uvec distinctSamples;
-    RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+    RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols,
         distinctSamples);
 
     // Run the base case on each combination of query point and sampled
@@ -191,7 +195,7 @@ Search(const MatType& querySet,
   }
   else if (singleMode)
   {
-    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+    RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric,
                    tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                    singleSampleLimit, false);
 
@@ -226,7 +230,7 @@ Search(const MatType& querySet,
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
-    RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+    RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
                    *distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
                    firstLeafExact, singleSampleLimit, false);
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
@@ -340,7 +344,7 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
 
   // Create the helper object for the tree traversal.
   typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr, distances,
+  RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr, distances,
                  metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                  singleSampleLimit, false);
 
@@ -390,14 +394,14 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
   }
 
   // Initialize results.
-  neighborPtr->set_size(k, referenceSet.n_cols);
+  neighborPtr->set_size(k, referenceSet->n_cols);
   neighborPtr->fill(size_t() - 1);
-  distancePtr->set_size(k, referenceSet.n_cols);
+  distancePtr->set_size(k, referenceSet->n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the tree traversal.
   typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+  RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
                  metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                  singleSampleLimit, true /* sets are the same */);
 
@@ -405,15 +409,15 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
   {
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
-    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
-        tau, alpha);
+    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet->n_cols,
+        k, tau, alpha);
     arma::uvec distinctSamples;
-    RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+    RAUtil::ObtainDistinctSamples(numSamples, referenceSet->n_cols,
         distinctSamples);
 
     // The naive brute-force solution.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
-      for (size_t j = 0; j < referenceSet.n_cols; ++j)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
+      for (size_t j = 0; j < referenceSet->n_cols; ++j)
         rules.BaseCase(i, j);
   }
   else if (singleMode)
@@ -422,7 +426,7 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
-    for (size_t i = 0; i < referenceSet.n_cols; ++i)
+    for (size_t i = 0; i < referenceSet->n_cols; ++i)
       traverser.Traverse(i, *referenceTree);
   }
   else
@@ -438,8 +442,8 @@ void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
   // Do we need to map the reference indices?
   if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
-    neighbors.set_size(k, referenceSet.n_cols);
-    distances.set_size(k, referenceSet.n_cols);
+    neighbors.set_size(k, referenceSet->n_cols);
+    distances.set_size(k, referenceSet->n_cols);
 
     for (size_t i = 0; i < distances.n_cols; ++i)
     {
@@ -486,8 +490,8 @@ std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
 {
   std::ostringstream convert;
   convert << "RASearch [" << this << "]" << std::endl;
-  convert << "  referenceSet: " << referenceSet.n_rows << "x"
-      << referenceSet.n_cols << std::endl;
+  convert << "  referenceSet: " << referenceSet->n_rows << "x"
+      << referenceSet->n_cols << std::endl;
 
   convert << "  naive: ";
   if (naive)



More information about the mlpack-git mailing list