[mlpack-svn] r16769 - mlpack/trunk/src/mlpack/methods/rann

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jul 7 09:10:04 EDT 2014


Author: rcurtin
Date: Mon Jul  7 09:10:04 2014
New Revision: 16769

Log:
Refactor RASearch so that it does not accept a leafSize parameter and can build
arbitrary tree types.


Modified:
   mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp
   mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search.hpp	Mon Jul  7 09:10:04 2014
@@ -79,7 +79,6 @@
            const typename TreeType::Mat& querySet,
            const bool naive = false,
            const bool singleMode = false,
-           const size_t leafSize = 20,
            const MetricType metric = MetricType());
 
   /**
@@ -107,7 +106,6 @@
   RASearch(const typename TreeType::Mat& referenceSet,
            const bool naive = false,
            const bool singleMode = false,
-           const size_t leafSize = 20,
            const MetricType metric = MetricType());
 
   /**
@@ -258,10 +256,10 @@
   //! Pointer to the root of the query tree (might not exist).
   TreeType* queryTree;
 
-  //! Indicates if we should free the reference tree at deletion time.
-  bool ownReferenceTree;
-  //! Indicates if we should free the query tree at deletion time.
-  bool ownQueryTree;
+  //! If true, this object created the trees and is responsible for them.
+  bool treeOwner;
+  //! Indicates if a separate query set was passed.
+  bool hasQuerySet;
 
   //! Indicates if naive random sampling on the set is being used.
   bool naive;

Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_impl.hpp	Mon Jul  7 09:10:04 2014
@@ -15,23 +15,50 @@
 namespace mlpack {
 namespace neighbor {
 
+namespace aux {
+
+//! Call the tree constructor that does mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    typename TreeType::Mat& dataset,
+    std::vector<size_t>& oldFromNew,
+    typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset, oldFromNew);
+}
+
+//! Call the tree constructor that does not do mapping.
+template<typename TreeType>
+TreeType* BuildTree(
+    const typename TreeType::Mat& dataset,
+    const std::vector<size_t>& /* oldFromNew */,
+    const typename boost::enable_if_c<
+        tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
+    >::type = 0)
+{
+  return new TreeType(dataset);
+}
+
+}; // namespace aux
+
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
-         const typename TreeType::Mat& querySet,
+RASearch(const typename TreeType::Mat& referenceSetIn,
+         const typename TreeType::Mat& querySetIn,
          const bool naive,
          const bool singleMode,
-         const size_t leafSize,
          const MetricType metric) :
-    referenceCopy(referenceSet),
-    queryCopy(querySet),
-    referenceSet(referenceCopy),
-    querySet(queryCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy :
+        referenceSetIn),
+    querySet((tree::TreeTraits<TreeType>::RearrangesDataset && !singleMode) ?
+        queryCopy : querySetIn),
     referenceTree(NULL),
     queryTree(NULL),
-    ownReferenceTree(true), // False if a tree was passed.
-    ownQueryTree(true), // False if a tree was passed.
+    treeOwner(!naive),
+    hasQuerySet(true),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
@@ -40,12 +67,22 @@
   // We'll time tree building.
   Timer::Start("tree_building");
 
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  {
+    referenceCopy = referenceSetIn;
+    if (!singleMode)
+      queryCopy = querySetIn;
+  }
+
   // Construct as a naive object if we need to.
   if (!naive)
   {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
+    referenceTree = aux::BuildTree<TreeType>(const_cast<typename
+        TreeType::Mat&>(referenceSet), oldFromNewReferences);
 
-    queryTree = new TreeType(queryCopy, oldFromNewQueries, leafSize);
+    if (!singleMode)
+      queryTree = aux::BuildTree<TreeType>(const_cast<typename
+          TreeType::Mat&>(querySet), oldFromNewQueries);
   }
 
   // Stop the timer we started above.
@@ -55,18 +92,18 @@
 // Construct the object.
 template<typename SortPolicy, typename MetricType, typename TreeType>
 RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSet,
+RASearch(const typename TreeType::Mat& referenceSetIn,
          const bool naive,
          const bool singleMode,
-         const size_t leafSize,
          const MetricType metric) :
-    referenceCopy(referenceSet),
-    referenceSet(referenceCopy),
-    querySet(referenceCopy),
+    referenceSet(tree::TreeTraits<TreeType>::RearrangesDataset ? referenceCopy :
+        referenceSetIn),
+    querySet(tree::TreeTraits<TreeType>::RearrangesDataset && !singleMode ?
+        referenceCopy : referenceSetIn),
     referenceTree(NULL),
     queryTree(NULL),
-    ownReferenceTree(true),
-    ownQueryTree(false), // Since it will be the same as referenceTree.
+    treeOwner(!naive),
+    hasQuerySet(false),
     naive(naive),
     singleMode(!naive && singleMode), // No single mode if naive.
     metric(metric),
@@ -75,11 +112,13 @@
   // We'll time tree building.
   Timer::Start("tree_building");
 
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    referenceCopy = referenceSetIn;
+
   // Construct as a naive object if we need to.
   if (!naive)
-  {
-    referenceTree = new TreeType(referenceCopy, oldFromNewReferences, leafSize);
-  }
+    referenceTree = aux::BuildTree<TreeType>(const_cast<typename
+        TreeType::Mat&>(referenceSet), oldFromNewReferences);
 
   // Stop the timer we started above.
   Timer::Stop("tree_building");
@@ -98,8 +137,8 @@
     querySet(querySet),
     referenceTree(referenceTree),
     queryTree(queryTree),
-    ownReferenceTree(false),
-    ownQueryTree(false),
+    treeOwner(false),
+    hasQuerySet(true),
     naive(false),
     singleMode(singleMode),
     metric(metric),
@@ -114,16 +153,16 @@
          const typename TreeType::Mat& referenceSet,
          const bool singleMode,
          const MetricType metric) :
-  referenceSet(referenceSet),
-  querySet(referenceSet),
-  referenceTree(referenceTree),
-  queryTree(NULL),
-  ownReferenceTree(false),
-  ownQueryTree(false),
-  naive(false),
-  singleMode(singleMode),
-  metric(metric),
-  numberOfPrunes(0)
+    referenceSet(referenceSet),
+    querySet(referenceSet),
+    referenceTree(referenceTree),
+    queryTree(NULL),
+    treeOwner(false),
+    hasQuerySet(false),
+    naive(false),
+    singleMode(singleMode),
+    metric(metric),
+    numberOfPrunes(0)
 // Nothing else to initialize.
 { }
 
@@ -135,10 +174,13 @@
 RASearch<SortPolicy, MetricType, TreeType>::
 ~RASearch()
 {
-  if (ownReferenceTree)
-    delete referenceTree;
-  if (ownQueryTree)
-    delete queryTree;
+  if (treeOwner)
+  {
+    if (referenceTree)
+      delete referenceTree;
+   if (queryTree)
+      delete queryTree;
+  }
 }
 
 /**
@@ -165,11 +207,13 @@
   arma::Mat<size_t>* neighborPtr = &resultingNeighbors;
   arma::mat* distancePtr = &distances;
 
-  if (!naive) // If naive, no re-mapping required since points are not mapped.
+  // Mapping is only required if this tree type rearranges points and we are not
+  // in naive mode.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    if (ownQueryTree || (ownReferenceTree && !queryTree))
+    if (treeOwner && !(singleMode && hasQuerySet))
       distancePtr = new arma::mat; // Query indices need to be mapped.
-    if (ownReferenceTree || ownQueryTree)
+    if (treeOwner)
       neighborPtr = new arma::Mat<size_t>; // All indices need mapping.
   }
 
@@ -201,12 +245,8 @@
     // Run the base case on each combination of query point and sampled
     // reference point.
     for (size_t i = 0; i < querySet.n_cols; ++i)
-    {
       for (size_t j = 0; j < distinctSamples.n_elem; ++j)
-      {
         rules.BaseCase(i, (size_t) distinctSamples[j]);
-      }
-    }
   }
   else if (singleMode)
   {
@@ -274,13 +314,12 @@
   Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
 
   // Now, do we need to do mapping of indices?
-  if ((!ownReferenceTree && !ownQueryTree) || naive)
+  if (!treeOwner || !tree::TreeTraits<TreeType>::RearrangesDataset)
   {
-    // No mapping needed if we do not own the trees or if we are doing naive
-    // sampling.  We are done.
+    // No mapping needed.  We are done.
     return;
   }
-  else if (ownReferenceTree && ownQueryTree) // Map references and queries.
+  else if (treeOwner && hasQuerySet && !singleMode) // Map both sets.
   {
     // Set size of output matrices correctly.
     resultingNeighbors.set_size(k, querySet.n_cols);
@@ -303,62 +342,41 @@
     delete neighborPtr;
     delete distancePtr;
   }
-  else if (ownReferenceTree)
+  else if (treeOwner && !hasQuerySet)
   {
-    if (!queryTree) // No query tree -- map both references and queries.
-    {
-      resultingNeighbors.set_size(k, querySet.n_cols);
-      distances.set_size(k, querySet.n_cols);
-
-      for (size_t i = 0; i < distances.n_cols; i++)
-      {
-        // Map distances (copy a column).
-        distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
+    // No query tree -- map both references and queries.
+    resultingNeighbors.set_size(k, querySet.n_cols);
+    distances.set_size(k, querySet.n_cols);
 
-        // Map indices of neighbors.
-        for (size_t j = 0; j < distances.n_rows; j++)
-        {
-          resultingNeighbors(j, oldFromNewReferences[i]) =
-              oldFromNewReferences[(*neighborPtr)(j, i)];
-        }
-      }
-    }
-    else // Map only references.
+    for (size_t i = 0; i < distances.n_cols; i++)
     {
-      // Set size of neighbor indices matrix correctly.
-      resultingNeighbors.set_size(k, querySet.n_cols);
+      // Map distances (copy a column).
+      distances.col(oldFromNewReferences[i]) = distancePtr->col(i);
 
       // Map indices of neighbors.
-      for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
+      for (size_t j = 0; j < distances.n_rows; j++)
       {
-        for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
-        {
-          resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
-        }
+        resultingNeighbors(j, oldFromNewReferences[i]) =
+            oldFromNewReferences[(*neighborPtr)(j, i)];
       }
     }
-
-    // Finished with temporary matrix.
-    delete neighborPtr;
   }
-  else if (ownQueryTree)
+  else if (treeOwner && hasQuerySet && singleMode) // Map only references.
   {
-    // Set size of matrices correctly.
+    // Set size of neighbor indices matrix correctly.
     resultingNeighbors.set_size(k, querySet.n_cols);
-    distances.set_size(k, querySet.n_cols);
 
-    for (size_t i = 0; i < distances.n_cols; i++)
+    // Map indices of neighbors.
+    for (size_t i = 0; i < resultingNeighbors.n_cols; i++)
     {
-      // Map distances (copy a column).
-      distances.col(oldFromNewQueries[i]) = distancePtr->col(i);
-
-      // Map indices of neighbors.
-      resultingNeighbors.col(oldFromNewQueries[i]) = neighborPtr->col(i);
+      for (size_t j = 0; j < resultingNeighbors.n_rows; j++)
+      {
+        resultingNeighbors(j, i) = oldFromNewReferences[(*neighborPtr)(j, i)];
+      }
     }
 
-    // Finished with temporary matrices.
+    // Finished with temporary matrix.
     delete neighborPtr;
-    delete distancePtr;
   }
 } // Search
 



More information about the mlpack-svn mailing list