[mlpack-git] master: Add std::move constructor for datasets. (89e4f43)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:54 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262

>---------------------------------------------------------------

commit 89e4f430f64e04fcf59332adb8211452e9a5eee0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 19 14:25:45 2015 -0400

    Add std::move constructor for datasets.


>---------------------------------------------------------------

89e4f430f64e04fcf59332adb8211452e9a5eee0
 .../methods/neighbor_search/neighbor_search.hpp    |  43 +++++++-
 .../neighbor_search/neighbor_search_impl.hpp       | 108 ++++++++++++++++++++-
 2 files changed, 146 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 56d9841..fa6a987 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -26,6 +26,10 @@ namespace neighbor /** Neighbor-search routines.  These include
                     * all-nearest-neighbors and all-furthest-neighbors
                     * searches. */ {
 
+// Forward declaration.
+template<typename SortPolicy>
+class NSModel;
+
 /**
  * The NeighborSearch class is a template class for performing distance-based
  * neighbor searches.  It takes a query dataset and a reference dataset (or just
@@ -72,7 +76,8 @@ class NeighborSearch
    *
    * This method will copy the matrices to internal copies, which are rearranged
    * during tree-building.  You can avoid this extra copy by pre-constructing
-   * the trees and passing them using a diferent constructor.
+   * the trees and passing them using a different constructor, or by using the
+   * construct that takes an rvalue reference to the dataset.
    *
    * @param referenceSet Set of reference points.
    * @param naive If true, O(n^2) naive search will be used (as opposed to
@@ -87,6 +92,30 @@ class NeighborSearch
                  const MetricType metric = MetricType());
 
   /**
+   * Initialize the NeighborSearch object, taking ownership of the reference
+   * dataset (this is the dataset which is searched).  Optionally, perform the
+   * computation in naive mode or single-tree mode.  An initialized distance
+   * metric can be given, for cases where the metric has internal data (i.e. the
+   * distance::MahalanobisDistance class).
+   *
+   * This method will not copy the data matrix, but will take ownership of it,
+   * and depending on the type of tree used, may rearrange the points.  If you
+   * would rather a copy be made, consider using the construct that takes a
+   * const reference to the data instead.
+   *
+   * @param referenceSet Set of reference points.
+   * @param naive If true, O(n^2) naive search will be used (as opposed to
+   *      dual-tree search).  This overrides singleMode (if it is set to true).
+   * @param singleMode If true, single-tree search will be used (as opposed to
+   *      dual-tree search).
+   * @param metric An optional instance of the MetricType class.
+   */
+  NeighborSearch(MatType&& referenceSet,
+                 const bool naive = false,
+                 const bool singleMode = false,
+                 const MetricType metric = MetricType());
+
+  /**
    * Initialize the NeighborSearch object with the given pre-constructed
    * reference tree (this is the tree built on the points that will be
    * searched).  Optionally, choose to use single-tree mode.  Naive mode is not
@@ -147,6 +176,16 @@ class NeighborSearch
   void Train(const MatType& referenceSet);
 
   /**
+   * Set the reference set to a new reference set, taking ownership of the set,
+   * and build a tree if necessary.  This method is called 'Train()' in order to
+   * match the rest of the mlpack abstractions, even though calling this
+   * "training" is maybe a bit of a stretch.
+   *
+   * @param referenceSet New set of reference data.
+   */
+  void Train(MatType&& referenceSet);
+
+  /**
    * Set the reference tree to a new reference tree.
    */
   void Train(Tree* referenceTree);
@@ -262,6 +301,8 @@ class NeighborSearch
   //! The total number of scores (applicable for non-naive search).
   size_t scores;
 
+  //! The NSModel class should have access to internal members.
+  friend class NSModel<SortPolicy>;
 }; // class NeighborSearch
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 19eabc8..6b31d72 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -39,6 +39,30 @@ TreeType* BuildTree(
   return new TreeType(dataset);
 }
 
+//! Call the tree construct that does mapping.
+template<typename MatType, typename TreeType>
+TreeType* BuildTree(
+    MatType&& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset), oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename MatType, typename TreeType>
+TreeType* BuildTree(
+    MatType&& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset));
+}
+
 // Construct the object.
 template<typename SortPolicy,
          typename MetricType,
@@ -75,6 +99,35 @@ template<typename SortPolicy,
                   typename TreeMatType> class TreeType,
          template<typename> class TraversalType>
 NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+NeighborSearch(MatType&& referenceSetIn,
+               const bool naive,
+               const bool singleMode,
+               const MetricType metric) :
+    referenceTree(naive ? NULL :
+        BuildTree<MatType, Tree>(std::move(referenceSetIn),
+                                 oldFromNewReferences)),
+    referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
+        &referenceTree->Dataset()),
+    treeOwner(!naive),
+    setOwner(naive),
+    naive(naive),
+    singleMode(!naive && singleMode),
+    metric(metric),
+    baseCases(0),
+    scores(0)
+{
+  // Nothing to do.
+}
+
+// Construct the object.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 NeighborSearch(Tree* referenceTree,
                const bool singleMode,
                const MetricType metric) :
@@ -149,18 +202,23 @@ template<typename SortPolicy,
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 Train(const MatType& referenceSet)
 {
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
   // We may need to rebuild the tree.
   if (!naive)
   {
-    if (treeOwner && referenceTree)
-      delete referenceTree;
-
     referenceTree = BuildTree<MatType, Tree>(referenceSet,
         oldFromNewReferences);
-
     treeOwner = true;
   }
+  else
+  {
+    treeOwner = false;
+  }
 
+  // Delete the old reference set, if we owned it.
   if (setOwner && this->referenceSet)
     delete this->referenceSet;
 
@@ -179,6 +237,48 @@ template<typename SortPolicy,
                   typename TreeMatType> class TreeType,
          template<typename> class TraversalType>
 void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(MatType&& referenceSetIn)
+{
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  // We may need to rebuild the tree.
+  if (!naive)
+  {
+    referenceTree = BuildTree<MatType, Tree>(std::move(referenceSetIn),
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+  else
+  {
+    treeOwner = false;
+  }
+
+  // Delete the old reference set, if we owned it.
+  if (setOwner && referenceSet)
+    delete referenceSet;
+
+  if (!naive)
+  {
+    referenceSet = &referenceTree->Dataset();
+    setOwner = false;
+  }
+  else
+  {
+    referenceSet = new MatType(std::move(referenceSetIn));
+    setOwner = true;
+  }
+}
+
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType,
+         template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
 Train(Tree* referenceTree)
 {
   if (naive)



More information about the mlpack-git mailing list