[mlpack-git] master: Add std::move constructor for datasets. (89e4f43)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:54 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262
>---------------------------------------------------------------
commit 89e4f430f64e04fcf59332adb8211452e9a5eee0
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 19 14:25:45 2015 -0400
Add std::move constructor for datasets.
>---------------------------------------------------------------
89e4f430f64e04fcf59332adb8211452e9a5eee0
.../methods/neighbor_search/neighbor_search.hpp | 43 +++++++-
.../neighbor_search/neighbor_search_impl.hpp | 108 ++++++++++++++++++++-
2 files changed, 146 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 56d9841..fa6a987 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -26,6 +26,10 @@ namespace neighbor /** Neighbor-search routines. These include
* all-nearest-neighbors and all-furthest-neighbors
* searches. */ {
+// Forward declaration.
+template<typename SortPolicy>
+class NSModel;
+
/**
* The NeighborSearch class is a template class for performing distance-based
* neighbor searches. It takes a query dataset and a reference dataset (or just
@@ -72,7 +76,8 @@ class NeighborSearch
*
* This method will copy the matrices to internal copies, which are rearranged
* during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a diferent constructor.
+ * the trees and passing them using a different constructor, or by using the
+ * construct that takes an rvalue reference to the dataset.
*
* @param referenceSet Set of reference points.
* @param naive If true, O(n^2) naive search will be used (as opposed to
@@ -87,6 +92,30 @@ class NeighborSearch
const MetricType metric = MetricType());
/**
+ * Initialize the NeighborSearch object, taking ownership of the reference
+ * dataset (this is the dataset which is searched). Optionally, perform the
+ * computation in naive mode or single-tree mode. An initialized distance
+ * metric can be given, for cases where the metric has internal data (i.e. the
+ * distance::MahalanobisDistance class).
+ *
+ * This method will not copy the data matrix, but will take ownership of it,
+ * and depending on the type of tree used, may rearrange the points. If you
+ * would rather a copy be made, consider using the construct that takes a
+ * const reference to the data instead.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, O(n^2) naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param metric An optional instance of the MetricType class.
+ */
+ NeighborSearch(MatType&& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
* Initialize the NeighborSearch object with the given pre-constructed
* reference tree (this is the tree built on the points that will be
* searched). Optionally, choose to use single-tree mode. Naive mode is not
@@ -147,6 +176,16 @@ class NeighborSearch
void Train(const MatType& referenceSet);
/**
+ * Set the reference set to a new reference set, taking ownership of the set,
+ * and build a tree if necessary. This method is called 'Train()' in order to
+ * match the rest of the mlpack abstractions, even though calling this
+ * "training" is maybe a bit of a stretch.
+ *
+ * @param referenceSet New set of reference data.
+ */
+ void Train(MatType&& referenceSet);
+
+ /**
* Set the reference tree to a new reference tree.
*/
void Train(Tree* referenceTree);
@@ -262,6 +301,8 @@ class NeighborSearch
//! The total number of scores (applicable for non-naive search).
size_t scores;
+ //! The NSModel class should have access to internal members.
+ friend class NSModel<SortPolicy>;
}; // class NeighborSearch
} // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 19eabc8..6b31d72 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -39,6 +39,30 @@ TreeType* BuildTree(
return new TreeType(dataset);
}
+//! Call the tree construct that does mapping.
+template<typename MatType, typename TreeType>
+TreeType* BuildTree(
+ MatType&& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset), oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename MatType, typename TreeType>
+TreeType* BuildTree(
+ MatType&& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset));
+}
+
// Construct the object.
template<typename SortPolicy,
typename MetricType,
@@ -75,6 +99,35 @@ template<typename SortPolicy,
typename TreeMatType> class TreeType,
template<typename> class TraversalType>
NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(MatType&& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceTree(naive ? NULL :
+ BuildTree<MatType, Tree>(std::move(referenceSetIn),
+ oldFromNewReferences)),
+ referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
+ &referenceTree->Dataset()),
+ treeOwner(!naive),
+ setOwner(naive),
+ naive(naive),
+ singleMode(!naive && singleMode),
+ metric(metric),
+ baseCases(0),
+ scores(0)
+{
+ // Nothing to do.
+}
+
+// Construct the object.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
NeighborSearch(Tree* referenceTree,
const bool singleMode,
const MetricType metric) :
@@ -149,18 +202,23 @@ template<typename SortPolicy,
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
Train(const MatType& referenceSet)
{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
// We may need to rebuild the tree.
if (!naive)
{
- if (treeOwner && referenceTree)
- delete referenceTree;
-
referenceTree = BuildTree<MatType, Tree>(referenceSet,
oldFromNewReferences);
-
treeOwner = true;
}
+ else
+ {
+ treeOwner = false;
+ }
+ // Delete the old reference set, if we owned it.
if (setOwner && this->referenceSet)
delete this->referenceSet;
@@ -179,6 +237,48 @@ template<typename SortPolicy,
typename TreeMatType> class TreeType,
template<typename> class TraversalType>
void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(MatType&& referenceSetIn)
+{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // We may need to rebuild the tree.
+ if (!naive)
+ {
+ referenceTree = BuildTree<MatType, Tree>(std::move(referenceSetIn),
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+ else
+ {
+ treeOwner = false;
+ }
+
+ // Delete the old reference set, if we owned it.
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ if (!naive)
+ {
+ referenceSet = &referenceTree->Dataset();
+ setOwner = false;
+ }
+ else
+ {
+ referenceSet = new MatType(std::move(referenceSetIn));
+ setOwner = true;
+ }
+}
+
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
Train(Tree* referenceTree)
{
if (naive)
More information about the mlpack-git
mailing list