[mlpack-git] master,mlpack-1.0.x: Modified patch from Saheb for #301; this unifies the constructors for NeighborSearch so they work with all tree types. The modifications I've made make it so that the referenceCopy and queryCopy matrices aren't full copies of the referenceSet and querySet matrices when the tree doesn't modify them (in the case where they aren't modified, it's not necessary to copy them, that's just a waste of memory). (b8747ec)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:45:13 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit b8747ec7df1e373d067bc437a463d4305b76790a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Mar 10 21:15:41 2014 +0000

    Modified patch from Saheb for #301; this unifies the constructors for
    NeighborSearch so they work with all tree types.  The modifications I've made
    make it so that the referenceCopy and queryCopy matrices aren't full copies of
    the referenceSet and querySet matrices when the tree doesn't modify them (in the
    case where they aren't modified, it's not necessary to copy them, that's just a
    waste of memory).


>---------------------------------------------------------------

b8747ec7df1e373d067bc437a463d4305b76790a
 .../methods/neighbor_search/neighbor_search.hpp    |  2 -
 .../neighbor_search/neighbor_search_impl.hpp       | 94 ++++++++++++++++------
 2 files changed, 71 insertions(+), 25 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/neighbor_search.hpp b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
index 01836c7..71ad1fc 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search.hpp
@@ -72,7 +72,6 @@ class NeighborSearch
                  const typename TreeType::Mat& querySet,
                  const bool naive = false,
                  const bool singleMode = false,
-                 const size_t leafSize = 20,
                  const MetricType metric = MetricType());
 
   /**
@@ -99,7 +98,6 @@ class NeighborSearch
   NeighborSearch(const typename TreeType::Mat& referenceSet,
                  const bool naive = false,
                  const bool singleMode = false,
-                 const size_t leafSize = 20,
                  const MetricType metric = MetricType());
 
   /**
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
index c0aafb6..c694aa0 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_impl.hpp
@@ -15,19 +15,42 @@
 namespace mlpack {
 namespace neighbor {
 
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset, oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset);
+}
+
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
-               const typename TreeType::Mat& querySet,
+NeighborSearch(const typename TreeType::Mat& referenceSetIn,
+               const typename TreeType::Mat& querySetIn,
                const bool naive,
                const bool singleMode,
-               const size_t leafSize,
                const MetricType metric) :
-    referenceCopy(referenceSet),
-    queryCopy(querySet),
-    referenceSet(referenceCopy),
-    querySet(queryCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
+    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? queryCopy
+        : querySetIn),
     referenceTree(NULL),
     queryTree(NULL),
     treeOwner(!naive), // False if a tree was passed.  If naive, then no trees.
@@ -42,15 +65,26 @@ NeighborSearch(const typename TreeType::Mat& referenceSet,
   // We'll time tree building, but only if we are building trees.
   Timer::Start("tree_building");
 
+  // Copy the datasets, if they will be modified during tree building.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    referenceCopy = referenceSetIn;
+    queryCopy = querySetIn;
+  }
+
   // If not in naive mode, then we need to build trees.
   if (!naive)
   {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-        (naive ? referenceCopy.n_cols : leafSize));
+    // The const_cast is safe; if RearrangesDataset == false, then it'll be
+    // casted back to const anyway, and if not, referenceSet points to
+    // referenceCopy, which isn't const.
+    referenceTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(referenceSet),
+        oldFromNewReferences);
 
     if (!singleMode)
-      queryTree = new TreeType(queryCopy, oldFromNewQueries,
-          (naive ? querySet.n_cols : leafSize));
+      queryTree = BuildTree<TreeType>(
+          const_cast<typename TreeType::Mat&>(querySet), oldFromNewQueries);
   }
 
   // Stop the timer we started above (if we need to).
@@ -60,14 +94,14 @@ NeighborSearch(const typename TreeType::Mat& referenceSet,
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 NeighborSearch<SortPolicy, MetricType, TreeType>::
-NeighborSearch(const typename TreeType::Mat& referenceSet,
+NeighborSearch(const typename TreeType::Mat& referenceSetIn,
                const bool naive,
                const bool singleMode,
-               const size_t leafSize,
                const MetricType metric) :
-    referenceCopy(referenceSet),
-    referenceSet(referenceCopy),
-    querySet(referenceCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
+    querySet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy
+        : referenceSetIn),
     referenceTree(NULL),
     queryTree(NULL),
     treeOwner(!naive), // If naive, then we are not building any trees.
@@ -79,11 +113,20 @@ NeighborSearch(const typename TreeType::Mat& referenceSet,
   // We'll time tree building, but only if we are building trees.
   Timer::Start("tree_building");
 
+  // Copy the dataset, if it will be modified during tree building.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    referenceCopy = referenceSetIn;
+
   // If not in naive mode, then we may need to construct trees.
   if (!naive)
   {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences,
-        (naive ? referenceSet.n_cols : leafSize));
+    // The const_cast is safe; if RearrangesDataset == false, then it'll be
+    // casted back to const anyway, and if not, referenceSet points to
+    // referenceCopy, which isn't const.
+    referenceTree = BuildTree<TreeType>(
+        const_cast<typename TreeType::Mat&>(referenceSet),
+        oldFromNewReferences);
+
     if (!singleMode)
       queryTree = new TreeType(*referenceTree);
   }
@@ -180,10 +223,15 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
   arma::mat* distancePtr = &distances;
 
-  if (treeOwner && !(singleMode && hasQuerySet))
-    distancePtr = new arma::mat; // Query indices need to be mapped.
-  if (treeOwner)
-    neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
+  // Mapping is only necessary if the tree rearranges points.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    if (treeOwner && !(singleMode && hasQuerySet))
+      distancePtr = new arma::mat; // Query indices need to be mapped.
+
+    if (treeOwner)
+      neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
+  }
 
   // Set the size of the neighbor and distance matrices.
   neighborPtr->set_size(k, querySet.n_cols);
@@ -225,7 +273,7 @@ void NeighborSearch<SortPolicy, MetricType, TreeType>::Search(
   Timer::Stop("computing_neighbors");
 
   // Now, do we need to do mapping of indices?
-  if (!treeOwner)
+  if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
   {
     // No mapping needed.  We are done.
     return;



More information about the mlpack-git mailing list