[mlpack-git] master,mlpack-1.0.x: Patch from Saheb for #301; refactor RangeSearch constructors so that leafSize is not a parameter. (0081a00)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:46:06 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 0081a0056af89db7c6145c322b572e6e35f11d4c
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Apr 9 22:27:08 2014 +0000

    Patch from Saheb for #301; refactor RangeSearch constructors so that leafSize is
    not a parameter.


>---------------------------------------------------------------

0081a0056af89db7c6145c322b572e6e35f11d4c
 src/mlpack/methods/range_search/range_search.hpp   |  5 --
 .../methods/range_search/range_search_impl.hpp     | 95 ++++++++++++++++------
 2 files changed, 70 insertions(+), 30 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search.hpp b/src/mlpack/methods/range_search/range_search.hpp
index 1f9c5d3..c6025c6 100644
--- a/src/mlpack/methods/range_search/range_search.hpp
+++ b/src/mlpack/methods/range_search/range_search.hpp
@@ -9,11 +9,8 @@
 #define __MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_HPP
 
 #include <mlpack/core.hpp>
-
 #include <mlpack/core/metrics/lmetric.hpp>
-
 #include <mlpack/core/tree/binary_space_tree.hpp>
-
 #include "range_search_stat.hpp"
 
 namespace mlpack {
@@ -54,7 +51,6 @@ class RangeSearch
               const typename TreeType::Mat& querySet,
               const bool naive = false,
               const bool singleMode = false,
-              const size_t leafSize = 20,
               const MetricType metric = MetricType());
 
   /**
@@ -79,7 +75,6 @@ class RangeSearch
   RangeSearch(const typename TreeType::Mat& referenceSet,
               const bool naive = false,
               const bool singleMode = false,
-              const size_t leafSize = 20,
               const MetricType metric = MetricType());
 
   /**
diff --git a/src/mlpack/methods/range_search/range_search_impl.hpp b/src/mlpack/methods/range_search/range_search_impl.hpp
index 061d398..b688cdf 100644
--- a/src/mlpack/methods/range_search/range_search_impl.hpp
+++ b/src/mlpack/methods/range_search/range_search_impl.hpp
@@ -16,18 +16,40 @@
 namespace mlpack {
 namespace range {
 
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset, oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset);
+}
+
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
-    const typename TreeType::Mat& referenceSet,
-    const typename TreeType::Mat& querySet,
+    const typename TreeType::Mat& referenceSetIn,
+    const typename TreeType::Mat& querySetIn,
     const bool naive,
     const bool singleMode,
-    const size_t leafSize,
     const MetricType metric) :
-    referenceCopy(referenceSet),
-    queryCopy(querySet),
-    referenceSet(referenceCopy),
-    querySet(queryCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
+    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
+        : querySetIn),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
     hasQuerySet(true),
     naive(naive),
@@ -38,16 +60,26 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
   // Build the trees.
   Timer::Start("range_search/tree_building");
 
+  // Copy the datasets, if they will be modified during tree building.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    referenceCopy = referenceSetIn;
+    queryCopy = querySetIn;
+  }
+
   // If in naive mode, then we do not need to build trees.
   if (!naive)
   {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-        (naive ? referenceCopy.n_cols : leafSize));
+    // The const_cast is safe; if RearrangesDataset == false, then it'll be
+    // casted back to const anyway, and if not, referenceSet points to
+    // referenceCopy, which isn't const.
+    referenceTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(referenceSet),
+        oldFromNewReferences);
 
-    // If we are in dual-tree mode, we need to build a query tree too.
     if (!singleMode)
-      queryTree = new TreeType(queryCopy, oldFromNewQueries,
-          (naive ? queryCopy.n_cols : leafSize));
+      queryTree = BuildTree<TreeType>(
+          const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
   }
 
   Timer::Stop("range_search/tree_building");
@@ -55,14 +87,14 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
 
 template<typename MetricType, typename TreeType>
 RangeSearch<MetricType, TreeType>::RangeSearch(
-    const typename TreeType::Mat& referenceSet,
+    const typename TreeType::Mat& referenceSetIn,
     const bool naive,
     const bool singleMode,
-    const size_t leafSize,
     const MetricType metric) :
-    referenceCopy(referenceSet),
-    referenceSet(referenceCopy),
-    querySet(referenceCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
+    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
     queryTree(NULL),
     treeOwner(!naive), // If in naive mode, we are not building any trees.
     hasQuerySet(false),
@@ -74,13 +106,20 @@ RangeSearch<MetricType, TreeType>::RangeSearch(
   // Build the trees.
   Timer::Start("range_search/tree_building");
 
+  // Copy the dataset, if it will be modified during tree building.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    referenceCopy = referenceSetIn;
+
   // If in naive mode, then we do not need to build trees.
   if (!naive)
   {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-        (naive ? referenceCopy.n_cols : leafSize));
+    // The const_cast is safe; if RearrangesDataset == false, then it'll be
+    // casted back to const anyway, and if not, referenceSet points to
+    // referenceCopy, which isn't const.
+    referenceTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(referenceSet),
+        oldFromNewReferences);
 
-    // If using dual-tree mode, then we need a second tree.
     if (!singleMode)
       queryTree = new TreeType(*referenceTree);
   }
@@ -165,10 +204,15 @@ void RangeSearch<MetricType, TreeType>::Search(
   std::vector<std::vector<size_t> >* neighborPtr = &neighbors;
   std::vector<std::vector<double> >* distancePtr = &distances;
 
-  if (treeOwner && !(singleMode && hasQuerySet))
-    distancePtr = new std::vector<std::vector<double> >;
-  if (treeOwner)
-    neighborPtr = new std::vector<std::vector<size_t> >;
+  // Mapping is only necessary if the tree rearranges points.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    if (treeOwner && !(singleMode && hasQuerySet))
+      distancePtr = new std::vector<std::vector<double> >; // Query indices need to be mapped.
+
+    if (treeOwner)
+      neighborPtr = new std::vector<std::vector<size_t> >; // All indices need mapping.
+  }
 
   // Resize each vector.
   neighborPtr->clear(); // Just in case there was anything in it.
@@ -216,7 +260,8 @@ void RangeSearch<MetricType, TreeType>::Search(
       << "." << std::endl;
 
   // Map points back to original indices, if necessary.
-  if (!treeOwner)
+
+  if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
   {
     // No mapping needed.  We are done.
     return;



More information about the mlpack-git mailing list