[mlpack-git] master: Use rvalue references to prevent copies. (95d1357)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262

>---------------------------------------------------------------

commit 95d1357f44d8d1505b766cbbaba9215ccdad9ff7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 19 14:57:53 2015 -0400

    Use rvalue references to prevent copies.


>---------------------------------------------------------------

95d1357f44d8d1505b766cbbaba9215ccdad9ff7
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  17 +--
 .../methods/neighbor_search/ns_model_impl.hpp      | 115 +++++++++------------
 2 files changed, 52 insertions(+), 80 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index e9e8042..952ed9e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -22,19 +22,19 @@ namespace neighbor {
 template<typename SortPolicy>
 struct NSModelName
 {
-  const static constexpr char value[22] = "neighbor_search_model";
+  static const std::string Name() { return "neighbor_search_model"; }
 };
 
 template<>
 struct NSModelName<NearestNeighborSort>
 {
-  const static constexpr char value[30] = "nearest_neighbor_search_model";
+  static const std::string Name() { return "nearest_neighbor_search_model"; }
 };
 
 template<>
 struct NSModelName<FurthestNeighborSort>
 {
-  const static constexpr char value[31] = "furthest_neighbor_search_model";
+  static const std::string Name() { return "furthest_neighbor_search_model"; }
 };
 
 template<typename SortPolicy>
@@ -57,9 +57,6 @@ class NSModel
   bool randomBasis;
   arma::mat q;
 
-  // Mappings, in case they are necessary.
-  std::vector<size_t> oldFromNewReferences;
-
   template<template<typename TreeMetricType,
                     typename TreeStatType,
                     typename TreeMatType> class TreeType>
@@ -74,10 +71,6 @@ class NSModel
   NSType<tree::RTree>* rTreeNS;
   NSType<tree::RStarTree>* rStarTreeNS;
 
-  // This pointers is only non-null if we are using kd-trees and we built the
-  // tree ourselves (which only happens if BuildModel() is called).
-  typename NSType<tree::KDTree>::Tree* kdTree;
-
  public:
   /**
    * Initialize the NSModel with the given type and whether or not a random
@@ -112,13 +105,13 @@ class NSModel
   bool& RandomBasis() { return randomBasis; }
 
   //! Build the reference tree.
-  void BuildModel(arma::mat& referenceSet,
+  void BuildModel(arma::mat&& referenceSet,
                   const size_t leafSize,
                   const bool naive,
                   const bool singleMode);
 
   //! Perform neighbor search.  The query set will be reordered.
-  void Search(arma::mat& querySet,
+  void Search(arma::mat&& querySet,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index be0df8e..af5f3fb 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -27,8 +27,7 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
     kdTreeNS(NULL),
     coverTreeNS(NULL),
     rTreeNS(NULL),
-    rStarTreeNS(NULL),
-    kdTree(NULL)
+    rStarTreeNS(NULL)
 {
   // Nothing to do.
 }
@@ -37,9 +36,6 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
 template<typename SortPolicy>
 NSModel<SortPolicy>::~NSModel()
 {
-  if (kdTree)
-    delete kdTree;
-
   if (kdTreeNS)
     delete kdTreeNS;
   if (coverTreeNS)
@@ -52,21 +48,17 @@ NSModel<SortPolicy>::~NSModel()
 
 //! Serialize the kNN model.
 template<typename SortPolicy>
-  template<typename Archive>
+template<typename Archive>
 void NSModel<SortPolicy>::Serialize(Archive& ar, 
                                     const unsigned int /* version */)
 {
   ar & data::CreateNVP(treeType, "treeType");
   ar & data::CreateNVP(randomBasis, "randomBasis");
   ar & data::CreateNVP(q, "q");
-  ar & data::CreateNVP(oldFromNewReferences, "oldFromNewReferences");
 
   // This should never happen, but just in case, be clean with memory.
   if (Archive::is_loading::value)
   {
-    if (kdTree)
-      delete kdTree;
-
     if (kdTreeNS)
       delete kdTreeNS;
     if (coverTreeNS)
@@ -77,8 +69,6 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
       delete rStarTreeNS;
 
     // Set all the pointers to NULL.
-    kdTree = NULL;
-
     kdTreeNS = NULL;
     coverTreeNS = NULL;
     rTreeNS = NULL;
@@ -86,7 +76,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
   }
 
   // We'll only need to serialize one of the kNN objects, based on the type.
-  const std::string& name = NSModelName<SortPolicy>::value;
+  const std::string& name = NSModelName<SortPolicy>::Name();
   switch (treeType)
   {
     case KD_TREE:
@@ -104,6 +94,21 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
   }
 }
 
+template<typename SortPolicy>
+const arma::mat& NSModel<SortPolicy>::Dataset() const
+{
+  if (kdTreeNS)
+    return kdTreeNS->ReferenceSet();
+  else if (coverTreeNS)
+    return coverTreeNS->ReferenceSet();
+  else if (rTreeNS)
+    return rTreeNS->ReferenceSet();
+  else if (rStarTreeNS)
+    return rStarTreeNS->ReferenceSet();
+
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Expose singleMode.
 template<typename SortPolicy>
 bool NSModel<SortPolicy>::SingleMode() const
@@ -167,7 +172,7 @@ bool& NSModel<SortPolicy>::Naive()
 
 //! Build the reference tree.
 template<typename SortPolicy>
-void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
+void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
                                      const size_t leafSize,
                                      const bool naive,
                                      const bool singleMode)
@@ -205,9 +210,6 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
   }
 
   // Clean memory, if necessary.
-  if (kdTree)
-    delete kdTree;
-
   if (kdTreeNS)
     delete kdTreeNS;
   if (coverTreeNS)
@@ -233,28 +235,37 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
       // If necessary, build the kd-tree.
       if (naive)
       {
-        kdTreeNS = new NSType<tree::KDTree>(referenceSet, naive, singleMode);
+        kdTreeNS = new NSType<tree::KDTree>(std::move(referenceSet), naive,
+            singleMode);
       }
       else
       {
-        kdTree = new typename NSType<tree::KDTree>::Tree(referenceSet,
+        std::vector<size_t> oldFromNewReferences;
+        typename NSType<tree::KDTree>::Tree* kdTree =
+            new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
             oldFromNewReferences, leafSize);
         kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
+
+        // Give the model ownership of the tree and the mappings.
+        kdTreeNS->treeOwner = true;
+        kdTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
       }
 
       break;
     case COVER_TREE:
       // If necessary, build the cover tree.
-      coverTreeNS = new NSType<tree::StandardCoverTree>(referenceSet,
-          singleMode);
+      coverTreeNS = new NSType<tree::StandardCoverTree>(std::move(referenceSet),
+          naive, singleMode);
       break;
     case R_TREE:
       // If necessary, build the R tree.
-      rTreeNS = new NSType<tree::RTree>(referenceSet, singleMode);
+      rTreeNS = new NSType<tree::RTree>(std::move(referenceSet), naive,
+          singleMode);
       break;
     case R_STAR_TREE:
       // If necessary, build the R* tree.
-      rStarTreeNS = new NSType<tree::RStarTree>(referenceSet, singleMode);
+      rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
+          singleMode);
       break;
   }
 
@@ -267,7 +278,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
 
 //! Perform neighbor search.  The query set will be reordered.
 template<typename SortPolicy>
-void NSModel<SortPolicy>::Search(arma::mat& querySet,
+void NSModel<SortPolicy>::Search(arma::mat&& querySet,
                                  const size_t k,
                                  arma::Mat<size_t>& neighbors,
                                  arma::mat& distances)
@@ -293,7 +304,7 @@ void NSModel<SortPolicy>::Search(arma::mat& querySet,
         Timer::Start("tree_building");
         Log::Info << "Building query tree..." << std::endl;
         std::vector<size_t> oldFromNewQueries;
-        typename NSType<tree::KDTree>::Tree queryTree(querySet,
+        typename NSType<tree::KDTree>::Tree queryTree(std::move(querySet),
             oldFromNewQueries, leafSize);
         Log::Info << "Tree built." << std::endl;
         Timer::Stop("tree_building");
@@ -302,37 +313,19 @@ void NSModel<SortPolicy>::Search(arma::mat& querySet,
         arma::mat distancesOut;
         kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
 
-        // Unmap the results.
-        Unmap(neighborsOut, distancesOut, oldFromNewReferences,
-            oldFromNewQueries, neighbors, distances);
-      }
-      else if (kdTreeNS->SingleMode() && !kdTreeNS->Naive())
-      {
-        // Search without building a second tree.
-        arma::Mat<size_t> neighborsOut;
-        arma::mat distancesOut;
-        kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
-
-        Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
-            distances);
+        // Unmap the query points.
+        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+        {
+          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+        }
       }
       else
       {
-        // Naive mode search.  No unmapping will be necessary... unless a tree
-        // has been built.
-        if (oldFromNewReferences.size() == 0)
-        {
-          kdTreeNS->Search(querySet, k, neighbors, distances);
-        }
-        else
-        {
-          arma::Mat<size_t> neighborsOut;
-          arma::mat distancesOut;
-          kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
-
-          Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
-              distances);
-        }
+        // Search without building a second tree.
+        kdTreeNS->Search(querySet, k, neighbors, distances);
       }
       break;
     case COVER_TREE:
@@ -367,21 +360,7 @@ void NSModel<SortPolicy>::Search(const size_t k,
   switch (treeType)
   {
     case KD_TREE:
-      // If in dual-tree or single-tree mode, we'll have to do unmapping.  We
-      // also must do unmapping in naive mode, if a tree has been built on the
-      // data.
-      if (oldFromNewReferences.size() > 0) // Mapping has occured.
-      {
-        arma::Mat<size_t> neighborsOut;
-        arma::mat distancesOut;
-        kdTreeNS->Search(k, neighborsOut, distancesOut);
-        Unmap(neighborsOut, distancesOut, oldFromNewReferences,
-            oldFromNewReferences, neighbors, distances);
-      }
-      else
-      {
-        kdTreeNS->Search(k, neighbors, distances);
-      }
+      kdTreeNS->Search(k, neighbors, distances);
       break;
     case COVER_TREE:
       coverTreeNS->Search(k, neighbors, distances);



More information about the mlpack-git mailing list