[mlpack-git] master: Add support for rvalue references when setting a given reference tree in NeighborSearch class. (260f711)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Aug 23 15:58:25 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1148f1652e139c9037eb3813550090313d089a30...a8a8a1381b529a01420de6e792a4a1e7bd58a626
>---------------------------------------------------------------
commit 260f711f2cdf69880355fea905b2e44c2dfbf54b
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 28 18:26:38 2016 -0300
Add support for rvalue references when setting a given reference tree in
NeighborSearch class.
>---------------------------------------------------------------
260f711f2cdf69880355fea905b2e44c2dfbf54b
.../methods/neighbor_search/neighbor_search.hpp | 46 ++++++++++++++++++++--
.../neighbor_search/neighbor_search_impl.hpp | 37 ++++++++++++++++-
.../methods/neighbor_search/ns_model_impl.hpp | 16 +++-----
3 files changed, 84 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index e3871f3..f9806af 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -234,9 +234,10 @@ class NeighborSearch
*
* Deprecated. Will be removed in mlpack 3.0.0.
*
- * There is no copying of the data matrices in this constructor (because
- * tree-building is not necessary), so this is the constructor to use when
- * copies absolutely must be avoided.
+ * This method won't take ownership of the given tree. There is no copying of
+ * the data matrices in this constructor (because tree-building is not
+ * necessary), so this is the constructor to use when copies absolutely must
+ * be avoided.
*
* @note
* Mapping the points of the matrix back to their original indices is not done
@@ -257,6 +258,36 @@ class NeighborSearch
const MetricType metric = MetricType());
/**
+ * Initialize the NeighborSearch object with the given pre-constructed
+ * reference tree (this is the tree built on the points that will be
+ * searched). Optionally, choose to use single-tree mode. Naive mode is not
+ * available as an option for this constructor. Additionally, an instantiated
+ * distance metric can be given, for cases where the distance metric holds
+ * data.
+ *
+ * This method will take ownership of the given tree. There is no copying of
+ * the data matrices (because tree-building is not necessary), so this is the
+ * constructor to use when copies absolutely must be avoided.
+ *
+ * @note
+ * Mapping the points of the matrix back to their original indices is not done
+ * when this constructor is used, so if the tree type you are using maps
+ * points (like BinarySpaceTree), then you will have to perform the re-mapping
+ * manually.
+ * @endnote
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param epsilon Relative approximate error (non-negative).
+ * @param metric Instantiated distance metric.
+ */
+ NeighborSearch(Tree&& referenceTree,
+ const bool singleMode = false,
+ const double epsilon = 0,
+ const MetricType metric = MetricType());
+
+ /**
* Create a NeighborSearch object without any reference data. If Search() is
* called before a reference set is set with Train(), an exception will be
* thrown.
@@ -309,6 +340,15 @@ class NeighborSearch
void Train(Tree* referenceTree);
/**
+ * Set the reference tree to a new reference tree.
+ *
+ * This method will take ownership of the given tree.
+ *
+ * @param referenceTree Pre-built tree for reference points.
+ */
+ void Train(Tree&& referenceTree);
+
+ /**
* For each point in the query set, compute the nearest neighbors and store
* the output in the given matrices. The matrices will be set to the size of
* n columns by k rows, where n is the number of points in the query dataset
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index ac06418..24fa9b2 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -314,6 +314,26 @@ SingleTreeTraversalType>::NeighborSearch(Tree* referenceTree,
throw std::invalid_argument("epsilon must be non-negative");
}
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, DualTreeTraversalType,
+SingleTreeTraversalType>::NeighborSearch(Tree&& referenceTree,
+ const bool singleMode,
+ const double epsilon,
+ const MetricType metric) :
+ NeighborSearch(new Tree(std::move(referenceTree)), singleMode, epsilon,
+ metric)
+{
+ treeOwner = true;
+}
+
// Construct the object without a reference dataset.
template<typename SortPolicy,
typename MetricType,
@@ -480,7 +500,7 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
throw std::invalid_argument("cannot train on given reference tree when "
"naive search (without trees) is desired");
- if (treeOwner && referenceTree)
+ if (treeOwner && this->referenceTree)
delete this->referenceTree;
if (setOwner && referenceSet)
delete this->referenceSet;
@@ -491,6 +511,21 @@ DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree* referenceTree)
setOwner = false;
}
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class DualTreeTraversalType,
+ template<typename> class SingleTreeTraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
+DualTreeTraversalType, SingleTreeTraversalType>::Train(Tree&& referenceTree)
+{
+ Train(new Tree(std::move(referenceTree)));
+ treeOwner = true;
+}
+
/**
* Computes the best neighbors and stores them in resultingNeighbors and
* distances.
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1591730..5c16bca 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -176,11 +176,8 @@ void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
ns->Train(std::move(referenceSet));
else
{
- typename SpillKNN::Tree* tree = new typename SpillKNN::Tree(
- std::move(referenceSet), tau, leafSize, rho);
- ns->Train(tree);
- // Give the model ownership of the tree.
- ns->treeOwner = true;
+ typename SpillKNN::Tree tree(std::move(referenceSet), tau, leafSize, rho);
+ ns->Train(std::move(tree));
}
}
else
@@ -197,13 +194,10 @@ void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
else
{
std::vector<size_t> oldFromNewReferences;
- typename NSType::Tree* tree =
- new typename NSType::Tree(std::move(referenceSet),
+ typename NSType::Tree referenceTree(std::move(referenceSet),
oldFromNewReferences, leafSize);
- ns->Train(tree);
-
- // Give the model ownership of the tree and the mappings.
- ns->treeOwner = true;
+ ns->Train(std::move(referenceTree));
+ // Set the mappings.
ns->oldFromNewReferences = std::move(oldFromNewReferences);
}
}
More information about the mlpack-git
mailing list