[mlpack-git] master: Update constructors taking a reference to the reference tree. (250bd45)

gitdub at mlpack.org gitdub at mlpack.org
Tue Aug 23 15:58:25 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1148f1652e139c9037eb3813550090313d089a30...a8a8a1381b529a01420de6e792a4a1e7bd58a626

>---------------------------------------------------------------

commit 250bd4567e6e486e329e3c0c0f3c50f45ddb65d4
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Aug 23 15:02:45 2016 -0300

    Update constructors taking a reference to the reference tree.


>---------------------------------------------------------------

250bd4567e6e486e329e3c0c0f3c50f45ddb65d4
 .../methods/neighbor_search/neighbor_search.hpp    | 122 +++++++++---------
 .../neighbor_search/neighbor_search_impl.hpp       | 140 ++++++++++++---------
 2 files changed, 142 insertions(+), 120 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index c0a40c5..9b8d51a 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -160,6 +160,65 @@ class NeighborSearch
       const MetricType metric = MetricType());
 
   /**
+   * Initialize the NeighborSearch object with a copy of the given
+   * pre-constructed reference tree (this is the tree built on the points that
+   * will be searched).  Optionally, choose to use single-tree mode.  Naive mode
+   * is not available as an option for this constructor.  Additionally, an
+   * instantiated distance metric can be given, for cases where the distance
+   * metric holds data.
+   *
+   * This method will copy the given tree.  You can avoid this copy by using the
+   * construct that takes a rvalue reference to the tree.
+   *
+   * @note
+   * Mapping the points of the matrix back to their original indices is not done
+   * when this constructor is used, so if the tree type you are using maps
+   * points (like BinarySpaceTree), then you will have to perform the re-mapping
+   * manually.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric Instantiated distance metric.
+   */
+  NeighborSearch(
+      Tree& referenceTree,
+      const NeighborSearchMode mode = DUAL_TREE_MODE,
+      const double epsilon = 0,
+      const MetricType metric = MetricType());
+
+  /**
+   * Initialize the NeighborSearch object with the given pre-constructed
+   * reference tree (this is the tree built on the points that will be
+   * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
+   * available as an option for this constructor.  Additionally, an instantiated
+   * distance metric can be given, for cases where the distance metric holds
+   * data.
+   *
+   * This method will take ownership of the given tree. There is no copying of
+   * the data matrices (because tree-building is not necessary), so this is the
+   * constructor to use when copies absolutely must be avoided.
+   *
+   * @note
+   * Mapping the points of the matrix back to their original indices is not done
+   * when this constructor is used, so if the tree type you are using maps
+   * points (like BinarySpaceTree), then you will have to perform the re-mapping
+   * manually.
+   * @endnote
+   *
+   * @param referenceTree Pre-built tree for reference points.
+   * @param mode Neighbor search mode.
+   * @param epsilon Relative approximate error (non-negative).
+   * @param metric Instantiated distance metric.
+   */
+  NeighborSearch(
+      Tree&& referenceTree,
+      const NeighborSearchMode mode = DUAL_TREE_MODE,
+      const double epsilon = 0,
+      const MetricType metric = MetricType());
+
+  /**
    * Create a NeighborSearch object without any reference data.  If Search() is
    * called before a reference set is set with Train(), an exception will be
    * thrown.
@@ -229,37 +288,6 @@ class NeighborSearch
                                    const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object with a copy of the given
-   * pre-constructed reference tree (this is the tree built on the points that
-   * will be searched).  Optionally, choose to use single-tree mode.  Naive mode
-   * is not available as an option for this constructor.  Additionally, an
-   * instantiated distance metric can be given, for cases where the distance
-   * metric holds data.
-   *
-   * Deprecated. Will be removed in mlpack 3.0.0.
-   *
-   * This method will copy the given tree.  You can avoid this copy by using the
-   * construct that takes a rvalue reference to the tree.
-   *
-   * @note
-   * Mapping the points of the matrix back to their original indices is not done
-   * when this constructor is used, so if the tree type you are using maps
-   * points (like BinarySpaceTree), then you will have to perform the re-mapping
-   * manually.
-   * @endnote
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric Instantiated distance metric.
-   */
-  mlpack_deprecated NeighborSearch(Tree& referenceTree,
-                                   const bool singleMode = false,
-                                   const double epsilon = 0,
-                                   const MetricType metric = MetricType());
-
-  /**
    * Initialize the NeighborSearch object with the given pre-constructed
    * reference tree (this is the tree built on the points that will be
    * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
@@ -293,38 +321,6 @@ class NeighborSearch
                                    const MetricType metric = MetricType());
 
   /**
-   * Initialize the NeighborSearch object with the given pre-constructed
-   * reference tree (this is the tree built on the points that will be
-   * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
-   * available as an option for this constructor.  Additionally, an instantiated
-   * distance metric can be given, for cases where the distance metric holds
-   * data.
-   *
-   * Deprecated. Will be removed in mlpack 3.0.0.
-   *
-   * This method will take ownership of the given tree. There is no copying of
-   * the data matrices (because tree-building is not necessary), so this is the
-   * constructor to use when copies absolutely must be avoided.
-   *
-   * @note
-   * Mapping the points of the matrix back to their original indices is not done
-   * when this constructor is used, so if the tree type you are using maps
-   * points (like BinarySpaceTree), then you will have to perform the re-mapping
-   * manually.
-   * @endnote
-   *
-   * @param referenceTree Pre-built tree for reference points.
-   * @param singleMode Whether single-tree computation should be used (as
-   *      opposed to dual-tree computation).
-   * @param epsilon Relative approximate error (non-negative).
-   * @param metric Instantiated distance metric.
-   */
-  mlpack_deprecated NeighborSearch(Tree&& referenceTree,
-                                   const bool singleMode = false,
-                                   const double epsilon = 0,
-                                   const MetricType metric = MetricType());
-
-  /**
    * Create a NeighborSearch object without any reference data.  If Search() is
    * called before a reference set is set with Train(), an exception will be
    * thrown.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 2bd7af4..af5d4cc 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -167,7 +167,7 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
     throw std::invalid_argument("epsilon must be non-negative");
 }
 
-// Construct the object without a reference dataset.
+// Construct the object.
 template<typename SortPolicy,
          typename MetricType,
          typename MatType,
@@ -177,13 +177,15 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
+SingleTreeTraversalType>::NeighborSearch(Tree& referenceTree,
+                                         std::vector<size_t>& oldFromNew,
+                                         const NeighborSearchMode mode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(NULL),
-    referenceSet(new MatType()), // Empty matrix.
-    treeOwner(false),
-    setOwner(true),
+    referenceTree(new Tree(referenceTree)),
+    referenceSet(&this->referenceTree->Dataset()),
+    treeOwner(true),
+    setOwner(false),
     searchMode(mode),
     epsilon(epsilon),
     metric(metric),
@@ -196,13 +198,12 @@ SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
 
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
-
-  // Build the tree on the empty dataset, if necessary.
-  if (mode != NAIVE_MODE)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
-    referenceTree = BuildTree<MatType, Tree>(*referenceSet,
-        oldFromNewReferences);
-    treeOwner = true;
+    if (oldFromNew.size() != referenceSet->n_cols)
+      throw std::invalid_argument("the size of oldFromNew vector must match the"
+          " number of points in the given dataset");
+    oldFromNewReferences = oldFromNew;
   }
 }
 
@@ -216,33 +217,37 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
-                                         const bool naive,
-                                         const bool singleMode,
+SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
+                                         std::vector<size_t>&& oldFromNew,
+                                         const NeighborSearchMode mode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(naive ? NULL :
-        BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
-    referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
-    treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
+    referenceTree(new Tree(std::move(referenceTree))),
+    referenceSet(&this->referenceTree->Dataset()),
+    treeOwner(true),
     setOwner(false),
-    naive(naive),
-    singleMode(!naive && singleMode), // No single mode if naive.
-    greedy(false),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update searchMode according to naive, singleMode and greedy flags.
-  UpdateSearchMode();
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
 
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
+  {
+    if (oldFromNew.size() != referenceSet->n_cols)
+      throw std::invalid_argument("the size of oldFromNew vector must match the"
+          " number of points in the given dataset");
+    oldFromNewReferences = std::move(oldFromNew);
+  }
 }
 
-// Construct the object.
+// Construct the object without a reference dataset.
 template<typename SortPolicy,
          typename MetricType,
          typename MatType,
@@ -252,32 +257,33 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
-                                         const bool naive,
-                                         const bool singleMode,
+SingleTreeTraversalType>::NeighborSearch(const NeighborSearchMode mode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(naive ? NULL :
-        BuildTree<MatType, Tree>(std::move(referenceSetIn),
-                                 oldFromNewReferences)),
-    referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
-        &referenceTree->Dataset()),
-    treeOwner(!naive),
-    setOwner(naive),
-    naive(naive),
-    singleMode(!naive && singleMode),
-    greedy(false),
+    referenceTree(NULL),
+    referenceSet(new MatType()), // Empty matrix.
+    treeOwner(false),
+    setOwner(true),
+    searchMode(mode),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
-  // Update searchMode according to naive, singleMode and greedy flags.
-  UpdateSearchMode();
+  // Update naive, singleMode and greedy flags according to searchMode.
+  UpdateSearchModeFlags();
 
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
+
+  // Build the tree on the empty dataset, if necessary.
+  if (mode != NAIVE_MODE)
+  {
+    referenceTree = BuildTree<MatType, Tree>(*referenceSet,
+        oldFromNewReferences);
+    treeOwner = true;
+  }
 }
 
 // Construct the object.
@@ -290,16 +296,18 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
+SingleTreeTraversalType>::NeighborSearch(const MatType& referenceSetIn,
+                                         const bool naive,
                                          const bool singleMode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(referenceTree),
-    referenceSet(&referenceTree->Dataset()),
-    treeOwner(false),
+    referenceTree(naive ? NULL :
+        BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
+    referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
+    treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
     setOwner(false),
-    naive(false),
-    singleMode(singleMode),
+    naive(naive),
+    singleMode(!naive && singleMode), // No single mode if naive.
     greedy(false),
     epsilon(epsilon),
     metric(metric),
@@ -324,22 +332,30 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(Tree& referenceTree,
+SingleTreeTraversalType>::NeighborSearch(MatType&& referenceSetIn,
+                                         const bool naive,
                                          const bool singleMode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(new Tree(referenceTree)),
-    referenceSet(&this->referenceTree->Dataset()),
-    treeOwner(true),
-    setOwner(false),
-    naive(false),
-    singleMode(singleMode),
+    referenceTree(naive ? NULL :
+        BuildTree<MatType, Tree>(std::move(referenceSetIn),
+                                 oldFromNewReferences)),
+    referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
+        &referenceTree->Dataset()),
+    treeOwner(!naive),
+    setOwner(naive),
+    naive(naive),
+    singleMode(!naive && singleMode),
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  // Update searchMode according to naive, singleMode and greedy flags.
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -354,22 +370,26 @@ template<typename SortPolicy,
          template<typename> class DualTreeTraversalType,
          template<typename> class SingleTreeTraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
-SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
+SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
                                          const bool singleMode,
                                          const double epsilon,
                                          const MetricType metric) :
-    referenceTree(new Tree(std::move(referenceTree))),
-    referenceSet(&this->referenceTree->Dataset()),
-    treeOwner(true),
+    referenceTree(referenceTree),
+    referenceSet(&referenceTree->Dataset()),
+    treeOwner(false),
     setOwner(false),
     naive(false),
     singleMode(singleMode),
+    greedy(false),
     epsilon(epsilon),
     metric(metric),
     baseCases(0),
     scores(0),
     treeNeedsReset(false)
 {
+  // Update searchMode according to naive, singleMode and greedy flags.
+  UpdateSearchMode();
+
   if (epsilon < 0)
     throw std::invalid_argument("epsilon must be non-negative");
 }
@@ -562,6 +582,9 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree& referenceTree)
 {
+  // Update searchMode according to naive, singleMode and greedy flags.
+  UpdateSearchMode();
+
   if (naive)
     throw std::invalid_argument("cannot train on given reference tree when "
         "naive search (without trees) is desired");
@@ -588,6 +611,9 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree&& referenceTree)
 {
+  // Update searchMode according to naive, singleMode and greedy flags.
+  UpdateSearchMode();
+
   if (naive)
     throw std::invalid_argument("cannot train on given reference tree when "
         "naive search (without trees) is desired");




More information about the mlpack-git mailing list