[mlpack-git] master: Merge branch 'searchMode-3.0.0' of https://github.com/MarcosPividori/mlpack into MarcosPividori-searchMode-3.0.0 (46a9c0d)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Nov 1 13:50:32 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3d3d733ba3c41c4f51764f44185767384ab6d9c7...94d14187222231ca29e4f6419c5999c660db4f8a
>---------------------------------------------------------------
commit 46a9c0de4cd21cc626d50ec9c5dcbd6c20718482
Merge: 91564ff 211ec7b
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Nov 1 13:50:32 2016 -0400
Merge branch 'searchMode-3.0.0' of https://github.com/MarcosPividori/mlpack into MarcosPividori-searchMode-3.0.0
>---------------------------------------------------------------
46a9c0de4cd21cc626d50ec9c5dcbd6c20718482
src/mlpack/methods/neighbor_search/kfn_main.cpp | 2 +-
src/mlpack/methods/neighbor_search/knn_main.cpp | 2 +-
.../methods/neighbor_search/neighbor_search.hpp | 193 +------------
.../neighbor_search/neighbor_search_impl.hpp | 310 +--------------------
src/mlpack/methods/neighbor_search/ns_model.hpp | 26 +-
.../methods/neighbor_search/ns_model_impl.hpp | 62 +----
src/mlpack/tests/knn_test.cpp | 8 +-
7 files changed, 26 insertions(+), 577 deletions(-)
diff --cc src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e57d802,378fcb0..e23641c
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@@ -333,33 -195,9 +225,23 @@@ class NeighborSearc
void Train(MatType&& referenceSet);
/**
- * Set the reference tree to a new reference tree.
- *
- * This method is deprecated and will be removed in mlpack 3.0.0! Train()
- * methods taking a reference to the reference tree are preferred.
- *
- * @param referenceTree Pre-built tree for reference points.
- */
- mlpack_deprecated void Train(Tree* referenceTree);
-
- /**
+ * Set the reference tree as a copy of the given reference tree.
+ *
+ * This method will copy the given tree. You can avoid this copy by using the
+ * Train() method that takes a rvalue reference to the tree.
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ */
+ void Train(const Tree& referenceTree);
+
+ /**
* Set the reference tree to a new reference tree.
+ *
+ * This method will take ownership of the given tree.
+ *
+ * @param referenceTree Pre-built tree for reference points.
*/
- void Train(Tree* referenceTree);
+ void Train(Tree&& referenceTree);
/**
* For each point in the query set, compute the nearest neighbors and store
diff --cc src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index fb099e3,0f78a70..d8c16ee
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@@ -147,41 -141,9 +141,38 @@@ SingleTreeTraversalType>::NeighborSearc
const NeighborSearchMode mode,
const double epsilon,
const MetricType metric) :
- referenceTree(referenceTree),
- referenceSet(&referenceTree->Dataset()),
- treeOwner(false),
+ referenceTree(new Tree(referenceTree)),
+ referenceSet(&this->referenceTree->Dataset()),
+ treeOwner(true),
+ setOwner(false),
+ searchMode(mode),
+ epsilon(epsilon),
+ metric(metric),
+ baseCases(0),
+ scores(0),
+ treeNeedsReset(false)
+{
- // Update naive, singleMode and greedy flags according to searchMode.
- UpdateSearchModeFlags();
-
+ if (epsilon < 0)
+ throw std::invalid_argument("epsilon must be non-negative");
+}
+
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
+ const NeighborSearchMode mode,
+ const double epsilon,
+ const MetricType metric) :
+ referenceTree(new Tree(std::move(referenceTree))),
+ referenceSet(&this->referenceTree->Dataset()),
+ treeOwner(true),
setOwner(false),
searchMode(mode),
epsilon(epsilon),
@@@ -190,9 -152,8 +181,6 @@@
scores(0),
treeNeedsReset(false)
{
- // Update naive, singleMode and greedy flags according to searchMode.
- UpdateSearchModeFlags();
-
- if (mode == NAIVE_MODE)
- throw std::invalid_argument("invalid constructor for naive mode");
if (epsilon < 0)
throw std::invalid_argument("epsilon must be non-negative");
}
@@@ -416,15 -224,9 +251,12 @@@ void NeighborSearch<SortPolicy, MetricT
DualTreeTraversalType, SingleTreeTraversalType>::Train(
const MatType& referenceSet)
{
- // Update searchMode.
- UpdateSearchMode();
-
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
+ {
+ oldFromNewReferences.clear();
delete referenceTree;
+ }
// We may need to rebuild the tree.
if (searchMode != NAIVE_MODE)
@@@ -460,15 -262,9 +292,12 @@@ template<typename SortPolicy
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
DualTreeTraversalType, SingleTreeTraversalType>::Train(MatType&& referenceSetIn)
{
- // Update searchMode.
- UpdateSearchMode();
-
// Clean up the old tree, if we built one.
if (treeOwner && referenceTree)
+ {
+ oldFromNewReferences.clear();
delete referenceTree;
+ }
// We may need to rebuild the tree.
if (searchMode != NAIVE_MODE)
@@@ -507,88 -303,14 +336,49 @@@ template<typename SortPolicy
template<typename> class DualTreeTraversalType,
template<typename> class SingleTreeTraversalType>
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
--DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
- {
- // Update searchMode.
- UpdateSearchMode();
-
- if (searchMode == NAIVE_MODE)
- throw std::invalid_argument("cannot train on given reference tree when "
- "naive search (without trees) is desired");
-
- if (treeOwner && this->referenceTree)
- {
- oldFromNewReferences.clear();
- delete this->referenceTree;
- }
-
- if (setOwner && referenceSet)
- delete this->referenceSet;
-
- this->referenceTree = referenceTree;
- this->referenceSet = &referenceTree->Dataset();
- treeOwner = false;
- setOwner = false;
- }
-
- template<typename SortPolicy,
- typename MetricType,
- typename MatType,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType,
- template<typename> class DualTreeTraversalType,
- template<typename> class SingleTreeTraversalType>
- void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Train(
+ const Tree& referenceTree)
{
- // Update searchMode according to naive, singleMode and greedy flags.
- UpdateSearchMode();
-
- if (naive)
+ if (searchMode == NAIVE_MODE)
throw std::invalid_argument("cannot train on given reference tree when "
"naive search (without trees) is desired");
- if (treeOwner && referenceTree)
+ if (treeOwner && this->referenceTree)
+ {
+ oldFromNewReferences.clear();
delete this->referenceTree;
+ }
+
+ if (setOwner && referenceSet)
+ delete this->referenceSet;
+
+ this->referenceTree = new Tree(referenceTree);
+ this->referenceSet = &this->referenceTree->Dataset();
+ treeOwner = true;
+ setOwner = false;
+}
+
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree&& referenceTree)
+{
- // Update searchMode according to naive, singleMode and greedy flags.
- UpdateSearchMode();
-
- if (naive)
++ if (searchMode == NAIVE_MODE)
+ throw std::invalid_argument("cannot train on given reference tree when "
+ "naive search (without trees) is desired");
+
+ if (treeOwner && this->referenceTree)
+ {
+ oldFromNewReferences.clear();
+ delete this->referenceTree;
+ }
+
if (setOwner && referenceSet)
delete this->referenceSet;
@@@ -828,26 -547,7 +615,7 @@@ template<typename SortPolicy
template<typename> class SingleTreeTraversalType>
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
DualTreeTraversalType, SingleTreeTraversalType>::Search(
-- Tree* queryTree,
- const size_t k,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- bool sameSet)
- {
- Search(*queryTree, k, neighbors, distances, sameSet);
- }
-
- template<typename SortPolicy,
- typename MetricType,
- typename MatType,
- template<typename TreeMetricType,
- typename TreeStatType,
- typename TreeMatType> class TreeType,
- template<typename> class DualTreeTraversalType,
- template<typename> class SingleTreeTraversalType>
- void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
- DualTreeTraversalType, SingleTreeTraversalType>::Search(
+ Tree& queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
More information about the mlpack-git
mailing list