[mlpack-git] master: Add empty constructor, rvalue reference constructor, and tests. (157595c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 8 10:53:09 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/de31c15aea34256db01ba8d0816a9c0740041c1f...157595c68e3d26679e90152f07e1ee28e5e563c2
>---------------------------------------------------------------
commit 157595c68e3d26679e90152f07e1ee28e5e563c2
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Dec 8 10:52:40 2015 -0500
Add empty constructor, rvalue reference constructor, and tests.
>---------------------------------------------------------------
157595c68e3d26679e90152f07e1ee28e5e563c2
src/mlpack/methods/rann/ra_search.hpp | 122 ++++++++++++++++++-
src/mlpack/methods/rann/ra_search_impl.hpp | 188 ++++++++++++++++++++++++++++-
src/mlpack/tests/allkrann_search_test.cpp | 57 +++++++++
3 files changed, 362 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 40a4c0f..9a4084b 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -74,7 +74,8 @@ class RASearch
*
* This method will copy the matrices to internal copies, which are rearranged
* during tree-building. You can avoid this extra copy by pre-constructing
- * the trees and passing them using a different constructor.
+ * the trees and using the appropriate constructor, or by using the
+ * constructor that takes an rvalue reference to the data with std::move().
*
* tau, the rank-approximation parameter, specifies that we are looking for k
* neighbors with probability alpha of being in the top tau percent of nearest
@@ -120,6 +121,60 @@ class RASearch
const MetricType metric = MetricType());
/**
+ * Initialize the RASearch object, passing both a reference dataset (this is
+ * the dataset that will be searched). Optionally, perform the computation in
+ * naive mode or single-tree mode. An initialized distance metric can be
+ * given, for cases where the metric has internal data (i.e. the
+ * distance::MahalanobisDistance class).
+ *
+ * This method will take ownership of the given reference set, avoiding a
+ * copy. If you need to use the reference set for other purposes, too,
+ * consider using the constructor that takes a const reference.
+ *
+ * tau, the rank-approximation parameter, specifies that we are looking for k
+ * neighbors with probability alpha of being in the top tau percent of nearest
+ * neighbors. So, as an example, if our dataset has 1000 points, and we want
+ * 5 nearest neighbors with 95% probability of being in the top 5% of nearest
+ * neighbors (or, the top 50 nearest neighbors), we set k = 5, tau = 5, and
+ * alpha = 0.95.
+ *
+ * The method will fail (and throw a std::invalid_argument exception) if the
+ * value of tau is too low: tau must be set such that the number of points in
+ * the corresponding percentile of the data is greater than k. Thus, if we
+ * choose tau = 0.1 with a dataset of 1000 points and k = 5, then we are
+ * attempting to choose 5 nearest neighbors out of the closest 1 point -- this
+ * is invalid.
+ *
+ * @param referenceSet Set of reference points.
+ * @param naive If true, the rank-approximate search will be performed by
+ * directly sampling the whole set instead of using the stratified
+ * sampling on the tree.
+ * @param singleMode If true, single-tree search will be used (as opposed to
+ * dual-tree search). This is useful when Search() will be called with
+ * few query points.
+ * @param metric An optional instance of the MetricType class.
+ * @param tau The rank-approximation in percentile of the data. The default
+ * value is 5%.
+ * @param alpha The desired success probability. The default value is 0.95.
+ * @param sampleAtLeaves Sample at leaves for faster but less accurate
+ * computation. This defaults to 'false'.
+ * @param firstLeafExact Traverse to the first leaf without approximation.
+ * This can ensure that the query definitely finds its (near) duplicate
+ * if there exists one. This defaults to 'false' for now.
+ * @param singleSampleLimit The limit on the largest node that can be
+ * approximated by sampling. This defaults to 20.
+ */
+ RASearch(MatType&& referenceSet,
+ const bool naive = false,
+ const bool singleMode = false,
+ const double tau = 5,
+ const double alpha = 0.95,
+ const bool sampleAtLeaves = false,
+ const bool firstLeafExact = false,
+ const size_t singleSampleLimit = 20,
+ const MetricType metric = MetricType());
+
+ /**
* Initialize the RASearch object with the given pre-constructed reference
* tree. It is assumed that the points in the tree's dataset correspond to
* the reference set. Optionally, choose to use single-tree mode. Naive mode
@@ -155,7 +210,6 @@ class RASearch
* @param referenceTree Pre-built tree for reference points.
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
- * @param metric Instantiated distance metric.
* @param tau The rank-approximation in percentile of the data. The default
* value is 5%.
* @param alpha The desired success probability. The default value is 0.95.
@@ -166,6 +220,7 @@ class RASearch
* if there exists one. This defaults to 'false' for now.
* @param singleSampleLimit The limit on the largest node that can be
* approximated by sampling. This defaults to 20.
+ * @param metric Instantiated distance metric.
*/
RASearch(Tree* referenceTree,
const bool singleMode = false,
@@ -177,12 +232,62 @@ class RASearch
const MetricType metric = MetricType());
/**
+ * Create an RASearch object with no reference data. If Search() is called
+ * before a reference set is set with Train(), an exception will be thrown.
+ *
+ * @param naive Whether naive (brute-force) search should be used.
+ * @param singleMode Whether single-tree computation should be used (as
+ * opposed to dual-tree computation).
+ * @param tau The rank-approximation in percentile of the data. The default
+ * value is 5%.
+ * @param alpha The desired success probability. The default value is 0.95.
+ * @param sampleAtLeaves Sample at leaves for faster but less accurate
+ * computation. This defaults to 'false'.
+ * @param firstLeafExact Traverse to the first leaf without approximation.
+ * This can ensure that the query definitely finds its (near) duplicate
+ * if there exists one. This defaults to 'false' for now.
+ * @param singleSampleLimit The limit on the largest node that can be
+ * approximated by sampling. This defaults to 20.
+ * @param metric Instantiated distance metric.
+ */
+ RASearch(const bool naive = false,
+ const bool singleMode = false,
+ const double tau = 5,
+ const double alpha = 0.95,
+ const bool sampleAtLeaves = false,
+ const bool firstLeafExact = false,
+ const size_t singleSampleLimit = 20,
+ const MetricType metric = MetricType());
+
+ /**
* Delete the RASearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
~RASearch();
/**
+ * "Train" the model on the given reference set. If tree-based search is
+ * being used (if Naive() is false), this means rebuilding the reference tree.
+ * This particular method will make a copy of the given reference data. To
+ * avoid that copy, use the Train() method that takes an rvalue reference with
+ * std::move().
+ *
+ * @param referenceSet New reference set to use.
+ */
+ void Train(const MatType& referenceSet);
+
+ /**
+ * "Train" the model on the given reference set, taking ownership of the data
+ * matrix. If tree-based search is being used (if Naive() is false), this
+ * also means rebuilding the reference tree. If you need to keep a copy of
+ * the reference data, use the Train() method that takes a const reference to
+ * the data.
+ *
+ * @param referenceSet New reference set to use.
+ */
+ void Train(MatType&& referenceSet);
+
+ /**
* Compute the rank approximate nearest neighbors of each query point in the
* query set and store the output in the given matrices. The matrices will be
* set to the size of n columns by k rows, where n is the number of points in
@@ -262,6 +367,19 @@ class RASearch
*/
void ResetQueryTree(Tree* queryTree) const;
+ //! Access the reference set.
+ const MatType& ReferenceSet() const { return *referenceSet; }
+
+ //! Get whether or not naive (brute-force) search is used.
+ bool Naive() const { return naive; }
+ //! Modify whether or not naive (brute-force) search is used.
+ bool& Naive() { return naive; }
+
+ //! Get whether or not single-tree search is used.
+ bool SingleMode() const { return singleMode; }
+ //! Modify whether or not single-tree search is used.
+ bool& SingleMode() { return singleMode; }
+
//! Get the rank-approximation in percentile of the data.
double Tau() const { return tau; }
//! Modify the rank-approximation in percentile of the data.
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index 0c60234..02e063f 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -20,7 +20,7 @@ namespace aux {
//! Call the tree constructor that does mapping.
template<typename TreeType>
TreeType* BuildTree(
- typename TreeType::Mat& dataset,
+ const typename TreeType::Mat& dataset,
std::vector<size_t>& oldFromNew,
typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -41,6 +41,30 @@ TreeType* BuildTree(
return new TreeType(dataset);
}
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat&& dataset,
+ std::vector<size_t>& oldFromNew,
+ typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset), oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+ typename TreeType::Mat&& dataset,
+ const std::vector<size_t>& /* oldFromNew */,
+ const typename boost::enable_if_c<
+ tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+ >::type = 0)
+{
+ return new TreeType(std::move(dataset));
+}
+
} // namespace aux
// Construct the object.
@@ -77,6 +101,41 @@ RASearch(const MatType& referenceSetIn,
// Nothing to do.
}
+// Construct the object, taking ownership of the data matrix.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(MatType&& referenceSetIn,
+ const bool naive,
+ const bool singleMode,
+ const double tau,
+ const double alpha,
+ const bool sampleAtLeaves,
+ const bool firstLeafExact,
+ const size_t singleSampleLimit,
+ const MetricType metric) :
+ referenceTree(naive ? NULL : aux::BuildTree<Tree>(
+ std::move(referenceSetIn), oldFromNewReferences)),
+ referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
+ &referenceTree->Dataset()),
+ treeOwner(!naive),
+ setOwner(naive),
+ naive(naive),
+ singleMode(!naive && singleMode), // No single mode if naive.
+ tau(tau),
+ alpha(alpha),
+ sampleAtLeaves(sampleAtLeaves),
+ firstLeafExact(firstLeafExact),
+ singleSampleLimit(singleSampleLimit),
+ metric(metric)
+{
+ // Nothing to do.
+}
+
// Construct the object.
template<typename SortPolicy,
typename MetricType,
@@ -108,9 +167,46 @@ RASearch(Tree* referenceTree,
// Nothing else to initialize.
{ }
+// Empty constructor.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(const bool naive,
+ const bool singleMode,
+ const double tau,
+ const double alpha,
+ const bool sampleAtLeaves,
+ const bool firstLeafExact,
+ const size_t singleSampleLimit,
+ const MetricType metric) :
+ referenceTree(NULL),
+ referenceSet(new MatType()),
+ treeOwner(false),
+ setOwner(true),
+ naive(naive),
+ singleMode(singleMode),
+ tau(tau),
+ alpha(alpha),
+ sampleAtLeaves(sampleAtLeaves),
+ firstLeafExact(firstLeafExact),
+ singleSampleLimit(singleSampleLimit),
+ metric(metric)
+{
+ // Build the tree on the empty dataset, if necessary.
+ if (!naive)
+ {
+ referenceTree = aux::BuildTree<Tree>(*referenceSet, oldFromNewReferences);
+ treeOwner = true;
+ }
+}
+
/**
- * The tree is the only member we may be responsible for deleting. The others
- * will take care of themselves.
+ * The tree and the dataset are the only members we may be responsible for
+ * deleting. The others will take care of themselves.
*/
template<typename SortPolicy,
typename MetricType,
@@ -127,6 +223,84 @@ RASearch<SortPolicy, MetricType, MatType, TreeType>::
delete referenceSet;
}
+// Train on a new reference set.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
+ const MatType& referenceSet)
+{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // We may need to rebuild the tree.
+ if (!naive)
+ {
+ referenceTree = aux::BuildTree<Tree>(referenceSet, oldFromNewReferences);
+ treeOwner = true;
+ }
+ else
+ {
+ treeOwner = false;
+ }
+
+ // Delete the old reference set, if we owned it.
+ if (setOwner && this->referenceSet)
+ delete this->referenceSet;
+
+ if (!naive)
+ this->referenceSet = &referenceTree->Dataset();
+ else
+ this->referenceSet = &referenceSet;
+ setOwner = false; // We don't own the set in either case.
+}
+
+// Train on a new reference set.
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
+ MatType&& referenceSet)
+{
+ // Clean up the old tree, if we built one.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
+
+ // We may need to rebuild the tree.
+ if (!naive)
+ {
+ referenceTree = aux::BuildTree<Tree>(std::move(referenceSet),
+ oldFromNewReferences);
+ treeOwner = true;
+ }
+ else
+ {
+ treeOwner = false;
+ }
+
+ // Delete the old reference set, if we owned it.
+ if (setOwner && this->referenceSet)
+ delete this->referenceSet;
+
+ if (!naive)
+ {
+ this->referenceSet = &referenceTree->Dataset();
+ setOwner = false;
+ }
+ else
+ {
+ this->referenceSet = new MatType(std::move(referenceSet));
+ setOwner = true;
+ }
+}
+
/**
* Computes the best neighbors and stores them in resultingNeighbors and
* distances.
@@ -143,6 +317,14 @@ Search(const MatType& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
{
+ if (k > referenceSet->n_cols)
+ {
+ std::stringstream ss;
+ ss << "requested value of k (" << k << ") is greater than the number of "
+ << "points in the reference set (" << referenceSet->n_cols << ")";
+ throw std::invalid_argument(ss.str());
+ }
+
Timer::Start("computing_neighbors");
// This will hold mappings for query points, if necessary.
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 852ce19..e116c4c 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -553,4 +553,61 @@ BOOST_AUTO_TEST_CASE(NeighborPtrDeleteTest)
BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
}
+/**
+ * Test that the rvalue reference move constructor works.
+ */
+BOOST_AUTO_TEST_CASE(MoveConstructorTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(3, 200);
+ arma::mat copy(dataset);
+
+ AllkRANN moveknn(std::move(copy));
+ AllkRANN allknn(dataset);
+
+ BOOST_REQUIRE_EQUAL(copy.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(moveknn.ReferenceSet().n_rows, 3);
+ BOOST_REQUIRE_EQUAL(moveknn.ReferenceSet().n_cols, 200);
+
+ arma::mat moveDistances, distances;
+ arma::Mat<size_t> moveNeighbors, neighbors;
+
+ moveknn.Search(1, moveNeighbors, moveDistances);
+ allknn.Search(1, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(moveNeighbors.n_rows, neighbors.n_rows);
+ BOOST_REQUIRE_EQUAL(moveNeighbors.n_rows, neighbors.n_rows);
+ BOOST_REQUIRE_EQUAL(moveNeighbors.n_cols, neighbors.n_cols);
+ BOOST_REQUIRE_EQUAL(moveDistances.n_rows, distances.n_rows);
+ BOOST_REQUIRE_EQUAL(moveDistances.n_cols, distances.n_cols);
+}
+
+/**
+ * Test that the dataset can be retrained with the move Train() function.
+ */
+BOOST_AUTO_TEST_CASE(MoveTrainTest)
+{
+ arma::mat dataset = arma::randu<arma::mat>(3, 200);
+
+ // Do it in tree mode, and in naive mode.
+ AllkRANN knn;
+ knn.Train(std::move(dataset));
+
+ arma::mat distances;
+ arma::Mat<size_t> neighbors;
+ knn.Search(1, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+
+ dataset = arma::randu<arma::mat>(3, 300);
+ knn.Naive() = true;
+ knn.Train(std::move(dataset));
+ knn.Search(1, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 300);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 300);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list