[mlpack-git] master: Refactor for new TreeType API. (48a4fef)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:23 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit 48a4fefcd4c5d12ce3c0e60d6fc3cc56efdda05e
Author: ryan <ryan at ratml.org>
Date:   Sun Jul 26 23:07:45 2015 -0400

    Refactor for new TreeType API.


>---------------------------------------------------------------

48a4fefcd4c5d12ce3c0e60d6fc3cc56efdda05e
 src/mlpack/methods/rann/CMakeLists.txt           |   6 +-
 src/mlpack/methods/rann/allkrann_main.cpp        |   5 +-
 src/mlpack/methods/rann/ra_query_stat.hpp        |   3 +
 src/mlpack/methods/rann/ra_search.hpp            |  34 ++--
 src/mlpack/methods/rann/ra_search_impl.hpp       | 139 +++++++++------
 src/mlpack/methods/rann/ra_search_rules.hpp      |  51 +-----
 src/mlpack/methods/rann/ra_search_rules_impl.hpp | 215 ++---------------------
 7 files changed, 136 insertions(+), 317 deletions(-)

diff --git a/src/mlpack/methods/rann/CMakeLists.txt b/src/mlpack/methods/rann/CMakeLists.txt
index b848cd2..30e1a89 100644
--- a/src/mlpack/methods/rann/CMakeLists.txt
+++ b/src/mlpack/methods/rann/CMakeLists.txt
@@ -15,6 +15,10 @@ set(SOURCES
 
   # typedefs
   ra_typedef.hpp
+
+  # utilities
+  ra_util.hpp
+  ra_util.cpp
 )
 
 # add directory name to sources
@@ -23,7 +27,7 @@ foreach(file ${SOURCES})
   set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
 endforeach()
 # append sources (with directory name) to list of all MLPACK sources (used at the parent scope)
-set(MLPACK_CONTRIB_SRCS ${MLPACK_CONTRIB_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
 
 # The code to compute the rank-approximate neighbor
diff --git a/src/mlpack/methods/rann/allkrann_main.cpp b/src/mlpack/methods/rann/allkrann_main.cpp
index eb0ab87..c5add09 100644
--- a/src/mlpack/methods/rann/allkrann_main.cpp
+++ b/src/mlpack/methods/rann/allkrann_main.cpp
@@ -20,6 +20,7 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 
 // Information about the program itself.
 PROGRAM_INFO("All K-Rank-Approximate-Nearest-Neighbors",
@@ -176,8 +177,8 @@ int main(int argc, char *argv[])
     // NeighborSearch, it does not copy the matrix.
     Log::Info << "Building reference tree..." << endl;
     Timer::Start("tree_building");
-    typedef BinarySpaceTree<bound::HRectBound<2, false>,
-        RAQueryStat<NearestNeighborSort> > TreeType;
+    typedef KDTree<EuclideanDistance, RAQueryStat<NearestNeighborSort>,
+        arma::mat> TreeType;
     TreeType refTree(referenceData, oldFromNewRefs, leafSize);
     Timer::Stop("tree_building");
 
diff --git a/src/mlpack/methods/rann/ra_query_stat.hpp b/src/mlpack/methods/rann/ra_query_stat.hpp
index 47ad4dc..d9c2610 100644
--- a/src/mlpack/methods/rann/ra_query_stat.hpp
+++ b/src/mlpack/methods/rann/ra_query_stat.hpp
@@ -64,4 +64,7 @@ class RAQueryStat
 
 };
 
+} // namespace neighbor
+} // namespace mlpack
+
 #endif
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 135f381..219b33f 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -26,6 +26,10 @@
 #include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
 
 #include "ra_query_stat.hpp"
+#include "ra_util.hpp"
+
+namespace mlpack {
+namespace neighbor {
 
 /**
  * The RASearch class: This class provides a generic manner to perform
@@ -50,12 +54,16 @@
  * @tparam TreeType The tree type to use.
  */
 template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::SquaredEuclideanDistance,
-         typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2, false>,
-                                                   RAQueryStat<SortPolicy> > >
+         typename MetricType = metric::EuclideanDistance,
+         typename MatType = arma::mat,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType = tree::KDTree>
 class RASearch
 {
  public:
+  //! Convenience typedef.
+  typedef TreeType<MetricType, RAQueryStat<SortPolicy>, MatType> Tree;
+
   /**
    * Initialize the RASearch object, passing both a reference dataset (this is
    * the dataset that will be searched).  Optionally, perform the computation in
@@ -100,7 +108,7 @@ class RASearch
    * @param singleSampleLimit The limit on the largest node that can be
    *     approximated by sampling. This defaults to 20.
    */
-  RASearch(const typename TreeType::Mat& referenceSet,
+  RASearch(const MatType& referenceSet,
            const bool naive = false,
            const bool singleMode = false,
            const double tau = 5,
@@ -158,7 +166,7 @@ class RASearch
    * @param singleSampleLimit The limit on the largest node that can be
    *     approximated by sampling. This defaults to 20.
    */
-  RASearch(TreeType* referenceTree,
+  RASearch(Tree* referenceTree,
            const bool singleMode = false,
            const double tau = 5,
            const double alpha = 0.95,
@@ -189,7 +197,7 @@ class RASearch
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    */
-  void Search(const typename TreeType::Mat& querySet,
+  void Search(const MatType& querySet,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
@@ -217,7 +225,7 @@ class RASearch
    * @param distances Matrix storing distances of neighbors for each query
    *     point.
    */
-  void Search(TreeType* queryTree,
+  void Search(Tree* queryTree,
               const size_t k,
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
@@ -251,7 +259,7 @@ class RASearch
    *
    * @param queryTree Tree whose statistics should be reset.
    */
-  void ResetQueryTree(TreeType* queryTree) const;
+  void ResetQueryTree(Tree* queryTree) const;
 
   //! Get the rank-approximation in percentile of the data.
   double Tau() const { return tau; }
@@ -284,11 +292,11 @@ class RASearch
  private:
   //! Copy of reference dataset (if we need it, because tree building modifies
   //! it).
-  arma::mat referenceCopy;
+  MatType referenceCopy;
   //! Reference dataset.
-  const arma::mat& referenceSet;
+  const MatType& referenceSet;
   //! Pointer to the root of the reference tree.
-  TreeType* referenceTree;
+  Tree* referenceTree;
 
   //! If true, this object created the trees and is responsible for them.
   bool treeOwner;
@@ -318,8 +326,8 @@ class RASearch
 
 }; // class RASearch
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 // Include implementation.
 #include "ra_search_impl.hpp"
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index b0db731..d44707a 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -44,9 +44,13 @@ TreeType* BuildTree(
 } // namespace aux
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSetIn,
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(const MatType& referenceSetIn,
          const bool naive,
          const bool singleMode,
          const double tau,
@@ -55,7 +59,7 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
          const bool firstLeafExact,
          const size_t singleSampleLimit,
          const MetricType metric) :
-    referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+    referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
         ? referenceCopy : referenceSetIn),
     referenceTree(NULL),
     treeOwner(!naive),
@@ -73,11 +77,10 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
 
   if (!naive)
   {
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    if (tree::TreeTraits<Tree>::RearrangesDataset)
       referenceCopy = referenceSetIn;
 
-    referenceTree = aux::BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(referenceSet),
+    referenceTree = aux::BuildTree<Tree>(const_cast<MatType&>(referenceSet),
         oldFromNewReferences);
   }
 
@@ -86,9 +89,13 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
 }
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(TreeType* referenceTree,
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(Tree* referenceTree,
          const bool singleMode,
          const double tau,
          const double alpha,
@@ -114,8 +121,12 @@ RASearch(TreeType* referenceTree,
  * The tree is the only member we may be responsible for deleting.  The others
  * will take care of themselves.
  */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
 ~RASearch()
 {
   if (treeOwner && referenceTree)
@@ -126,9 +137,13 @@ RASearch<SortPolicy, MetricType, TreeType>::
  * Computes the best neighbors and stores them in resultingNeighbors and
  * distances.
  */
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::
-Search(const typename TreeType::Mat& querySet,
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::
+Search(const MatType& querySet,
        const size_t k,
        arma::Mat<size_t>& neighbors,
        arma::mat& distances)
@@ -147,7 +162,7 @@ Search(const typename TreeType::Mat& querySet,
 
   // Mapping is only required if this tree type rearranges points and we are not
   // in naive mode.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     if (!singleMode && !naive)
       distancePtr = new arma::mat; // Query indices need to be mapped.
@@ -163,17 +178,16 @@ Search(const typename TreeType::Mat& querySet,
 
   // If we will be building a tree and it will modify the query set, make a copy
   // of the dataset.
-  typename TreeType::Mat queryCopy;
+  MatType queryCopy;
   const bool needsCopy = (!naive && !singleMode &&
-      tree::TreeTraits<TreeType>::RearrangesDataset);
+      tree::TreeTraits<Tree>::RearrangesDataset);
   if (needsCopy)
     queryCopy = querySet;
 
-  const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
-      querySet;
+  const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
 
   // Create the helper object for the tree traversal.
-  typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr,
                  metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                  singleSampleLimit, false);
@@ -182,10 +196,10 @@ Search(const typename TreeType::Mat& querySet,
   {
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
-    const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
         tau, alpha);
     arma::uvec distinctSamples;
-    rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+    RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
         distinctSamples);
 
     // Run the base case on each combination of query point and sampled
@@ -203,8 +217,7 @@ Search(const typename TreeType::Mat& querySet,
       Log::Info << "Performing single-tree traversal..." << std::endl;
 
       // Create the traverser.
-      typename TreeType::template SingleTreeTraverser<RuleType>
-        traverser(rules);
+      typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
       // Now have it traverse for each point.
       for (size_t i = 0; i < querySetRef.n_cols; ++i)
@@ -223,12 +236,12 @@ Search(const typename TreeType::Mat& querySet,
     // Build the query tree.
     Timer::Stop("computing_neighbors");
     Timer::Start("tree_building");
-    TreeType* queryTree = aux::BuildTree<TreeType>(
-        const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+    Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+        oldFromNewQueries);
     Timer::Stop("tree_building");
     Timer::Start("computing_neighbors");
 
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     Log::Info << "Query statistic pre-search: "
         << queryTree->Stat().NumSamplesMade() << std::endl;
@@ -245,7 +258,7 @@ Search(const typename TreeType::Mat& querySet,
   Timer::Stop("computing_neighbors");
 
   // Map points back to original indices, if necessary.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
   {
     if (!singleMode && !naive && treeOwner)
     {
@@ -304,9 +317,13 @@ Search(const typename TreeType::Mat& querySet,
   }
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::Search(
-    TreeType* queryTree,
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
+    Tree* queryTree,
     const size_t k,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
@@ -314,7 +331,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   Timer::Start("computing_neighbors");
 
   // Get a reference to the query set.
-  const typename TreeType::Mat& querySet = queryTree->Dataset();
+  const MatType& querySet = queryTree->Dataset();
 
   // Make sure we are in dual-tree mode.
   if (singleMode || naive)
@@ -324,7 +341,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   // We won't need to map query indices, but will we need to map distances?
   arma::Mat<size_t>* neighborPtr = &neighbors;
 
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
     neighborPtr = new arma::Mat<size_t>;
 
   neighborPtr->set_size(k, querySet.n_cols);
@@ -333,19 +350,19 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   distances.fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the tree traversal.
-  typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr, distances,
                  metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                  singleSampleLimit, false);
 
   // Create the traverser.
-  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+  typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
   traverser.Traverse(*queryTree, *referenceTree);
 
   Timer::Stop("computing_neighbors");
 
   // Do we need to map indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     // We must map reference indices only.
     neighbors.set_size(k, querySet.n_cols);
@@ -360,8 +377,12 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   }
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
     const size_t k,
     arma::Mat<size_t>& neighbors,
     arma::mat& distances)
@@ -371,7 +392,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   arma::Mat<size_t>* neighborPtr = &neighbors;
   arma::mat* distancePtr = &distances;
 
-  if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+  if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
   {
     // We will always need to rearrange in this case.
     distancePtr = new arma::mat;
@@ -385,7 +406,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   distancePtr->fill(SortPolicy::WorstDistance());
 
   // Create the helper object for the tree traversal.
-  typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+  typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
   RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
                  metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
                  singleSampleLimit, true /* sets are the same */);
@@ -394,10 +415,10 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   {
     // Find how many samples from the reference set we need and sample uniformly
     // from the reference set without replacement.
-    const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+    const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
         tau, alpha);
     arma::uvec distinctSamples;
-    rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+    RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
         distinctSamples);
 
     // The naive brute-force solution.
@@ -408,7 +429,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   else if (singleMode)
   {
     // Create the traverser.
-    typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
 
     // Now have it traverse for each point.
     for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -417,7 +438,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   else
   {
     // Create the traverser.
-    typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+    typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
 
     traverser.Traverse(*referenceTree, *referenceTree);
   }
@@ -425,7 +446,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   Timer::Stop("computing_neighbors");
 
   // Do we need to map the reference indices?
-  if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
   {
     neighbors.set_size(k, referenceSet.n_cols);
     distances.set_size(k, referenceSet.n_cols);
@@ -447,9 +468,13 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
   }
 }
 
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::ResetQueryTree(
-    TreeType* queryNode) const
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::ResetQueryTree(
+    Tree* queryNode) const
 {
   queryNode->Stat().Bound() = SortPolicy::WorstDistance();
   queryNode->Stat().NumSamplesMade() = 0;
@@ -459,8 +484,13 @@ void RASearch<SortPolicy, MetricType, TreeType>::ResetQueryTree(
 }
 
 // Returns a string representation of the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-std::string RASearch<SortPolicy, MetricType, TreeType>::ToString() const
+template<typename SortPolicy,
+         typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
+    const
 {
   std::ostringstream convert;
   convert << "RASearch [" << this << "]" << std::endl;
@@ -493,12 +523,11 @@ std::string RASearch<SortPolicy, MetricType, TreeType>::ToString() const
   else
     convert << "false" << std::endl;
   convert << "  singleSampleLimit: " << singleSampleLimit << std::endl;
-  convert << "  metric: " << std::endl <<
-      mlpack::util::Indent(metric.ToString(),2);
+  convert << "  metric: " << std::endl << util::Indent(metric.ToString(), 2);
   return convert.str();
 }
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 2642afe..fe4cafe 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -10,7 +10,6 @@
 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
 
 #include "../neighbor_search/ns_traversal_info.hpp"
-#include "ra_search.hpp" // For friend declaration.
 
 namespace mlpack {
 namespace neighbor {
@@ -248,47 +247,6 @@ class RASearchRules
                       const double distance);
 
   /**
-   * Compute the minimum number of samples required to guarantee
-   * the given rank-approximation and success probability.
-   *
-   * @param n Size of the set to be sampled from.
-   * @param k The number of neighbors required within the rank-approximation.
-   * @param tau The rank-approximation in percentile of the data.
-   * @param alpha The success probability desired.
-   */
-  size_t MinimumSamplesReqd(const size_t n,
-                            const size_t k,
-                            const double tau,
-                            const double alpha) const;
-
-  /**
-   * Compute the success probability of obtaining 'k'-neighbors from a
-   * set of size 'n' within the top 't' neighbors if 'm' samples are made.
-   *
-   * @param n Size of the set being sampled from.
-   * @param k The number of neighbors required within the rank-approximation.
-   * @param m The number of random samples.
-   * @param t The desired rank-approximation.
-   */
-  double SuccessProbability(const size_t n,
-                            const size_t k,
-                            const size_t m,
-                            const size_t t) const;
-
-  /**
-   * Pick up desired number of samples (with replacement) from a given range
-   * of integers so that only the distinct samples are returned from
-   * the range [0 - specified upper bound)
-   *
-   * @param numSamples Number of random samples.
-   * @param rangeUpperBound The upper bound on the range of integers.
-   * @param distinctSamples The list of the distinct samples.
-   */
-  void ObtainDistinctSamples(const size_t numSamples,
-                             const size_t rangeUpperBound,
-                             arma::uvec& distinctSamples) const;
-
-  /**
    * Perform actual scoring for single-tree case.
    */
   double Score(const size_t queryIndex,
@@ -303,15 +261,10 @@ class RASearchRules
                TreeType& referenceNode,
                const double distance,
                const double bestDistance);
-
-  // So that RASearch can access ObtainDistinctSamples() and
-  // MinimumSamplesReqd().  Maybe refactoring is a better solution but this is
-  // okay for now.
-  friend class RASearch<SortPolicy, MetricType, TreeType>;
 }; // class RASearchRules
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 // Include implementation.
 #include "ra_search_rules_impl.hpp"
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index 77a5b0e..88e74eb 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -57,7 +57,7 @@ RASearchRules(const arma::mat& referenceSet,
         << std::endl;
 
   Timer::Start("computing_number_of_samples_reqd");
-  numSamplesReqd = MinimumSamplesReqd(n, k, tau, alpha);
+  numSamplesReqd = RAUtil::MinimumSamplesReqd(n, k, tau, alpha);
   Timer::Stop("computing_number_of_samples_reqd");
 
   // Initialize some statistics to be collected during the search.
@@ -74,194 +74,13 @@ RASearchRules(const arma::mat& referenceSet,
     for (size_t i = 0; i < querySet.n_cols; ++i)
     {
       arma::uvec distinctSamples;
-      ObtainDistinctSamples(numSamplesReqd, n, distinctSamples);
+      RAUtil::ObtainDistinctSamples(numSamplesReqd, n, distinctSamples);
       for (size_t j = 0; j < distinctSamples.n_elem; j++)
         BaseCase(i, (size_t) distinctSamples[j]);
     }
   }
 }
 
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline
-void RASearchRules<SortPolicy, MetricType, TreeType>::
-ObtainDistinctSamples(const size_t numSamples,
-                      const size_t rangeUpperBound,
-                      arma::uvec& distinctSamples) const
-{
-  // Keep track of the points that are sampled.
-  arma::Col<size_t> sampledPoints;
-  sampledPoints.zeros(rangeUpperBound);
-
-  for (size_t i = 0; i < numSamples; i++)
-    sampledPoints[(size_t) math::RandInt(rangeUpperBound)]++;
-
-  distinctSamples = arma::find(sampledPoints > 0);
-  return;
-}
-
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-size_t RASearchRules<SortPolicy, MetricType, TreeType>::
-MinimumSamplesReqd(const size_t n,
-                   const size_t k,
-                   const double tau,
-                   const double alpha) const
-{
-  size_t ub = n; // The upper bound on the binary search.
-  size_t lb = k; // The lower bound on the binary search.
-  size_t  m = lb; // The minimum number of random samples.
-
-  // The rank-approximation.
-  const size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
-
-  double prob;
-  Log::Assert(alpha <= 1.0);
-
-  // going through all values of sample sizes
-  // to find the minimum samples required to satisfy the
-  // desired bound
-  bool done = false;
-
-  // This performs a binary search on the integer values between 'lb = k'
-  // and 'ub = n' to find the minimum number of samples 'm' required to obtain
-  // the desired success probability 'alpha'.
-  do
-  {
-    prob = SuccessProbability(n, k, m, t);
-
-    if (prob > alpha)
-    {
-      if (prob - alpha < 0.001 || ub < lb + 2) {
-        done = true;
-        break;
-      }
-      else
-        ub = m;
-    }
-    else
-    {
-      if (prob < alpha)
-      {
-        if (m == lb)
-        {
-          m++;
-          continue;
-        }
-        else
-          lb = m;
-      }
-      else
-      {
-        done = true;
-        break;
-      }
-    }
-    m = (ub + lb) / 2;
-
-  } while (!done);
-
-  return (std::min(m + 1, n));
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-double RASearchRules<SortPolicy, MetricType, TreeType>::SuccessProbability(
-    const size_t n,
-    const size_t k,
-    const size_t m,
-    const size_t t) const
-{
-  if (k == 1)
-  {
-    if (m > n - t)
-      return 1.0;
-
-    double eps = (double) t / (double) n;
-
-    return 1.0 - std::pow(1.0 - eps, (double) m);
-
-  } // Faster implementation for topK = 1.
-  else
-  {
-    if (m < k)
-      return 0.0;
-
-    if (m > n - t + k - 1)
-      return 1.0;
-
-    double eps = (double) t / (double) n;
-    double sum = 0.0;
-
-    // The probability that 'k' of the 'm' samples lie within the top 't'
-    // of the neighbors is given by:
-    // sum_{j = k}^m Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
-    // which is also equal to
-    // 1 - sum_{j = 0}^{k - 1} Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
-    //
-    // So this is a m - k term summation or a k term summation. So if
-    // m > 2k, do the k term summation, otherwise do the m term summation.
-
-    size_t lb;
-    size_t ub;
-    bool topHalf;
-
-    if (2 * k < m)
-    {
-      // Compute 1 - sum_{j = 0}^{k - 1} Choose(m, j) eps^j (1 - eps)^{m - j}
-      // eps = t/n.
-      //
-      // Choosing 'lb' as 1 and 'ub' as k so as to sum from 1 to (k - 1), and
-      // add the term (1 - eps)^m term separately.
-      lb = 1;
-      ub = k;
-      topHalf = true;
-      sum = std::pow(1 - eps, (double) m);
-    }
-    else
-    {
-      // Compute sum_{j = k}^m Choose(m, j) eps^j (1 - eps)^{m - j}
-      // eps = t/n.
-      //
-      // Choosing 'lb' as k and 'ub' as m so as to sum from k to (m - 1), and
-      // add the term eps^m term separately.
-      lb = k;
-      ub = m;
-      topHalf = false;
-      sum = std::pow(eps, (double) m);
-    }
-
-    for (size_t j = lb; j < ub; j++)
-    {
-      // Compute Choose(m, j).
-      double mCj = (double) m;
-      size_t jTrans;
-
-      // If j < m - j, compute Choose(m, j).
-      // If j > m - j, compute Choose(m, m - j).
-      if (topHalf)
-        jTrans = j;
-      else
-        jTrans = m - j;
-
-      for(size_t i = 2; i <= jTrans; i++)
-      {
-        mCj *= (double) (m - (i - 1));
-        mCj /= (double) i;
-      }
-
-      sum += (mCj * std::pow(eps, (double) j)
-              * std::pow(1.0 - eps, (double) (m - j)));
-    }
-
-    if (topHalf)
-      sum = 1.0 - sum;
-
-    return sum;
-  } // For k > 1.
-}
-
 template<typename SortPolicy, typename MetricType, typename TreeType>
 inline force_inline
 double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
@@ -359,7 +178,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
           // Then samplesReqd <= singleSampleLimit.
           // Hence, approximate the node by sampling enough number of points.
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+          RAUtil::ObtainDistinctSamples(samplesReqd,
+              referenceNode.NumDescendants(),
                                 distinctSamples);
           for (size_t i = 0; i < distinctSamples.n_elem; i++)
             // The counting of the samples are done in the 'BaseCase' function
@@ -375,7 +195,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
           {
             // Approximate node by sampling enough number of points.
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+            RAUtil::ObtainDistinctSamples(samplesReqd,
+                referenceNode.NumDescendants(),
                                   distinctSamples);
             for (size_t i = 0; i < distinctSamples.n_elem; i++)
               // The counting of the samples are done in the 'BaseCase' function
@@ -463,8 +284,8 @@ Rescore(const size_t queryIndex,
         // Then, samplesReqd <= singleSampleLimit.  Hence, approximate the node
         // by sampling enough number of points.
         arma::uvec distinctSamples;
-        ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-            distinctSamples);
+        RAUtil::ObtainDistinctSamples(samplesReqd,
+            referenceNode.NumDescendants(), distinctSamples);
         for (size_t i = 0; i < distinctSamples.n_elem; i++)
           // The counting of the samples are done in the 'BaseCase' function so
           // no book-keeping is required here.
@@ -479,8 +300,8 @@ Rescore(const size_t queryIndex,
         {
           // Approximate node by sampling enough points.
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-                                distinctSamples);
+          RAUtil::ObtainDistinctSamples(samplesReqd,
+              referenceNode.NumDescendants(), distinctSamples);
           for (size_t i = 0; i < distinctSamples.n_elem; i++)
             // The counting of the samples are done in the 'BaseCase' function
             // so no book-keeping is required here.
@@ -666,8 +487,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
           {
             const size_t queryIndex = queryNode.Descendant(i);
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-                                  distinctSamples);
+            RAUtil::ObtainDistinctSamples(samplesReqd,
+                referenceNode.NumDescendants(), distinctSamples);
             for (size_t j = 0; j < distinctSamples.n_elem; j++)
               // The counting of the samples are done in the 'BaseCase' function
               // so no book-keeping is required here.
@@ -696,8 +517,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
             {
               const size_t queryIndex = queryNode.Descendant(i);
               arma::uvec distinctSamples;
-              ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-                                    distinctSamples);
+              RAUtil::ObtainDistinctSamples(samplesReqd,
+                  referenceNode.NumDescendants(), distinctSamples);
               for (size_t j = 0; j < distinctSamples.n_elem; j++)
                 // The counting of the samples are done in the 'BaseCase'
                 // function so no book-keeping is required here.
@@ -871,8 +692,8 @@ Rescore(TreeType& queryNode,
         {
           const size_t queryIndex = queryNode.Descendant(i);
           arma::uvec distinctSamples;
-          ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-              distinctSamples);
+          RAUtil::ObtainDistinctSamples(samplesReqd,
+              referenceNode.NumDescendants(), distinctSamples);
           for (size_t j = 0; j < distinctSamples.n_elem; j++)
             // The counting of the samples are done in the 'BaseCase'
             // function so no book-keeping is required here.
@@ -900,8 +721,8 @@ Rescore(TreeType& queryNode,
           {
             const size_t queryIndex = queryNode.Descendant(i);
             arma::uvec distinctSamples;
-            ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
-                                  distinctSamples);
+            RAUtil::ObtainDistinctSamples(samplesReqd,
+                referenceNode.NumDescendants(), distinctSamples);
             for (size_t j = 0; j < distinctSamples.n_elem; j++)
               // The counting of the samples are done in BaseCase() so no
               // book-keeping is required here.



More information about the mlpack-git mailing list