[mlpack-git] master: Refactor for non-modifying TreeTypes. (fa72b45)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:59 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit fa72b45f3ab9ae758ae3319c869fde7074a8d6a7
Author: ryan <ryan at ratml.org>
Date:   Mon Jul 27 15:06:02 2015 -0400

    Refactor for non-modifying TreeTypes.


>---------------------------------------------------------------

fa72b45f3ab9ae758ae3319c869fde7074a8d6a7
 src/mlpack/methods/rann/ra_search.hpp      | 13 +++----
 src/mlpack/methods/rann/ra_search_impl.hpp | 54 +++++++++++-------------------
 2 files changed, 23 insertions(+), 44 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 219b33f..59bd081 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -290,13 +290,12 @@ class RASearch
   std::string ToString() const;
 
  private:
-  //! Copy of reference dataset (if we need it, because tree building modifies
-  //! it).
-  MatType referenceCopy;
-  //! Reference dataset.
-  const MatType& referenceSet;
+  //! Permutations of reference points during tree building.
+  std::vector<size_t> oldFromNewReferences;
   //! Pointer to the root of the reference tree.
   Tree* referenceTree;
+  //! Reference dataset.
+  const MatType& referenceSet;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
@@ -320,10 +319,6 @@ class RASearch
 
   //! Instantiation of kernel.
   MetricType metric;
-
-  //! Permutations of reference points during tree building.
-  std::vector<size_t> oldFromNewReferences;
-
 }; // class RASearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index d44707a..a700869 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -59,9 +59,9 @@ RASearch(const MatType& referenceSetIn,
          const bool firstLeafExact,
          const size_t singleSampleLimit,
          const MetricType metric) :
-    referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
-        ? referenceCopy : referenceSetIn),
-    referenceTree(NULL),
+    referenceTree(naive ? NULL : aux::BuildTree<Tree>(
+        const_cast<MatType&>(referenceSetIn), oldFromNewReferences)),
+    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
     treeOwner(!naive),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
@@ -72,20 +72,7 @@ RASearch(const MatType& referenceSetIn,
     singleSampleLimit(singleSampleLimit),
     metric(metric)
 {
-  // We'll time tree building.
-  Timer::Start("tree_building");
-
-  if (!naive)
-  {
-    if (tree::TreeTraits<Tree>::RearrangesDataset)
-      referenceCopy = referenceSetIn;
-
-    referenceTree = aux::BuildTree<Tree>(const_cast<MatType&>(referenceSet),
-        oldFromNewReferences);
-  }
-
-  // Stop the timer we started above.
-  Timer::Stop("tree_building");
+  // Nothing to do.
 }
 
 // Construct the object.
@@ -103,8 +90,8 @@ RASearch(Tree* referenceTree,
          const bool firstLeafExact,
          const size_t singleSampleLimit,
          const MetricType metric) :
-    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
+    referenceSet(referenceTree->Dataset()),
     treeOwner(false),
     naive(false),
     singleMode(singleMode),
@@ -176,24 +163,14 @@ Search(const MatType& querySet,
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
-  // If we will be building a tree and it will modify the query set, make a copy
-  // of the dataset.
-  MatType queryCopy;
-  const bool needsCopy = (!naive && !singleMode &&
-      tree::TreeTraits<Tree>::RearrangesDataset);
-  if (needsCopy)
-    queryCopy = querySet;
-
-  const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
-
-  // Create the helper object for the tree traversal.
   typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
-  RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr,
-                 metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
-                 singleSampleLimit, false);
 
   if (naive)
   {
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+                   tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+                   singleSampleLimit, false);
+
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
     const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
@@ -204,12 +181,16 @@ Search(const MatType& querySet,
 
     // Run the base case on each combination of query point and sampled
     // reference point.
-    for (size_t i = 0; i < querySetRef.n_cols; ++i)
+    for (size_t i = 0; i < querySet.n_cols; ++i)
       for (size_t j = 0; j < distinctSamples.n_elem; ++j)
         rules.BaseCase(i, (size_t) distinctSamples[j]);
   }
   else if (singleMode)
   {
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric,
+                   tau, alpha, naive, sampleAtLeaves, firstLeafExact,
+                   singleSampleLimit, false);
+
     // If the reference root node is a leaf, then the sampling has already been
     // done in the RASearchRules constructor.  This happens when naive = true.
     if (!referenceTree->IsLeaf())
@@ -220,7 +201,7 @@ Search(const MatType& querySet,
       typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
       // Now have it traverse for each point.
-      for (size_t i = 0; i < querySetRef.n_cols; ++i)
+      for (size_t i = 0; i < querySet.n_cols; ++i)
         traverser.Traverse(i, *referenceTree);
 
       Log::Info << "Single-tree traversal complete." << std::endl;
@@ -236,11 +217,14 @@ Search(const MatType& querySet,
     // Build the query tree.
     Timer::Stop("computing_neighbors");
     Timer::Start("tree_building");
-    Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+    Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySet),
         oldFromNewQueries);
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
+    RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+                   *distancePtr, metric, tau, alpha, naive, sampleAtLeaves,
+                   firstLeafExact, singleSampleLimit, false);
     typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     Log::Info << "Query statistic pre-search: "



More information about the mlpack-git mailing list