[mlpack-git] master: Refactor NeighborSearch internals to deal with the tree holding the dataset internally. (e7890e2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:12 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit e7890e2540ddf84cb6bb5fd798e0ad2487da99da
Author: ryan <ryan at ratml.org>
Date:   Wed Apr 22 18:11:06 2015 -0400

    Refactor NeighborSearch internals to deal with the tree holding the dataset internally.


>---------------------------------------------------------------

e7890e2540ddf84cb6bb5fd798e0ad2487da99da
 .../methods/neighbor_search/neighbor_search.hpp    | 11 ++---
 .../neighbor_search/neighbor_search_impl.hpp       | 56 +++++++---------------
 2 files changed, 21 insertions(+), 46 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 18111a3..6621b72 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -174,15 +174,12 @@ class NeighborSearch
   size_t& Scores() { return scores; }
 
  private:
-  //! Copy of reference dataset (if we need it, because tree building modifies
-  //! it).
-  typename TreeType::Mat referenceCopy;
-  //! Reference dataset.
-  const typename TreeType::Mat& referenceSet;
-  //! Pointer to the root of the reference tree.
-  TreeType* referenceTree;
   //! Permutations of reference points during tree building.
   std::vector<size_t> oldFromNewReferences;
+  //! Pointer to the root of the reference tree.
+  TreeType* referenceTree;
+  //! Reference to reference dataset.
+  const typename TreeType::Mat& referenceSet;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 5a1c3c6..c6dc565 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -18,7 +18,7 @@ namespace neighbor {
 //! Call the tree constructor that does mapping.
 template<typename TreeType>
 TreeType* BuildTree(
-    typename TreeType::Mat& dataset,
+    const typename TreeType::Mat& dataset,
     std::vector<size_t>& oldFromNew,
     typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -46,9 +46,9 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
                const bool naive,
                const bool singleMode,
                const MetricType metric) :
-    referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
-        ? referenceCopy : referenceSetIn),
-    referenceTree(NULL),
+    referenceTree(naive ? NULL :
+        BuildTree<TreeType>(referenceSetIn, oldFromNewReferences)),
+    referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
     treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
@@ -56,25 +56,7 @@ NeighborSearch(const typename TreeType::Mat& referenceSetIn,
     baseCases(0),
     scores(0)
 {
-  // Build the tree.
-  Timer::Start("tree_building");
-
-  if (!naive)
-  {
-    // Copy the dataset, if it will be modified during tree building.
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
-      referenceCopy = referenceSetIn;
-
-    // The const_cast is safe; if RearrangesDataset == false, then it'll be
-    // casted back to const anyway, and if not, referenceSet points to
-    // referenceCopy, which isn't const.
-    referenceTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(referenceSet),
-        oldFromNewReferences);
-  }
-
-  // Stop the timer we started above.
-  Timer::Stop("tree_building");
+  // Nothing to do.
 }
 
 // Construct the object.
@@ -83,8 +65,8 @@ NeighborSearch<SortPolicy, MetricType, TreeType>::NeighborSearch(
     TreeType* referenceTree,
     const bool singleMode,
     const MetricType metric) :
-    referenceSet(referenceTree->Dataset()),
     referenceTree(referenceTree),
+    referenceSet(referenceTree->Dataset()),
     treeOwner(false),
     naive(false),
     singleMode(singleMode),
@@ -142,23 +124,13 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   distancePtr->set_size(k, querySet.n_cols);
   distancePtr->fill(SortPolicy::WorstDistance());
 
-  // If we will be building a tree and it will modify the query set, make a copy
-  // of the dataset.
-  typename TreeType::Mat queryCopy;
-  const bool needsCopy = (!naive && !singleMode &&
-      tree::TreeTraits<TreeType>::RearrangesDataset);
-  if (needsCopy)
-    queryCopy = querySet;
-
-  const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
-      querySet;
-
-  // Create the helper object for the tree traversal.
   typedef NeighborSearchRules<SortPolicy, MetricType, TreeType> RuleType;
-  RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr, metric);
 
   if (naive)
   {
+    // Create the helper object for the tree traversal.
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
     // The naive brute-force traversal.
     for (size_t i = 0; i < querySet.n_cols; ++i)
       for (size_t j = 0; j < referenceSet.n_cols; ++j)
@@ -168,6 +140,9 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   }
   else if (singleMode)
   {
+    // Create the helper object for the tree traversal.
+    RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+
     // Create the traverser.
     typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
 
@@ -186,11 +161,14 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
     // Build the query tree.
     Timer::Stop("computing_neighbors");
     Timer::Start("tree_building");
-    TreeType* queryTree = BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+    TreeType* queryTree = BuildTree<TreeType>(querySet, oldFromNewQueries);
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
+    // Create the helper object for the tree traversal.
+    RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+        *distancePtr, metric);
+
     // Create the traverser.
     typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
 



More information about the mlpack-git mailing list