[mlpack-git] master: Refactor to handle interally-copying trees correctly. (45d9117)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:29 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit 45d9117361478923883e002d87f573f2be1be58c
Author: ryan <ryan at ratml.org>
Date:   Mon Jul 27 00:09:30 2015 -0400

    Refactor to handle interally-copying trees correctly.


>---------------------------------------------------------------

45d9117361478923883e002d87f573f2be1be58c
 src/mlpack/methods/range_search/range_search.hpp   | 15 +++----
 .../methods/range_search/range_search_impl.hpp     | 48 ++++++----------------
 2 files changed, 21 insertions(+), 42 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 04d3071..2b7d03e 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -198,18 +198,19 @@ class RangeSearch
               std::vector<std::vector<size_t>>& neighbors,
               std::vector<std::vector<double>>& distances);
 
-  // Returns a string representation of this object.
+  //! Returns a string representation of this object.
   std::string ToString() const;
 
+  //! Return the reference tree (or NULL if in naive mode).
+  Tree* ReferenceTree() { return referenceTree; }
+
  private:
-  //! Copy of reference matrix; used when a tree is built internally.
-  MatType referenceCopy;
-  //! Reference set (data should be accessed using this).
-  const MatType& referenceSet;
-  //! Reference tree.
-  Tree* referenceTree;
   //! Mappings to old reference indices (used when this object builds trees).
   std::vector<size_t> oldFromNewReferences;
+  //! Reference tree.
+  Tree* referenceTree;
+  //! Reference set (data should be accessed using this).
+  const MatType& referenceSet;
 
   //! If true, this object is responsible for deleting the trees.
   bool treeOwner;
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index a7ad027..d9325c7 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -48,32 +48,15 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
     const bool naive,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
-        ? referenceCopy : referenceSetIn),
-    referenceTree(NULL),
+    referenceTree(naive ? NULL : BuildTree<Tree>(
+        const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
+    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
     naive(naive),
     singleMode(!naive && singleMode), // Naive overrides single mode.
     metric(metric)
 {
-  // Build the tree.
-  Timer::Start("range_search/tree_building");
-
-  // If in naive mode, then we do not need to build trees.
-  if (!naive)
-  {
-    // Copy the dataset, if it will be modified during tree building.
-    if (tree::TreeTraits<Tree>::RearrangesDataset)
-      referenceCopy = referenceSetIn;
-
-    // The const_cast is safe; if RearrangesDataset == false, then it'll be
-    // casted back to const anyway, and if not, referenceSet points to
-    // referenceCopy, which isn't const.
-    referenceTree = BuildTree<Tree>(const_cast<MatType&>(referenceSet),
-        oldFromNewReferences);
-  }
-
-  Timer::Stop("range_search/tree_building");
+  // Nothing to do.
 }
 
 template<typename MetricType,
@@ -84,8 +67,8 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
     Tree* referenceTree,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
+    referenceSet(referenceTree->Dataset()),
     treeOwner(false),
     naive(false),
     singleMode(singleMode),
@@ -119,16 +102,6 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
   // This will hold mappings for query points, if necessary.
   std::vector<size_t> oldFromNewQueries;
 
-  // If we will be building a tree and it will modify the query set, make a copy
-  // of the dataset.
-  MatType queryCopy;
-  const bool needsCopy = (!naive && !singleMode &&
-      tree::TreeTraits<Tree>::RearrangesDataset);
-  if (needsCopy)
-    queryCopy = querySet;
-
-  const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
-
   // If we have built the trees ourselves, then we will have to map all the
   // indices back to their original indices when this computation is finished.
   // To avoid extra copies, we will store the unmapped neighbors and distances
@@ -158,11 +131,12 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
 
   // Create the helper object for the traversal.
   typedef RangeSearchRules<MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, querySetRef, range, *neighborPtr, *distancePtr,
-      metric);
 
   if (naive)
   {
+    RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+        metric);
+
     // The naive brute-force solution.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       for (size_t j = 0; j < referenceSet.n_cols; ++j)
@@ -171,6 +145,8 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
+    RuleType rules(referenceSet, querySet, range, *neighborPtr, *distancePtr,
+        metric);
     typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
@@ -182,12 +158,14 @@ void RangeSearch<MetricType, MatType, TreeType>::Search(
     // Build the query tree.
     Timer::Stop("range_search/computing_neighbors");
     Timer::Start("range_search/tree_building");
-    Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+    Tree* queryTree = BuildTree<Tree>(const_cast<MatType&>(querySet),
         oldFromNewQueries);
     Timer::Stop("range_search/tree_building");
     Timer::Start("range_search/computing_neighbors");
 
     // Create the traverser.
+    RuleType rules(referenceSet, queryTree->Dataset(), range, *neighborPtr,
+        *distancePtr, metric);
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*queryTree, *referenceTree);



More information about the mlpack-git mailing list