[mlpack-git] master: Merge branch 'searchMode-3.0.0' of https://github.com/MarcosPividori/mlpack into MarcosPividori-searchMode-3.0.0 (46a9c0d)

gitdub at mlpack.org gitdub at mlpack.org
Tue Nov 1 13:50:32 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3d3d733ba3c41c4f51764f44185767384ab6d9c7...94d14187222231ca29e4f6419c5999c660db4f8a

>---------------------------------------------------------------

commit 46a9c0de4cd21cc626d50ec9c5dcbd6c20718482
Merge: 91564ff 211ec7b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Nov 1 13:50:32 2016 -0400

    Merge branch 'searchMode-3.0.0' of https://github.com/MarcosPividori/mlpack into MarcosPividori-searchMode-3.0.0


>---------------------------------------------------------------

46a9c0de4cd21cc626d50ec9c5dcbd6c20718482
 src/mlpack/methods/neighbor_search/kfn_main.cpp    |   2 +-
 src/mlpack/methods/neighbor_search/knn_main.cpp    |   2 +-
 .../methods/neighbor_search/neighbor_search.hpp    | 193 +------------
 .../neighbor_search/neighbor_search_impl.hpp       | 310 +--------------------
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  26 +-
 .../methods/neighbor_search/ns_model_impl.hpp      |  62 +----
 src/mlpack/tests/knn_test.cpp                      |   8 +-
 7 files changed, 26 insertions(+), 577 deletions(-)

diff --cc src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e57d802,378fcb0..e23641c
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@@ -333,33 -195,9 +225,23 @@@ class NeighborSearc
    void Train(MatType&& referenceSet);
  
    /**
-    * Set the reference tree to a new reference tree.
-    *
-    * This method is deprecated and will be removed in mlpack 3.0.0! Train()
-    * methods taking a reference to the reference tree are preferred.
-    *
-    * @param referenceTree Pre-built tree for reference points.
-    */
-   mlpack_deprecated void Train(Tree* referenceTree);
- 
-   /**
 +   * Set the reference tree as a copy of the given reference tree.
 +   *
 +   * This method will copy the given tree.  You can avoid this copy by using the
 +   * Train() method that takes a rvalue reference to the tree.
 +   *
 +   * @param referenceTree Pre-built tree for reference points.
 +   */
 +  void Train(const Tree& referenceTree);
 +
 +  /**
     * Set the reference tree to a new reference tree.
 +   *
 +   * This method will take ownership of the given tree.
 +   *
 +   * @param referenceTree Pre-built tree for reference points.
     */
 -  void Train(Tree* referenceTree);
 +  void Train(Tree&& referenceTree);
  
    /**
     * For each point in the query set, compute the nearest neighbors and store
diff --cc src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index fb099e3,0f78a70..d8c16ee
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@@ -147,41 -141,9 +141,38 @@@ SingleTreeTraversalType>::NeighborSearc
                                           const NeighborSearchMode mode,
                                           const double epsilon,
                                           const MetricType metric) :
 -    referenceTree(referenceTree),
 -    referenceSet(&referenceTree->Dataset()),
 -    treeOwner(false),
 +    referenceTree(new Tree(referenceTree)),
 +    referenceSet(&this->referenceTree->Dataset()),
 +    treeOwner(true),
 +    setOwner(false),
 +    searchMode(mode),
 +    epsilon(epsilon),
 +    metric(metric),
 +    baseCases(0),
 +    scores(0),
 +    treeNeedsReset(false)
 +{
-   // Update naive, singleMode and greedy flags according to searchMode.
-   UpdateSearchModeFlags();
- 
 +  if (epsilon < 0)
 +    throw std::invalid_argument("epsilon must be non-negative");
 +}
 +
 +// Construct the object.
 +template<typename SortPolicy,
 +         typename MetricType,
 +         typename MatType,
 +         template<typename TreeMetricType,
 +                  typename TreeStatType,
 +                  typename TreeMatType> class TreeType,
 +         template<typename> class DualTreeTraversalType,
 +         template<typename> class SingleTreeTraversalType>
 +NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
 +SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
 +                                         const NeighborSearchMode mode,
 +                                         const double epsilon,
 +                                         const MetricType metric) :
 +    referenceTree(new Tree(std::move(referenceTree))),
 +    referenceSet(&this->referenceTree->Dataset()),
 +    treeOwner(true),
      setOwner(false),
      searchMode(mode),
      epsilon(epsilon),
@@@ -190,9 -152,8 +181,6 @@@
      scores(0),
      treeNeedsReset(false)
  {
-   // Update naive, singleMode and greedy flags according to searchMode.
-   UpdateSearchModeFlags();
- 
 -  if (mode == NAIVE_MODE)
 -    throw std::invalid_argument("invalid constructor for naive mode");
    if (epsilon < 0)
      throw std::invalid_argument("epsilon must be non-negative");
  }
@@@ -416,15 -224,9 +251,12 @@@ void NeighborSearch<SortPolicy, MetricT
  DualTreeTraversalType, SingleTreeTraversalType>::Train(
      const MatType& referenceSet)
  {
-   // Update searchMode.
-   UpdateSearchMode();
- 
    // Clean up the old tree, if we built one.
    if (treeOwner && referenceTree)
 +  {
 +    oldFromNewReferences.clear();
      delete referenceTree;
 +  }
  
    // We may need to rebuild the tree.
    if (searchMode != NAIVE_MODE)
@@@ -460,15 -262,9 +292,12 @@@ template<typename SortPolicy
  void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
  DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
  {
-   // Update searchMode.
-   UpdateSearchMode();
- 
    // Clean up the old tree, if we built one.
    if (treeOwner && referenceTree)
 +  {
 +    oldFromNewReferences.clear();
      delete referenceTree;
 +  }
  
    // We may need to rebuild the tree.
    if (searchMode != NAIVE_MODE)
@@@ -507,88 -303,14 +336,49 @@@ template<typename SortPolicy
           template<typename> class DualTreeTraversalType,
           template<typename> class SingleTreeTraversalType>
  void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
--DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
- {
-   // Update searchMode.
-   UpdateSearchMode();
- 
-   if (searchMode == NAIVE_MODE)
-     throw std::invalid_argument("cannot train on given reference tree when "
-         "naive search (without trees) is desired");
- 
-   if (treeOwner && this->referenceTree)
-   {
-     oldFromNewReferences.clear();
-     delete this->referenceTree;
-   }
- 
-   if (setOwner && referenceSet)
-     delete this->referenceSet;
- 
-   this->referenceTree = referenceTree;
-   this->referenceSet = &referenceTree->Dataset();
-   treeOwner = false;
-   setOwner = false;
- }
- 
- template<typename SortPolicy,
-          typename MetricType,
-          typename MatType,
-          template<typename TreeMetricType,
-                   typename TreeStatType,
-                   typename TreeMatType> class TreeType,
-          template<typename> class DualTreeTraversalType,
-          template<typename> class SingleTreeTraversalType>
- void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 +DualTreeTraversalType, SingleTreeTraversalType>::Train(
 +    const Tree& referenceTree)
  {
-   // Update searchMode according to naive, singleMode and greedy flags.
-   UpdateSearchMode();
- 
-   if (naive)
+   if (searchMode == NAIVE_MODE)
      throw std::invalid_argument("cannot train on given reference tree when "
          "naive search (without trees) is desired");
  
 -  if (treeOwner && referenceTree)
 +  if (treeOwner && this->referenceTree)
 +  {
 +    oldFromNewReferences.clear();
      delete this->referenceTree;
 +  }
 +
 +  if (setOwner && referenceSet)
 +    delete this->referenceSet;
 +
 +  this->referenceTree = new Tree(referenceTree);
 +  this->referenceSet = &this->referenceTree->Dataset();
 +  treeOwner = true;
 +  setOwner = false;
 +}
 +
 +template<typename SortPolicy,
 +         typename MetricType,
 +         typename MatType,
 +         template<typename TreeMetricType,
 +                  typename TreeStatType,
 +                  typename TreeMatType> class TreeType,
 +         template<typename> class DualTreeTraversalType,
 +         template<typename> class SingleTreeTraversalType>
 +void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
 +DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree&& referenceTree)
 +{
-   // Update searchMode according to naive, singleMode and greedy flags.
-   UpdateSearchMode();
- 
-   if (naive)
++  if (searchMode == NAIVE_MODE)
 +    throw std::invalid_argument("cannot train on given reference tree when "
 +        "naive search (without trees) is desired");
 +
 +  if (treeOwner && this->referenceTree)
 +  {
 +    oldFromNewReferences.clear();
 +    delete this->referenceTree;
 +  }
 +
    if (setOwner && referenceSet)
      delete this->referenceSet;
  
@@@ -828,26 -547,7 +615,7 @@@ template<typename SortPolicy
           template<typename> class SingleTreeTraversalType>
  void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
  DualTreeTraversalType, SingleTreeTraversalType>::Search(
--    Tree* queryTree,
-     const size_t k,
-     arma::Mat<size_t>& neighbors,
-     arma::mat& distances,
-     bool sameSet)
- {
-   Search(*queryTree, k, neighbors, distances, sameSet);
- }
- 
- template<typename SortPolicy,
-          typename MetricType,
-          typename MatType,
-          template<typename TreeMetricType,
-                   typename TreeStatType,
-                   typename TreeMatType> class TreeType,
-          template<typename> class DualTreeTraversalType,
-          template<typename> class SingleTreeTraversalType>
- void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
- DualTreeTraversalType, SingleTreeTraversalType>::Search(
 +    Tree& queryTree,
      const size_t k,
      arma::Mat<size_t>& neighbors,
      arma::mat& distances,




More information about the mlpack-git mailing list