[mlpack-git] master: Add move constructors and Train() functions. (06b2e58)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 2 12:19:22 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f86acf8be2c01568d8b3dcd2e529ee9f20f7585e...156787dd4f372a7fd740f733127ac200ea2564b7

>---------------------------------------------------------------

commit 06b2e584866e060bc6e478306c7a963c6688811f
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Nov 2 17:17:01 2015 +0000

    Add move constructors and Train() functions.


>---------------------------------------------------------------

06b2e584866e060bc6e478306c7a963c6688811f
 src/mlpack/methods/range_search/range_search.hpp   |  76 +++++++++
 .../methods/range_search/range_search_impl.hpp     | 182 +++++++++++++++++++++
 src/mlpack/tests/range_search_test.cpp             |  23 +++
 3 files changed, 281 insertions(+)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index de95d79..ffb3d86 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -59,6 +59,30 @@ class RangeSearch
               const MetricType metric = MetricType());
 
   /**
+   * Initialize the RangeSearch object with the given reference dataset (this is
+   * the dataset which is searched), taking ownership of the matrix.
+   * Optionally, perform the computation in naive mode or single-tree mode.
+   * Additionally, an instantiated metric can be given, for cases where the
+   * distance metric holds data.
+   *
+   * This method will not copy the data matrix, but will take ownership of it,
+   * and depending on the type of tree used, may rearrange the points.  If you
+   * would rather a copy be made, consider using the constructor that takes a
+   * const reference to the data instead.
+   *
+   * @param referenceSet Set of reference points.
+   * @param naive If true, brute force naive search will be used (as opposed to
+   *      dual-tree search).  This overrides singleMode (if it is set to true).
+   * @param singleMode If true, single-tree search will be used (as opposed to
+   *      dual-tree search).
+   * @param metric An optional instance of the MetricType class.
+   */
+  RangeSearch(MatType&& referenceSet,
+              const bool naive = false,
+              const bool singleMode = false,
+              const MetricType metric = MetricType());
+
+  /**
    * Initialize the RangeSearch object with the given pre-constructed reference
    * tree (this is the tree built on the reference set, which is the set that is
    * searched).  Optionally, choose to use single-tree mode, which will not
@@ -87,12 +111,51 @@ class RangeSearch
               const MetricType metric = MetricType());
 
   /**
+   * Initialize the RangeSearch object without any reference data.  If the
+   * monochromatic Search() is called before a reference set is set with
+   * Train(), no results will be returned (since the reference set is empty).
+   *
+   * @param naive Whether to use naive search.
+   * @param singleMode Whether single-tree computation should be used (as
+   *      opposed to dual-tree computation).
+   * @param metric Instantiated metric.
+   */
+  RangeSearch(const bool naive = false,
+              const bool singleMode = false,
+              const MetricType metric = MetricType());
+
+  /**
    * Destroy the RangeSearch object.  If trees were created, they will be
    * deleted.
    */
   ~RangeSearch();
 
   /**
+   * Set the reference set to a new reference set, and build a tree if
+   * necessary.  This method is called 'Train()' in order to match the rest of
+   * the mlpack abstractions, even though calling this "training" is maybe a bit
+   * of a stretch.
+   *
+   * @param referenceSet New set of reference data.
+   */
+  void Train(const MatType& referenceSet);
+
+  /**
+   * Set the reference set to a new reference set, taking ownership of the set.
+   * A tree is built if necessary.  This method is called 'Train()' in order to
+   * match the rest of the mlpack abstractions, even though calling this
+   * "training" is maybe a bit of a stretch.
+   *
+   * @param referenceSet New set of reference data.
+   */
+  void Train(MatType&& referenceSet);
+
+  /**
+   * Set the reference tree to a new reference tree.
+   */
+  void Train(Tree* referenceTree);
+
+  /**
    * Search for all reference points in the given range for each point in the
    * query set, returning the results in the neighbors and distances objects.
    * Each entry in the external vector corresponds to a query point.  Each of
@@ -199,6 +262,16 @@ class RangeSearch
               std::vector<std::vector<size_t>>& neighbors,
               std::vector<std::vector<double>>& distances);
 
+  //! Get whether single-tree search is being used.
+  bool SingleMode() const { return singleMode; }
+  //! Modify whether single-tree search is being used.
+  bool& SingleMode() { return singleMode; }
+
+  //! Get whether naive search is being used.
+  bool Naive() const { return naive; }
+  //! Modify whether naive search is being used.
+  bool& Naive() { return naive; }
+
   //! Get the number of base cases during the last search.
   size_t BaseCases() const { return baseCases; }
   //! Get the number of scores during the last search.
@@ -211,6 +284,9 @@ class RangeSearch
   //! Returns a string representation of this object.
   std::string ToString() const;
 
+  //! Return the reference set.
+  const MatType& ReferenceSet() const { return *referenceSet; }
+
   //! Return the reference tree (or NULL if in naive mode).
   Tree* ReferenceTree() { return referenceTree; }
 
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index a813dca..4eef20f 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -39,6 +39,28 @@ TreeType* BuildTree(
   return new TreeType(dataset);
 }
 
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat&& dataset,
+    std::vector<size_t>& oldFromNew,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset), oldFromNew);
+}
+
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat&& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(std::move(dataset));
+}
+
 template<typename MetricType,
          typename MatType,
          template<typename TreeMetricType,
@@ -63,6 +85,32 @@ RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
   // Nothing to do.
 }
 
+// Move constructor.
+template<typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+    MatType&& referenceSet,
+    const bool naive,
+    const bool singleMode,
+    const MetricType metric) :
+    referenceTree(naive ? NULL : BuildTree<Tree>(std::move(referenceSet),
+        oldFromNewReferences)),
+    referenceSet(naive ? new MatType(std::move(referenceSet)) :
+        &referenceTree->Dataset()),
+    treeOwner(!naive),
+    setOwner(naive),
+    naive(naive),
+    singleMode(!naive && singleMode),
+    metric(metric),
+    baseCases(0),
+    scores(0)
+{
+  // Nothing to do.
+}
+
 template<typename MetricType,
          typename MatType,
          template<typename TreeMetricType,
@@ -90,6 +138,34 @@ template<typename MetricType,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
+RangeSearch<MetricType, MatType, TreeType>::RangeSearch(
+    const bool naive,
+    const bool singleMode,
+    const MetricType metric) :
+    referenceTree(NULL),
+    referenceSet(new MatType()), // Empty matrix.
+    treeOwner(false),
+    setOwner(true),
+    naive(naive),
+    singleMode(singleMode),
+    metric(metric),
+    baseCases(0),
+    scores(0)
+{
+  // Build the tree on the empty dataset, if necessary.
+  if (!naive)
+  {
+    referenceTree = BuildTree<Tree>(const_cast<MatType&>(*referenceSet),
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
 RangeSearch<MetricType, MatType, TreeType>::~RangeSearch()
 {
   if (treeOwner && referenceTree)
@@ -103,12 +179,118 @@ template<typename MetricType,
          template<typename TreeMetricType,
                   typename TreeStatType,
                   typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+    const MatType& referenceSet)
+{
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  // Rebuild the tree, if necessary.
+  if (!naive)
+  {
+    referenceTree = BuildTree<MatType>(referenceSet, oldFromNewReferences);
+    treeOwner = true;
+  }
+  else
+  {
+    treeOwner = false;
+  }
+
+  // Delete the old reference set, if we owned it.
+  if (setOwner && this->referenceSet)
+    delete this->referenceSet;
+
+  if (!naive)
+    this->referenceSet = &referenceTree->Dataset();
+  else
+    this->referenceSet = &referenceSet;
+  setOwner = false;
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+    MatType&& referenceSet)
+{
+  // Clean up the old tree, if we built one.
+  if (treeOwner && referenceTree)
+    delete referenceTree;
+
+  // We may need to rebuild the tree.
+  if (!naive)
+  {
+    referenceTree = BuildTree<MatType>(std::move(referenceSet),
+        oldFromNewReferences);
+    treeOwner = true;
+  }
+  else
+  {
+    treeOwner = false;
+  }
+
+  // Delete the old reference set, if we owned it.
+  if (setOwner && this->referenceSet)
+    delete this->referenceSet;
+
+  if (!naive)
+  {
+    this->referenceSet = &referenceTree->Dataset();
+    setOwner = false;
+  }
+  else
+  {
+    this->referenceSet = new MatType(std::move(referenceSet));
+    setOwner = true;
+  }
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
+void RangeSearch<MetricType, MatType, TreeType>::Train(
+  Tree* referenceTree)
+{
+  if (naive)
+    throw std::invalid_argument("cannot train on given reference tree when "
+        "naive search (without trees) is desired");
+
+  if (treeOwner && referenceTree)
+    delete this->referenceTree;
+  if (setOwner && referenceSet)
+    delete this->referenceSet;
+
+  this->referenceTree = referenceTree;
+  this->referenceSet = &referenceTree->Dataset();
+  treeOwner = false;
+  setOwner = false;
+}
+
+template<typename MetricType,
+         typename MatType,
+         template<typename TreeMetricType,
+                  typename TreeStatType,
+                  typename TreeMatType> class TreeType>
 void RangeSearch<MetricType, MatType, TreeType>::Search(
     const MatType& querySet,
     const math::Range& range,
     std::vector<std::vector<size_t>>& neighbors,
     std::vector<std::vector<double>>& distances)
 {
+  if (querySet.n_rows != referenceSet->n_rows)
+  {
+    std::ostringstream oss;
+    oss << "RangeSearch::Search(): dimensionalities of query set ("
+        << querySet.n_rows << ") and reference set (" << referenceSet->n_rows
+        << ") do not match!";
+    throw std::invalid_argument(oss.str());
+  }
+
   Timer::Start("range_search/computing_neighbors");
 
   // This will hold mappings for query points, if necessary.
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 3110a9f..bc34d59 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1043,4 +1043,27 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
   }
 }
 
+/**
+ * Make sure that no results are returned when we build a range search object
+ * with no reference set.
+ */
+BOOST_AUTO_TEST_CASE(EmptySearchTest)
+{
+  RangeSearch<EuclideanDistance, arma::mat, KDTree> rs;
+
+  vector<vector<size_t>> neighbors;
+  vector<vector<double>> distances;
+
+  rs.Search(math::Range(0.0, 10.0), neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(neighbors.size(), 0);
+  BOOST_REQUIRE_EQUAL(distances.size(), 0);
+
+  // Now check with a query set.
+  arma::mat querySet = arma::randu<arma::mat>(3, 100);
+
+  BOOST_REQUIRE_THROW(rs.Search(querySet, math::Range(0.0, 10.0), neighbors,
+      distances), std::invalid_argument);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list