[mlpack-git] master: Add move constructors and Train() functions. (06b2e58)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:22 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7
>---------------------------------------------------------------
commit 06b2e584866e060bc6e478306c7a963c6688811f
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Nov 2 17:17:01 2015 +0000
Add move constructors and Train() functions.
>---------------------------------------------------------------
06b2e584866e060bc6e478306c7a963c6688811f
src/mlpack/methods/range_search/range_search.hpp | 76 +++++++++
.../methods/range_search/range_search_impl.hpp | 182 +++++++++++++++++++++
src/mlpack/tests/range_search_test.cpp | 23 +++
3 files changed, 281 insertions(+)
diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index de95d79..ffb3d86 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -59,6 +59,30 @@ class RangeSearch
const MetricType metric = MetricType());
/**
+ * Initialize the RangeSearch object with the given reference dataset (this is
+ * the dataset which is searched), taking ownership of the matrix.
+ * Optionally, perform the computation in naive mode or single-tree mode.
+ * Additionally, an instantiated metric can be given, for cases where the
+ * distance metric holds data.
+ *
+ * This method will not copy the data matrix, but will take ownership of it,
+ * and depending on the type of tree used, may rearrange the points. If you
+ * would rather a copy be made, consider using the constructor that takes a
+ * const reference to the data instead.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, brute force naive search will be used (as opposed to
+ * dual-tree search). This overrides singleMode (if it is set to true).
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search).
+ * @param metric An optional instance of the MetricType class.
+ */
+ RangeSearch(MatType&& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
* Initialize the RangeSearch object with the given pre-constructed reference
* tree (this is the tree built on the reference set, which is the set that is
* searched). Optionally, choose to use single-tree mode, which will not
@@ -87,12 +111,51 @@ class RangeSearch
const MetricType metric = MetricType());
/**
+ * Initialize the RangeSearch object without any reference data. If the
+ * monochromatic Search() is called before a reference set is set with
+ * Train(), no results will be returned (since the reference set is empty).
+ *
+ * @param naive Whether to use naive search.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param metric Instantiated metric.
+ */
+ RangeSearch(const bool naive = false,
+ const bool singleMode = false,
+ const MetricType metric = MetricType());
+
+ /**
* Destroy the RangeSearch object. If trees were created, they will be
* deleted.
*/
~RangeSearch();
/**
+ * Set the reference set to a new reference set, and build a tree if
+ * necessary. This method is called 'Train()' in order to match the rest of
+ * the mlpack abstractions, even though calling this "training" is maybe a bit
+ * of a stretch.
+ *
+ * @param referenceSet New set of reference data.
+ */
+ void Train(const MatType& referenceSet);
+
+ /**
+ * Set the reference set to a new reference set, taking ownership of the set.
+ * A tree is built if necessary. This method is called 'Train()' in order to
+ * match the rest of the mlpack abstractions, even though calling this
+ * "training" is maybe a bit of a stretch.
+ *
+ * @param referenceSet New set of reference data.
+ */
+ void Train(MatType&& referenceSet);
+
+ /**
+ * Set the reference tree to a new reference tree.
+ */
+ void Train(Tree* referenceTree);
+
+ /**
* Search for all reference points in the given range for each point in the
* query set, returning the results in the neighbors and distances objects.
* Each entry in the external vector corresponds to a query point. Each of
@@ -199,6 +262,16 @@ class RangeSearch
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances);
+ //! Get whether single-tree search is being used.
+ bool SingleMode() const { return singleMode; }
+ //! Modify whether single-tree search is being used.
+ bool& SingleMode() { return singleMode; }
+
+ //! Get whether naive search is being used.
+ bool Naive() const { return naive; }
+ //! Modify whether naive search is being used.
+ bool& Naive() { return naive; }
+
//! Get the number of base cases during the last search.
size_t BaseCases() const { return baseCases; }
//! Get the number of scores during the last search.
@@ -211,6 +284,9 @@ class RangeSearch
//! Returns a string representation of this object.
std::string ToString() const;
+ //! Return the reference set.
+ const MatType& ReferenceSet() const { return *referenceSet; }
+
//! Return the reference tree (or NULL if in naive mode).
Tree* ReferenceTree() { return referenceTree; }
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index a813dca..4eef20f 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -39,6 +39,28 @@ TreeType* BuildTree(
return new TreeType(dataset);
}
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat&& dataset,
+ std::vector<size_t>& oldFromNew,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset), oldFromNew);
+}
+
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat&& dataset,
+ const std::vector<size_t>& /* oldFromNew */,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset));
+}
+
template<typename MetricType,
typename MatType,
template<typename TreeMetricType,
@@ -63,6 +85,32 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
// Nothing to do.
}
+// Move constructor.
+template<typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+ MatType&& referenceSet,
+ const bool naive,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceTree(naive ? NULL : BuildTree<Tree>(std::move(referenceSet),
+ oldFromNewReferences)),
+ referenceSet(naive ? new MatType(std::move(referenceSet)) :
+ &referenceTree->Dataset()),
+ treeOwner(!naive),
+ setOwner(naive),
+ naive(naive),
+ singleMode(!naive && singleMode),
+ metric(metric),
+ baseCases(0),
+ scores(0)
+{
+ // Nothing to do.
+}
+
template<typename MetricType,
typename MatType,
template<typename TreeMetricType,
@@ -90,6 +138,34 @@ template<typename MetricType,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+ const bool naive,
+ const bool singleMode,
+ const MetricType metric) :
+ referenceTree(NULL),
+ referenceSet(new MatType()), // Empty matrix.
+ treeOwner(false),
+ setOwner(true),
+ naive(naive),
+ singleMode(singleMode),
+ metric(metric),
+ baseCases(0),
+ scores(0)
+{
+ // Build the tree on the empty dataset, if necessary.
+ if (!naive)
+ {
+ referenceTree = BuildTree<Tree>(const_cast<MatType&>(*referenceSet),
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
{
if (treeOwner && referenceTree)
@@ -103,12 +179,118 @@ template<typename MetricType,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+ const MatType& referenceSet)
+{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // Rebuild the tree, if necessary.
+ if (!naive)
+ {
+ referenceTree = BuildTree<MatType>(referenceSet, oldFromNewReferences);
+ treeOwner = true;
+ }
+ else
+ {
+ treeOwner = false;
+ }
+
+ // Delete the old reference set, if we owned it.
+ if (setOwner && this->referenceSet)
+ delete this->referenceSet;
+
+ if (!naive)
+ this->referenceSet = &referenceTree->Dataset();
+ else
+ this->referenceSet = &referenceSet;
+ setOwner = false;
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+ MatType&& referenceSet)
+{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // We may need to rebuild the tree.
+ if (!naive)
+ {
+ referenceTree = BuildTree<MatType>(std::move(referenceSet),
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+ else
+ {
+ treeOwner = false;
+ }
+
+ // Delete the old reference set, if we owned it.
+ if (setOwner && this->referenceSet)
+ delete this->referenceSet;
+
+ if (!naive)
+ {
+ this->referenceSet = &referenceTree->Dataset();
+ setOwner = false;
+ }
+ else
+ {
+ this->referenceSet = new MatType(std::move(referenceSet));
+ setOwner = true;
+ }
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+ Tree* referenceTree)
+{
+ if (naive)
+ throw std::invalid_argument("cannot train on given reference tree when "
+ "naive search (without trees) is desired");
+
+ if (treeOwner && referenceTree)
+ delete this->referenceTree;
+ if (setOwner && referenceSet)
+ delete this->referenceSet;
+
+ this->referenceTree = referenceTree;
+ this->referenceSet = &referenceTree->Dataset();
+ treeOwner = false;
+ setOwner = false;
+}
+
+template<typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
void RangeSearch<MetricType, MatType, TreeType>::Search(
const MatType& querySet,
const math::Range& range,
std::vector<std::vector<size_t>>& neighbors,
std::vector<std::vector<double>>& distances)
{
+ if (querySet.n_rows != referenceSet->n_rows)
+ {
+ std::ostringstream oss;
+ oss << "RangeSearch::Search(): dimensionalities of query set ("
+ << querySet.n_rows << ") and reference set (" << referenceSet->n_rows
+ << ") do not match!";
+ throw std::invalid_argument(oss.str());
+ }
+
Timer::Start("range_search/computing_neighbors");
// This will hold mappings for query points, if necessary.
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 3110a9f..bc34d59 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1043,4 +1043,27 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
}
}
+/**
+ * Make sure that no results are returned when we build a range search object
+ * with no reference set.
+ */
+BOOST_AUTO_TEST_CASE(EmptySearchTest)
+{
+ RangeSearch<EuclideanDistance, arma::mat, KDTree> rs;
+
+ vector<vector<size_t>> neighbors;
+ vector<vector<double>> distances;
+
+ rs.Search(math::Range(0.0, 10.0), neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.size(), 0);
+ BOOST_REQUIRE_EQUAL(distances.size(), 0);
+
+ // Now check with a query set.
+ arma::mat querySet = arma::randu<arma::mat>(3, 100);
+
+ BOOST_REQUIRE_THROW(rs.Search(querySet, math::Range(0.0, 10.0), neighbors,
+ distances), std::invalid_argument);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list