[mlpack-git] master: Add Train() functions to NeighborSearch for consistency. (60c8170)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Sep 29 09:33:25 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cbeb3ea17262b7c5115247dc217e316c529249b7...f85a9b22f3ce56143943a2488c05c2810d6b2bf3
>---------------------------------------------------------------
commit 60c81702c432014544351948c091d03deeef6985
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 25 17:18:34 2015 -0400
Add Train() functions to NeighborSearch for consistency.
This will make serialization much easier, and also adds more flexibility to the
class.
>---------------------------------------------------------------
60c81702c432014544351948c091d03deeef6985
.../methods/neighbor_search/neighbor_search.hpp | 52 ++++-
.../neighbor_search/neighbor_search_impl.hpp | 247 +++++++++++++++++++--
2 files changed, 265 insertions(+), 34 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 0e56d25..ff01c88 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -116,12 +116,42 @@ class NeighborSearch
const MetricType metric = MetricType());
/**
+ * Create a NeighborSearch object without any reference data. If Search() is
+ * called before a reference set is set with Train(), an exception will be
+ * thrown.
+ *
+ * @param naive Whether to use naive search.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated metric.
+ */
+ NeighborSearch(const bool naive = false,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+
+ /**
* Delete the NeighborSearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
~NeighborSearch();
/**
+ * Set the reference set to a new reference set, and build a tree if
+ * necessary. This method is called 'Train()' in order to match the rest of
+ * the mlpack abstractions, even though calling this "training" is maybe a bit
+ * of a stretch.
+ *
+ * @param referenceSet New set of reference data.
+ */
+ void Train(const MatType& referenceSet);
+
+ /**
+ * Set the reference tree to a new reference tree.
+ */
+ void Train(Tree* referenceTree);
+
+ /**
* For each point in the query set, compute the nearest neighbors and store
* the output in the given matrices. The matrices will be set to the size of
* n columns by k rows, where n is the number of points in the query dataset
@@ -182,16 +212,12 @@ class NeighborSearch
//! Returns a string representation of this object.
std::string ToString() const;
- //! Return the total number of base case evaluations performed during
- //! searches.
+ //! Return the total number of base case evaluations performed during the last
+ //! search.
size_t BaseCases() const { return baseCases; }
- //! Modify the total number of base case evaluations.
- size_t& BaseCases() { return baseCases; }
- //! Return the number of node combination scores during the search.
+ //! Return the number of node combination scores during the last search.
size_t Scores() const { return scores; }
- //! Modify the number of node combination scores.
- size_t& Scores() { return scores; }
//! Access whether or not search is done in naive linear scan mode.
bool Naive() const { return naive; }
@@ -203,20 +229,26 @@ class NeighborSearch
//! Modify whether or not search is done in single-tree mode.
bool& SingleMode() { return singleMode; }
+ //! Serialize the NeighborSearch model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! Permutations of reference points during tree building.
std::vector<size_t> oldFromNewReferences;
//! Pointer to the root of the reference tree.
Tree* referenceTree;
- //! Reference to reference dataset.
- const MatType& referenceSet;
+ //! Reference dataset. In some situations we may be the owner of this.
+ const MatType* referenceSet;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
+ //! If true, we own the reference set.
+ bool setOwner;
//! Indicates if O(n^2) naive search is being used.
bool naive;
- //! Indicates if single-tree search is being used (opposed to dual-tree).
+ //! Indicates if single-tree search is being used (as opposed to dual-tree).
bool singleMode;
//! Instantiation of metric.
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index 82bec9c..19eabc8 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -54,8 +54,9 @@ NeighborSearch(const MatType& referenceSetIn,
const MetricType metric) :
referenceTree(naive ? NULL :
BuildTree<MatType, Tree>(referenceSetIn, oldFromNewReferences)),
- referenceSet(naive ? referenceSetIn : referenceTree->Dataset()),
+ referenceSet(naive ? &referenceSetIn : &referenceTree->Dataset()),
treeOwner(!naive), // False if a tree was passed. If naive, then no trees.
+ setOwner(false),
naive(naive),
singleMode(!naive && singleMode), // No single mode if naive.
metric(metric),
@@ -78,8 +79,9 @@ NeighborSearch(Tree* referenceTree,
const bool singleMode,
const MetricType metric) :
referenceTree(referenceTree),
- referenceSet(referenceTree->Dataset()),
+ referenceSet(&referenceTree->Dataset()),
treeOwner(false),
+ setOwner(false),
naive(false),
singleMode(singleMode),
metric(metric),
@@ -89,6 +91,37 @@ NeighborSearch(Tree* referenceTree,
// Nothing else to initialize.
}
+// Construct the object without a reference dataset.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+ NeighborSearch(const bool naive,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceTree(NULL),
+ referenceSet(new MatType()), // Empty matrix.
+ treeOwner(false),
+ setOwner(true),
+ naive(naive),
+ singleMode(singleMode),
+ metric(metric),
+ baseCases(0),
+ scores(0)
+{
+ // Build the tree on the empty dataset, if necessary.
+ if (!naive)
+ {
+ referenceTree = BuildTree<MatType, Tree>(*referenceSet,
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+}
+
// Clean memory.
template<typename SortPolicy,
typename MetricType,
@@ -98,10 +131,69 @@ template<typename SortPolicy,
typename TreeMatType> class TreeType,
template<typename> class TraversalType>
NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
- ~NeighborSearch()
+~NeighborSearch()
{
if (treeOwner && referenceTree)
delete referenceTree;
+ if (setOwner && referenceSet)
+ delete referenceSet;
+}
+
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(const MatType& referenceSet)
+{
+ // We may need to rebuild the tree.
+ if (!naive)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ referenceTree = BuildTree<MatType, Tree>(referenceSet,
+ oldFromNewReferences);
+
+ treeOwner = true;
+ }
+
+ if (setOwner && this->referenceSet)
+ delete this->referenceSet;
+
+ if (!naive)
+ this->referenceSet = &referenceTree->Dataset();
+ else
+ this->referenceSet = &referenceSet;
+ setOwner = false; // We don't own the set in either case.
+}
+
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+Train(Tree* referenceTree)
+{
+ if (naive)
+ throw std::invalid_argument("cannot train on given reference tree when "
+ "naive search (without trees) is desired");
+
+ if (treeOwner && referenceTree)
+ delete this->referenceTree;
+ if (setOwner && referenceSet)
+ delete this->referenceSet;
+
+ this->referenceTree = referenceTree;
+ this->referenceSet = &referenceTree->Dataset();
+ treeOwner = false;
+ setOwner = false;
}
/**
@@ -121,8 +213,19 @@ Search(const MatType& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
+ if (k > referenceSet->n_cols)
+ {
+ std::stringstream ss;
+ ss << "requested value of k (" << k << ") is greater than the number of "
+ << "points in the reference set (" << referenceSet->n_cols << ")";
+ throw std::invalid_argument(ss.str());
+ }
+
Timer::Start("computing_neighbors");
+ baseCases = 0;
+ scores = 0;
+
// This will hold mappings for query points, if necessary.
std::vector<size_t> oldFromNewQueries;
@@ -154,19 +257,19 @@ Search(const MatType& querySet,
if (naive)
{
// Create the helper object for the tree traversal.
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
// The naive brute-force traversal.
for (size_t i = 0; i < querySet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
rules.BaseCase(i, j);
- baseCases += querySet.n_cols * referenceSet.n_cols;
+ baseCases += querySet.n_cols * referenceSet->n_cols;
}
else if (singleMode)
{
// Create the helper object for the tree traversal.
- RuleType rules(referenceSet, querySet, *neighborPtr, *distancePtr, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, *distancePtr, metric);
// Create the traverser.
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
@@ -191,7 +294,7 @@ Search(const MatType& querySet,
Timer::Start("computing_neighbors");
// Create the helper object for the tree traversal.
- RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr,
+ RuleType rules(*referenceSet, queryTree->Dataset(), *neighborPtr,
*distancePtr, metric);
// Create the traverser.
@@ -283,16 +386,27 @@ Search(Tree* queryTree,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
- Timer::Start("computing_neighbors");
-
- // Get a reference to the query set.
- const MatType& querySet = queryTree->Dataset();
+ if (k > referenceSet->n_cols)
+ {
+ std::stringstream ss;
+ ss << "requested value of k (" << k << ") is greater than the number of "
+ << "points in the reference set (" << referenceSet->n_cols << ")";
+ throw std::invalid_argument(ss.str());
+ }
// Make sure we are in dual-tree mode.
if (singleMode || naive)
throw std::invalid_argument("cannot call NeighborSearch::Search() with a "
"query tree when naive or singleMode are set to true");
+ Timer::Start("computing_neighbors");
+
+ baseCases = 0;
+ scores = 0;
+
+ // Get a reference to the query set.
+ const MatType& querySet = queryTree->Dataset();
+
// We won't need to map query indices, but will we need to map distances?
arma::Mat<size_t>* neighborPtr = &neighbors;
@@ -306,7 +420,7 @@ Search(Tree* queryTree,
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(referenceSet, querySet, *neighborPtr, distances, metric);
+ RuleType rules(*referenceSet, querySet, *neighborPtr, distances, metric);
// Create the traverser.
TraversalType<RuleType> traverser(rules);
@@ -345,8 +459,19 @@ Search(const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
+ if (k > referenceSet->n_cols)
+ {
+ std::stringstream ss;
+ ss << "requested value of k (" << k << ") is greater than the number of "
+ << "points in the reference set (" << referenceSet->n_cols << ")";
+ throw std::invalid_argument(ss.str());
+ }
+
Timer::Start("computing_neighbors");
+ baseCases = 0;
+ scores = 0;
+
arma::Mat<size_t>* neighborPtr = &neighbors;
arma::mat* distancePtr = &distances;
@@ -358,24 +483,24 @@ Search(const size_t k,
}
// Initialize results.
- neighborPtr->set_size(k, referenceSet.n_cols);
+ neighborPtr->set_size(k, referenceSet->n_cols);
neighborPtr->fill(size_t() - 1);
- distancePtr->set_size(k, referenceSet.n_cols);
+ distancePtr->set_size(k, referenceSet->n_cols);
distancePtr->fill(SortPolicy::WorstDistance());
// Create the helper object for the traversal.
typedef NeighborSearchRules<SortPolicy, MetricType, Tree> RuleType;
- RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
+ RuleType rules(*referenceSet, *referenceSet, *neighborPtr, *distancePtr,
metric, true /* don't return the same point as nearest neighbor */);
if (naive)
{
// The naive brute-force solution.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
- for (size_t j = 0; j < referenceSet.n_cols; ++j)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
+ for (size_t j = 0; j < referenceSet->n_cols; ++j)
rules.BaseCase(i, j);
- baseCases += referenceSet.n_cols * referenceSet.n_cols;
+ baseCases += referenceSet->n_cols * referenceSet->n_cols;
}
else if (singleMode)
{
@@ -383,7 +508,7 @@ Search(const size_t k,
typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
- for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ for (size_t i = 0; i < referenceSet->n_cols; ++i)
traverser.Traverse(i, *referenceTree);
scores += rules.Scores();
@@ -411,8 +536,8 @@ Search(const size_t k,
// Do we need to map the reference indices?
if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
- neighbors.set_size(k, referenceSet.n_cols);
- distances.set_size(k, referenceSet.n_cols);
+ neighbors.set_size(k, referenceSet->n_cols);
+ distances.set_size(k, referenceSet->n_cols);
for (size_t i = 0; i < distances.n_cols; ++i)
{
@@ -444,8 +569,8 @@ std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
{
std::ostringstream convert;
convert << "NeighborSearch [" << this << "]" << std::endl;
- convert << " Reference set: " << referenceSet.n_rows << "x" ;
- convert << referenceSet.n_cols << std::endl;
+ convert << " Reference set: " << referenceSet->n_rows << "x" ;
+ convert << referenceSet->n_cols << std::endl;
if (referenceTree)
convert << " Reference tree: " << referenceTree << std::endl;
convert << " Tree owner: " << treeOwner << std::endl;
@@ -455,6 +580,80 @@ std::string NeighborSearch<SortPolicy, MetricType, MatType, TreeType,
return convert.str();
}
+//! Serialize the NeighborSearch model.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType,
+ template<typename> class TraversalType>
+template<typename Archive>
+void NeighborSearch<SortPolicy, MetricType, MatType, TreeType, TraversalType>::
+ Serialize(Archive& ar, const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // Serialize preferences for search.
+ ar & CreateNVP(naive, "naive");
+ ar & CreateNVP(singleMode, "singleMode");
+
+ // If we are doing naive search, we serialize the dataset. Otherwise we
+ // serialize the tree.
+ if (naive)
+ {
+ // Delete the current reference set, if necessary and if we are loading.
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ setOwner = true; // We will own the reference set when we load it.
+ }
+
+ ar & CreateNVP(referenceSet, "referenceSet");
+ ar & CreateNVP(metric, "metric");
+
+ // If we are loading, set the tree to NULL and clean up memory if necessary.
+ if (Archive::is_loading::value)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ referenceTree = NULL;
+ oldFromNewReferences.clear();
+ treeOwner = false;
+ }
+ }
+ else
+ {
+ // Delete the current reference tree, if necessary and if we are loading.
+ if (Archive::is_loading::value)
+ {
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // After we load the tree, we will own it.
+ treeOwner = true;
+ }
+
+ ar & CreateNVP(referenceTree, "referenceTree");
+ ar & CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+ // If we are loading, set the dataset accordingly and clean up memory if
+ // necessary.
+ if (Archive::is_loading::value)
+ {
+ if (setOwner && referenceSet)
+ delete referenceSet;
+
+ referenceSet = &referenceTree->Dataset();
+ metric = referenceTree->Metric(); // Get the metric from the tree.
+ setOwner = false;
+ }
+ }
+}
+
} // namespace neighbor
} // namespace mlpack
More information about the mlpack-git
mailing list