[mlpack-git] master: Use rvalue references to prevent copies. (95d1357)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:48 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262
>---------------------------------------------------------------
commit 95d1357f44d8d1505b766cbbaba9215ccdad9ff7
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 19 14:57:53 2015 -0400
Use rvalue references to prevent copies.
>---------------------------------------------------------------
95d1357f44d8d1505b766cbbaba9215ccdad9ff7
src/mlpack/methods/neighbor_search/ns_model.hpp | 17 +--
.../methods/neighbor_search/ns_model_impl.hpp | 115 +++++++++------------
2 files changed, 52 insertions(+), 80 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index e9e8042..952ed9e 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -22,19 +22,19 @@ namespace neighbor {
template<typename SortPolicy>
struct NSModelName
{
- const static constexpr char value[22] = "neighbor_search_model";
+ static const std::string Name() { return "neighbor_search_model"; }
};
template<>
struct NSModelName<NearestNeighborSort>
{
- const static constexpr char value[30] = "nearest_neighbor_search_model";
+ static const std::string Name() { return "nearest_neighbor_search_model"; }
};
template<>
struct NSModelName<FurthestNeighborSort>
{
- const static constexpr char value[31] = "furthest_neighbor_search_model";
+ static const std::string Name() { return "furthest_neighbor_search_model"; }
};
template<typename SortPolicy>
@@ -57,9 +57,6 @@ class NSModel
bool randomBasis;
arma::mat q;
- // Mappings, in case they are necessary.
- std::vector<size_t> oldFromNewReferences;
-
template<template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
@@ -74,10 +71,6 @@ class NSModel
NSType<tree::RTree>* rTreeNS;
NSType<tree::RStarTree>* rStarTreeNS;
- // This pointers is only non-null if we are using kd-trees and we built the
- // tree ourselves (which only happens if BuildModel() is called).
- typename NSType<tree::KDTree>::Tree* kdTree;
-
public:
/**
* Initialize the NSModel with the given type and whether or not a random
@@ -112,13 +105,13 @@ class NSModel
bool& RandomBasis() { return randomBasis; }
//! Build the reference tree.
- void BuildModel(arma::mat& referenceSet,
+ void BuildModel(arma::mat&& referenceSet,
const size_t leafSize,
const bool naive,
const bool singleMode);
//! Perform neighbor search. The query set will be reordered.
- void Search(arma::mat& querySet,
+ void Search(arma::mat&& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index be0df8e..af5f3fb 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -27,8 +27,7 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
kdTreeNS(NULL),
coverTreeNS(NULL),
rTreeNS(NULL),
- rStarTreeNS(NULL),
- kdTree(NULL)
+ rStarTreeNS(NULL)
{
// Nothing to do.
}
@@ -37,9 +36,6 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
template<typename SortPolicy>
NSModel<SortPolicy>::~NSModel()
{
- if (kdTree)
- delete kdTree;
-
if (kdTreeNS)
delete kdTreeNS;
if (coverTreeNS)
@@ -52,21 +48,17 @@ NSModel<SortPolicy>::~NSModel()
//! Serialize the kNN model.
template<typename SortPolicy>
- template<typename Archive>
+template<typename Archive>
void NSModel<SortPolicy>::Serialize(Archive& ar,
const unsigned int /* version */)
{
ar & data::CreateNVP(treeType, "treeType");
ar & data::CreateNVP(randomBasis, "randomBasis");
ar & data::CreateNVP(q, "q");
- ar & data::CreateNVP(oldFromNewReferences, "oldFromNewReferences");
// This should never happen, but just in case, be clean with memory.
if (Archive::is_loading::value)
{
- if (kdTree)
- delete kdTree;
-
if (kdTreeNS)
delete kdTreeNS;
if (coverTreeNS)
@@ -77,8 +69,6 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
delete rStarTreeNS;
// Set all the pointers to NULL.
- kdTree = NULL;
-
kdTreeNS = NULL;
coverTreeNS = NULL;
rTreeNS = NULL;
@@ -86,7 +76,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
}
// We'll only need to serialize one of the kNN objects, based on the type.
- const std::string& name = NSModelName<SortPolicy>::value;
+ const std::string& name = NSModelName<SortPolicy>::Name();
switch (treeType)
{
case KD_TREE:
@@ -104,6 +94,21 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
}
}
+template<typename SortPolicy>
+const arma::mat& NSModel<SortPolicy>::Dataset() const
+{
+ if (kdTreeNS)
+ return kdTreeNS->ReferenceSet();
+ else if (coverTreeNS)
+ return coverTreeNS->ReferenceSet();
+ else if (rTreeNS)
+ return rTreeNS->ReferenceSet();
+ else if (rStarTreeNS)
+ return rStarTreeNS->ReferenceSet();
+
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Expose singleMode.
template<typename SortPolicy>
bool NSModel<SortPolicy>::SingleMode() const
@@ -167,7 +172,7 @@ bool& NSModel<SortPolicy>::Naive()
//! Build the reference tree.
template<typename SortPolicy>
-void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
+void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
const size_t leafSize,
const bool naive,
const bool singleMode)
@@ -205,9 +210,6 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
}
// Clean memory, if necessary.
- if (kdTree)
- delete kdTree;
-
if (kdTreeNS)
delete kdTreeNS;
if (coverTreeNS)
@@ -233,28 +235,37 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
// If necessary, build the kd-tree.
if (naive)
{
- kdTreeNS = new NSType<tree::KDTree>(referenceSet, naive, singleMode);
+ kdTreeNS = new NSType<tree::KDTree>(std::move(referenceSet), naive,
+ singleMode);
}
else
{
- kdTree = new typename NSType<tree::KDTree>::Tree(referenceSet,
+ std::vector<size_t> oldFromNewReferences;
+ typename NSType<tree::KDTree>::Tree* kdTree =
+ new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
oldFromNewReferences, leafSize);
kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
+
+ // Give the model ownership of the tree and the mappings.
+ kdTreeNS->treeOwner = true;
+ kdTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
}
break;
case COVER_TREE:
// If necessary, build the cover tree.
- coverTreeNS = new NSType<tree::StandardCoverTree>(referenceSet,
- singleMode);
+ coverTreeNS = new NSType<tree::StandardCoverTree>(std::move(referenceSet),
+ naive, singleMode);
break;
case R_TREE:
// If necessary, build the R tree.
- rTreeNS = new NSType<tree::RTree>(referenceSet, singleMode);
+ rTreeNS = new NSType<tree::RTree>(std::move(referenceSet), naive,
+ singleMode);
break;
case R_STAR_TREE:
// If necessary, build the R* tree.
- rStarTreeNS = new NSType<tree::RStarTree>(referenceSet, singleMode);
+ rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
+ singleMode);
break;
}
@@ -267,7 +278,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
//! Perform neighbor search. The query set will be reordered.
template<typename SortPolicy>
-void NSModel<SortPolicy>::Search(arma::mat& querySet,
+void NSModel<SortPolicy>::Search(arma::mat&& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -293,7 +304,7 @@ void NSModel<SortPolicy>::Search(arma::mat& querySet,
Timer::Start("tree_building");
Log::Info << "Building query tree..." << std::endl;
std::vector<size_t> oldFromNewQueries;
- typename NSType<tree::KDTree>::Tree queryTree(querySet,
+ typename NSType<tree::KDTree>::Tree queryTree(std::move(querySet),
oldFromNewQueries, leafSize);
Log::Info << "Tree built." << std::endl;
Timer::Stop("tree_building");
@@ -302,37 +313,19 @@ void NSModel<SortPolicy>::Search(arma::mat& querySet,
arma::mat distancesOut;
kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
- // Unmap the results.
- Unmap(neighborsOut, distancesOut, oldFromNewReferences,
- oldFromNewQueries, neighbors, distances);
- }
- else if (kdTreeNS->SingleMode() && !kdTreeNS->Naive())
- {
- // Search without building a second tree.
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
- kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
-
- Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
- distances);
+ // Unmap the query points.
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+ {
+ neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+ }
}
else
{
- // Naive mode search. No unmapping will be necessary... unless a tree
- // has been built.
- if (oldFromNewReferences.size() == 0)
- {
- kdTreeNS->Search(querySet, k, neighbors, distances);
- }
- else
- {
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
- kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
-
- Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
- distances);
- }
+ // Search without building a second tree.
+ kdTreeNS->Search(querySet, k, neighbors, distances);
}
break;
case COVER_TREE:
@@ -367,21 +360,7 @@ void NSModel<SortPolicy>::Search(const size_t k,
switch (treeType)
{
case KD_TREE:
- // If in dual-tree or single-tree mode, we'll have to do unmapping. We
- // also must do unmapping in naive mode, if a tree has been built on the
- // data.
- if (oldFromNewReferences.size() > 0) // Mapping has occured.
- {
- arma::Mat<size_t> neighborsOut;
- arma::mat distancesOut;
- kdTreeNS->Search(k, neighborsOut, distancesOut);
- Unmap(neighborsOut, distancesOut, oldFromNewReferences,
- oldFromNewReferences, neighbors, distances);
- }
- else
- {
- kdTreeNS->Search(k, neighbors, distances);
- }
+ kdTreeNS->Search(k, neighbors, distances);
break;
case COVER_TREE:
coverTreeNS->Search(k, neighbors, distances);
More information about the mlpack-git
mailing list