[mlpack-git] master: Add support for rvalue references when setting a given reference tree in NeighborSearch class. (260f711)

gitdub at mlpack.org gitdub at mlpack.org
Tue Aug 23 15:58:25 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1148f1652e139c9037eb3813550090313d089a30...a8a8a1381b529a01420de6e792a4a1e7bd58a626

>---------------------------------------------------------------

commit 260f711f2cdf69880355fea905b2e44c2dfbf54b
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Thu Jul 28 18:26:38 2016 -0300

    Add support for rvalue references when setting a given reference tree in
    NeighborSearch class.


>---------------------------------------------------------------

260f711f2cdf69880355fea905b2e44c2dfbf54b
 .../methods/neighbor_search/neighbor_search.hpp    | 46 ++++++++++++++++++++--
 .../neighbor_search/neighbor_search_impl.hpp       | 37 ++++++++++++++++-
 .../methods/neighbor_search/ns_model_impl.hpp      | 16 +++-----
 3 files changed, 84 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e3871f3..f9806af 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -234,9 +234,10 @@ class NeighborSearch
    *
    * Deprecated. Will be removed in mlpack 3.0.0.
    *
-   * There is no copying of the data matrices in this constructor (because
-   * tree-building is not necessary), so this is the constructor to use when
-   * copies absolutely must be avoided.
+   * This method won't take ownership of the given tree. There is no copying of
+   * the data matrices in this constructor (because  tree-building is not
+   * necessary), so this is the constructor to use when copies absolutely must
+   * be avoided.
    *
    * @note
    * Mapping the points of the matrix back to their original indices is not done
@@ -257,6 +258,36 @@ class NeighborSearch
                  const MetricType metric = MetricType());
 
   /**
+   * Initialize the NeighborSearch object with the given pre-constructed
+   * reference tree (this is the tree built on the points that will be
+   * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
+   * available as an option for this constructor.  Additionally, an instantiated
+   * distance metric can be given, for cases where the distance metric holds
+   * data.
+   *
+   * This method will take ownership of the given tree. There is no copying of
+   * the data matrices (because tree-building is not necessary), so this is the
+   * constructor to use when copies absolutely must be avoided.
+   *
+   * @note
+   * Mapping the points of the matrix back to their original indices is not done
+   * when this constructor is used, so if the tree type you are using maps
+   * points (like BinarySpaceTree), then you will have to perform the re-mapping
+   * manually.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric Instantiated distance metric.
+   */
+  NeighborSearch(Tree&& referenceTree,
+                 const bool singleMode = false,
+                 const double epsilon = 0,
+                 const MetricType metric = MetricType());
+
+  /**
    * Create a NeighborSearch object without any reference data.  If Search() is
    * called before a reference set is set with Train(), an exception will be
    * thrown.
@@ -309,6 +340,15 @@ class NeighborSearch
   void Train(Tree* referenceTree);
 
   /**
+   * Set the reference tree to a new reference tree.
+   *
+   * This method will take ownership of the given tree.
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   */
+  void Train(Tree&& referenceTree);
+
+  /**
    * For each point in the query set, compute the nearest neighbors and store
    * the output in the given matrices.  The matrices will be set to the size of
    * n columns by k rows, where n is the number of points in the query dataset
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index ac06418..24fa9b2 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -314,6 +314,26 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     throw std::invalid_argument("epsilon must be non-negative");
 }
 
+// Construct the object.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
+                                         const bool singleMode,
+                                         const double epsilon,
+                                         const MetricType metric) :
+    NeighborSearch(new Tree(std::move(referenceTree)), singleMode, epsilon,
+        metric)
+{
+  treeOwner = true;
+}
+
 // Construct the object without a reference dataset.
 template<typename SortPolicy,
          typename MetricType,
@@ -480,7 +500,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
     throw std::invalid_argument("cannot train on given reference tree when "
         "naive search (without trees) is desired");
 
-  if (treeOwner && referenceTree)
+  if (treeOwner && this->referenceTree)
     delete this->referenceTree;
   if (setOwner && referenceSet)
     delete this->referenceSet;
@@ -491,6 +511,21 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
   setOwner = false;
 }
 
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class DualTreeTraversalType,
+         template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree&& referenceTree)
+{
+  Train(new Tree(std::move(referenceTree)));
+  treeOwner = true;
+}
+
 /**
  * Computes the best neighbors and stores them in resultingNeighbors and
  * distances.
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1591730..5c16bca 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -176,11 +176,8 @@ void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
       ns->Train(std::move(referenceSet));
     else
     {
-      typename SpillKNN::Tree* tree = new typename SpillKNN::Tree(
-          std::move(referenceSet), tau, leafSize, rho);
-      ns->Train(tree);
-      // Give the model ownership of the tree.
-      ns->treeOwner = true;
+      typename SpillKNN::Tree tree(std::move(referenceSet), tau, leafSize, rho);
+      ns->Train(std::move(tree));
     }
   }
   else
@@ -197,13 +194,10 @@ void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
   else
   {
     std::vector<size_t> oldFromNewReferences;
-    typename NSType::Tree* tree =
-        new typename NSType::Tree(std::move(referenceSet),
+    typename NSType::Tree referenceTree(std::move(referenceSet),
         oldFromNewReferences, leafSize);
-    ns->Train(tree);
-
-    // Give the model ownership of the tree and the mappings.
-    ns->treeOwner = true;
+    ns->Train(std::move(referenceTree));
+    // Set the mappings.
     ns->oldFromNewReferences = std::move(oldFromNewReferences);
   }
 }




More information about the mlpack-git mailing list