[mlpack-git] master: Add empty constructor, rvalue reference constructor, and tests. (157595c)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 8 10:53:09 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/de31c15aea34256db01ba8d0816a9c0740041c1f...157595c68e3d26679e90152f07e1ee28e5e563c2

>---------------------------------------------------------------

commit 157595c68e3d26679e90152f07e1ee28e5e563c2
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Dec 8 10:52:40 2015 -0500

    Add empty constructor, rvalue reference constructor, and tests.


>---------------------------------------------------------------

157595c68e3d26679e90152f07e1ee28e5e563c2
 src/mlpack/methods/rann/ra_search.hpp      | 122 ++++++++++++++++++-
 src/mlpack/methods/rann/ra_search_impl.hpp | 188 ++++++++++++++++++++++++++++-
 src/mlpack/tests/allkrann_search_test.cpp  |  57 +++++++++
 3 files changed, 362 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 40a4c0f..9a4084b 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -74,7 +74,8 @@ class RASearch
    *
    * This method will copy the matrices to internal copies, which are rearranged
    * during tree-building.  You can avoid this extra copy by pre-constructing
-   * the trees and passing them using a different constructor.
+   * the trees and using the appropriate constructor, or by using the
+   * constructor that takes an rvalue reference to the data with std::move().
    *
    * tau, the rank-approximation parameter, specifies that we are looking for k
    * neighbors with probability alpha of being in the top tau percent of nearest
@@ -120,6 +121,60 @@ class RASearch
            const MetricType metric = MetricType());
 
   /**
+   * Initialize the RASearch object, passing both a reference dataset (this is
+   * the dataset that will be searched).  Optionally, perform the computation in
+   * naive mode or single-tree mode.  An initialized distance metric can be
+   * given, for cases where the metric has internal data (i.e. the
+   * distance::MahalanobisDistance class).
+   *
+   * This method will take ownership of the given reference set, avoiding a
+   * copy.  If you need to use the reference set for other purposes, too,
+   * consider using the constructor that takes a const reference.
+   *
+   * tau, the rank-approximation parameter, specifies that we are looking for k
+   * neighbors with probability alpha of being in the top tau percent of nearest
+   * neighbors.  So, as an example, if our dataset has 1000 points, and we want
+   * 5 nearest neighbors with 95% probability of being in the top 5% of nearest
+   * neighbors (or, the top 50 nearest neighbors), we set k = 5, tau = 5, and
+   * alpha = 0.95.
+   *
+   * The method will fail (and throw a std::invalid_argument exception) if the
+   * value of tau is too low: tau must be set such that the number of points in
+   * the corresponding percentile of the data is greater than k.  Thus, if we
+   * choose tau = 0.1 with a dataset of 1000 points and k = 5, then we are
+   * attempting to choose 5 nearest neighbors out of the closest 1 point -- this
+   * is invalid.
+   *
+   * @param referenceSet Set of reference points.
+   * @param naive If true, the rank-approximate search will be performed by
+   *      directly sampling the whole set instead of using the stratified
+   *      sampling on the tree.
+   * @param singleMode If true, single-tree search will be used (as opposed to
+   *      dual-tree search).  This is useful when Search() will be called with
+   *      few query points.
+   * @param metric An optional instance of the MetricType class.
+   * @param tau The rank-approximation in percentile of the data. The default
+   *     value is 5%.
+   * @param alpha The desired success probability. The default value is 0.95.
+   * @param sampleAtLeaves Sample at leaves for faster but less accurate
+   *      computation. This defaults to 'false'.
+   * @param firstLeafExact Traverse to the first leaf without approximation.
+   *     This can ensure that the query definitely finds its (near) duplicate
+   *     if there exists one.  This defaults to 'false' for now.
+   * @param singleSampleLimit The limit on the largest node that can be
+   *     approximated by sampling. This defaults to 20.
+   */
+  RASearch(MatType&& referenceSet,
+           const bool naive = false,
+           const bool singleMode = false,
+           const double tau = 5,
+           const double alpha = 0.95,
+           const bool sampleAtLeaves = false,
+           const bool firstLeafExact = false,
+           const size_t singleSampleLimit = 20,
+           const MetricType metric = MetricType());
+
+  /**
    * Initialize the RASearch object with the given pre-constructed reference
    * tree.  It is assumed that the points in the tree's dataset correspond to
    * the reference set.  Optionally, choose to use single-tree mode.  Naive mode
@@ -155,7 +210,6 @@ class RASearch
    * @param referenceTree Pre-built tree for reference points.
    * @param singleMode Whether single-tree computation should be used (as
    *      opposed to dual-tree computation).
-   * @param metric Instantiated distance metric.
    * @param tau The rank-approximation in percentile of the data. The default
    *     value is 5%.
    * @param alpha The desired success probability. The default value is 0.95.
@@ -166,6 +220,7 @@ class RASearch
    *     if there exists one.  This defaults to 'false' for now.
    * @param singleSampleLimit The limit on the largest node that can be
    *     approximated by sampling. This defaults to 20.
+   * @param metric Instantiated distance metric.
    */
   RASearch(Tree* referenceTree,
            const bool singleMode = false,
@@ -177,12 +232,62 @@ class RASearch
            const MetricType metric = MetricType());
 
   /**
+   * Create an RASearch object with no reference data.  If Search() is called
+   * before a reference set is set with Train(), an exception will be thrown.
+   *
+   * @param naive Whether naive (brute-force) search should be used.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param tau The rank-approximation in percentile of the data. The default
+   *     value is 5%.
+   * @param alpha The desired success probability. The default value is 0.95.
+   * @param sampleAtLeaves Sample at leaves for faster but less accurate
+   *      computation. This defaults to 'false'.
+   * @param firstLeafExact Traverse to the first leaf without approximation.
+   *     This can ensure that the query definitely finds its (near) duplicate
+   *     if there exists one.  This defaults to 'false' for now.
+   * @param singleSampleLimit The limit on the largest node that can be
+   *     approximated by sampling. This defaults to 20.
+   * @param metric Instantiated distance metric.
+   */
+  RASearch(const bool naive = false,
+           const bool singleMode = false,
+           const double tau = 5,
+           const double alpha = 0.95,
+           const bool sampleAtLeaves = false,
+           const bool firstLeafExact = false,
+           const size_t singleSampleLimit = 20,
+           const MetricType metric = MetricType());
+
+  /**
    * Delete the RASearch object. The tree is the only member we are
    * responsible for deleting.  The others will take care of themselves.
    */
   ~RASearch();
 
   /**
+   * "Train" the model on the given reference set.  If tree-based search is
+   * being used (if Naive() is false), this means rebuilding the reference tree.
+   * This particular method will make a copy of the given reference data.  To
+   * avoid that copy, use the Train() method that takes an rvalue reference with
+   * std::move().
+   *
+   * @param referenceSet New reference set to use.
+   */
+  void Train(const MatType& referenceSet);
+
+  /**
+   * "Train" the model on the given reference set, taking ownership of the data
+   * matrix.  If tree-based search is being used (if Naive() is false), this
+   * also means rebuilding the reference tree.  If you need to keep a copy of
+   * the reference data, use the Train() method that takes a const reference to
+   * the data.
+   *
+   * @param referenceSet New reference set to use.
+   */
+  void Train(MatType&& referenceSet);
+
+  /**
    * Compute the rank approximate nearest neighbors of each query point in the
    * query set and store the output in the given matrices. The matrices will be
    * set to the size of n columns by k rows, where n is the number of points in
@@ -262,6 +367,19 @@ class RASearch
    */
   void ResetQueryTree(Tree* queryTree) const;
 
+  //! Access the reference set.
+  const MatType& ReferenceSet() const { return *referenceSet; }
+
+  //! Get whether or not naive (brute-force) search is used.
+  bool Naive() const { return naive; }
+  //! Modify whether or not naive (brute-force) search is used.
+  bool& Naive() { return naive; }
+
+  //! Get whether or not single-tree search is used.
+  bool SingleMode() const { return singleMode; }
+  //! Modify whether or not single-tree search is used.
+  bool& SingleMode() { return singleMode; }
+
   //! Get the rank-approximation in percentile of the data.
   double Tau() const { return tau; }
   //! Modify the rank-approximation in percentile of the data.
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index 0c60234..02e063f 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -20,7 +20,7 @@ namespace aux {
 //! Call the tree constructor that does mapping.
 template<typename TreeType>
 TreeType* BuildTree(
-    typename TreeType::Mat& dataset,
+    const typename TreeType::Mat& dataset,
     std::vector<size_t>& oldFromNew,
     typename boost::enable_if_c<
         tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -41,6 +41,30 @@ TreeType* BuildTree(
   return new TreeType(dataset);
 }
 
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat&& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset), oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat&& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset));
+}
+
 } // namespace aux
 
 // Construct the object.
@@ -77,6 +101,41 @@ RASearch(const MatType& referenceSetIn,
   // Nothing to do.
 }
 
+// Construct the object, taking ownership of the data matrix.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(MatType&& referenceSetIn,
+         const bool naive,
+         const bool singleMode,
+         const double tau,
+         const double alpha,
+         const bool sampleAtLeaves,
+         const bool firstLeafExact,
+         const size_t singleSampleLimit,
+         const MetricType metric) :
+    referenceTree(naive ? NULL : aux::BuildTree<Tree>(
+        std::move(referenceSetIn), oldFromNewReferences)),
+    referenceSet(naive ? new MatType(std::move(referenceSetIn)) :
+        &referenceTree->Dataset()),
+    treeOwner(!naive),
+    setOwner(naive),
+    naive(naive),
+    singleMode(!naive && singleMode), // No single mode if naive.
+    tau(tau),
+    alpha(alpha),
+    sampleAtLeaves(sampleAtLeaves),
+    firstLeafExact(firstLeafExact),
+    singleSampleLimit(singleSampleLimit),
+    metric(metric)
+{
+  // Nothing to do.
+}
+
 // Construct the object.
 template<typename SortPolicy,
          typename MetricType,
@@ -108,9 +167,46 @@ RASearch(Tree* referenceTree,
 // Nothing else to initialize.
 {  }
 
+// Empty constructor.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(const bool naive,
+         const bool singleMode,
+         const double tau,
+         const double alpha,
+         const bool sampleAtLeaves,
+         const bool firstLeafExact,
+         const size_t singleSampleLimit,
+         const MetricType metric) :
+    referenceTree(NULL),
+    referenceSet(new MatType()),
+    treeOwner(false),
+    setOwner(true),
+    naive(naive),
+    singleMode(singleMode),
+    tau(tau),
+    alpha(alpha),
+    sampleAtLeaves(sampleAtLeaves),
+    firstLeafExact(firstLeafExact),
+    singleSampleLimit(singleSampleLimit),
+    metric(metric)
+{
+  // Build the tree on the empty dataset, if necessary.
+  if (!naive)
+  {
+    referenceTree = aux::BuildTree<Tree>(*referenceSet, oldFromNewReferences);
+    treeOwner = true;
+  }
+}
+
 /**
- * The tree is the only member we may be responsible for deleting.  The others
- * will take care of themselves.
+ * The tree and the dataset are the only members we may be responsible for
+ * deleting.  The others will take care of themselves.
  */
 template<typename SortPolicy,
          typename MetricType,
@@ -127,6 +223,84 @@ RASearch<SortPolicy, MetricType, MatType, TreeType>::
     delete referenceSet;
 }
 
+// Train on a new reference set.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
+    const MatType& referenceSet)
+{
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  // We may need to rebuild the tree.
+  if (!naive)
+  {
+    referenceTree = aux::BuildTree<Tree>(referenceSet, oldFromNewReferences);
+    treeOwner = true;
+  }
+  else
+  {
+    treeOwner = false;
+  }
+
+  // Delete the old reference set, if we owned it.
+  if (setOwner && this->referenceSet)
+    delete this->referenceSet;
+
+  if (!naive)
+    this->referenceSet = &referenceTree->Dataset();
+  else
+    this->referenceSet = &referenceSet;
+  setOwner = false; // We don't own the set in either case.
+}
+
+// Train on a new reference set.
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Train(
+    MatType&& referenceSet)
+{
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  // We may need to rebuild the tree.
+  if (!naive)
+  {
+    referenceTree = aux::BuildTree<Tree>(std::move(referenceSet),
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+  else
+  {
+    treeOwner = false;
+  }
+
+  // Delete the old reference set, if we owned it.
+  if (setOwner && this->referenceSet)
+    delete this->referenceSet;
+
+  if (!naive)
+  {
+    this->referenceSet = &referenceTree->Dataset();
+    setOwner = false;
+  }
+  else
+  {
+    this->referenceSet = new MatType(std::move(referenceSet));
+    setOwner = true;
+  }
+}
+
 /**
  * Computes the best neighbors and stores them in resultingNeighbors and
  * distances.
@@ -143,6 +317,14 @@ Search(const MatType& querySet,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
 {
+  if (k > referenceSet->n_cols)
+  {
+    std::stringstream ss;
+    ss << "requested value of k (" << k << ") is greater than the number of "
+        << "points in the reference set (" << referenceSet->n_cols << ")";
+    throw std::invalid_argument(ss.str());
+  }
+
   Timer::Start("computing_neighbors");
 
   // This will hold mappings for query points, if necessary.
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 852ce19..e116c4c 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -553,4 +553,61 @@ BOOST_AUTO_TEST_CASE(NeighborPtrDeleteTest)
   BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
 }
 
+/**
+ * Test that the rvalue reference move constructor works.
+ */
+BOOST_AUTO_TEST_CASE(MoveConstructorTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(3, 200);
+  arma::mat copy(dataset);
+
+  AllkRANN moveknn(std::move(copy));
+  AllkRANN allknn(dataset);
+
+  BOOST_REQUIRE_EQUAL(copy.n_elem, 0);
+  BOOST_REQUIRE_EQUAL(moveknn.ReferenceSet().n_rows, 3);
+  BOOST_REQUIRE_EQUAL(moveknn.ReferenceSet().n_cols, 200);
+
+  arma::mat moveDistances, distances;
+  arma::Mat<size_t> moveNeighbors, neighbors;
+
+  moveknn.Search(1, moveNeighbors, moveDistances);
+  allknn.Search(1, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(moveNeighbors.n_rows, neighbors.n_rows);
+  BOOST_REQUIRE_EQUAL(moveNeighbors.n_rows, neighbors.n_rows);
+  BOOST_REQUIRE_EQUAL(moveNeighbors.n_cols, neighbors.n_cols);
+  BOOST_REQUIRE_EQUAL(moveDistances.n_rows, distances.n_rows);
+  BOOST_REQUIRE_EQUAL(moveDistances.n_cols, distances.n_cols);
+}
+
+/**
+ * Test that the dataset can be retrained with the move Train() function.
+ */
+BOOST_AUTO_TEST_CASE(MoveTrainTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(3, 200);
+
+  // Do it in tree mode, and in naive mode.
+  AllkRANN knn;
+  knn.Train(std::move(dataset));
+
+  arma::mat distances;
+  arma::Mat<size_t> neighbors;
+  knn.Search(1, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+
+  dataset = arma::randu<arma::mat>(3, 300);
+  knn.Naive() = true;
+  knn.Train(std::move(dataset));
+  knn.Search(1, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 300);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 300);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list