[mlpack-git] master: Refactor for new TreeType API. (48a4fef)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:23 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit 48a4fefcd4c5d12ce3c0e60d6fc3cc56efdda05e
Author: ryan <ryan at ratml.org>
Date: Sun Jul 26 23:07:45 2015 -0400
Refactor for new TreeType API.
>---------------------------------------------------------------
48a4fefcd4c5d12ce3c0e60d6fc3cc56efdda05e
src/mlpack/methods/rann/CMakeLists.txt | 6 +-
src/mlpack/methods/rann/allkrann_main.cpp | 5 +-
src/mlpack/methods/rann/ra_query_stat.hpp | 3 +
src/mlpack/methods/rann/ra_search.hpp | 34 ++--
src/mlpack/methods/rann/ra_search_impl.hpp | 139 +++++++++------
src/mlpack/methods/rann/ra_search_rules.hpp | 51 +-----
src/mlpack/methods/rann/ra_search_rules_impl.hpp | 215 ++---------------------
7 files changed, 136 insertions(+), 317 deletions(-)
diff --git a/src/mlpack/methods/rann/CMakeLists.txt b/src/mlpack/methods/rann/CMakeLists.txt
index b848cd2..30e1a89 100644
--- a/src/mlpack/methods/rann/CMakeLists.txt
+++ b/src/mlpack/methods/rann/CMakeLists.txt
@@ -15,6 +15,10 @@ set(SOURCES
# typedefs
ra_typedef.hpp
+
+ # utilities
+ ra_util.hpp
+ ra_util.cpp
)
# add directory name to sources
@@ -23,7 +27,7 @@ foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# append sources (with directory name) to list of all MLPACK sources (used at the parent scope)
-set(MLPACK_CONTRIB_SRCS ${MLPACK_CONTRIB_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
# The code to compute the rank-approximate neighbor
diff --git a/src/mlpack/methods/rann/allkrann_main.cpp b/src/mlpack/methods/rann/allkrann_main.cpp
index eb0ab87..c5add09 100644
--- a/src/mlpack/methods/rann/allkrann_main.cpp
+++ b/src/mlpack/methods/rann/allkrann_main.cpp
@@ -20,6 +20,7 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
using namespace mlpack::tree;
+using namespace mlpack::metric;
// Information about the program itself.
PROGRAM_INFO("All K-Rank-Approximate-Nearest-Neighbors",
@@ -176,8 +177,8 @@ int main(int argc, char *argv[])
// NeighborSearch, it does not copy the matrix.
Log::Info << "Building reference tree..." << endl;
Timer::Start("tree_building");
- typedef BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<NearestNeighborSort> > TreeType;
+ typedef KDTree<EuclideanDistance, RAQueryStat<NearestNeighborSort>,
+ arma::mat> TreeType;
TreeType refTree(referenceData, oldFromNewRefs, leafSize);
Timer::Stop("tree_building");
diff --git a/src/mlpack/methods/rann/ra_query_stat.hpp b/src/mlpack/methods/rann/ra_query_stat.hpp
index 47ad4dc..d9c2610 100644
--- a/src/mlpack/methods/rann/ra_query_stat.hpp
+++ b/src/mlpack/methods/rann/ra_query_stat.hpp
@@ -64,4 +64,7 @@ class RAQueryStat
};
+} // namespace neighbor
+} // namespace mlpack
+
#endif
diff --git a/src/mlpack/methods/rann/ra_search.hpp b/src/mlpack/methods/rann/ra_search.hpp
index 135f381..219b33f 100644
--- a/src/mlpack/methods/rann/ra_search.hpp
+++ b/src/mlpack/methods/rann/ra_search.hpp
@@ -26,6 +26,10 @@
#include <mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp>
#include "ra_query_stat.hpp"
+#include "ra_util.hpp"
+
+namespace mlpack {
+namespace neighbor {
/**
* The RASearch class: This class provides a generic manner to perform
@@ -50,12 +54,16 @@
* @tparam TreeType The tree type to use.
*/
template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2, false>,
- RAQueryStat<SortPolicy> > >
+ typename MetricType = metric::EuclideanDistance,
+ typename MatType = arma::mat,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType = tree::KDTree>
class RASearch
{
public:
+ //! Convenience typedef.
+ typedef TreeType<MetricType, RAQueryStat<SortPolicy>, MatType> Tree;
+
/**
* Initialize the RASearch object, passing both a reference dataset (this is
* the dataset that will be searched). Optionally, perform the computation in
@@ -100,7 +108,7 @@ class RASearch
* @param singleSampleLimit The limit on the largest node that can be
* approximated by sampling. This defaults to 20.
*/
- RASearch(const typename TreeType::Mat& referenceSet,
+ RASearch(const MatType& referenceSet,
const bool naive = false,
const bool singleMode = false,
const double tau = 5,
@@ -158,7 +166,7 @@ class RASearch
* @param singleSampleLimit The limit on the largest node that can be
* approximated by sampling. This defaults to 20.
*/
- RASearch(TreeType* referenceTree,
+ RASearch(Tree* referenceTree,
const bool singleMode = false,
const double tau = 5,
const double alpha = 0.95,
@@ -189,7 +197,7 @@ class RASearch
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void Search(const typename TreeType::Mat& querySet,
+ void Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);
@@ -217,7 +225,7 @@ class RASearch
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
- void Search(TreeType* queryTree,
+ void Search(Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances);
@@ -251,7 +259,7 @@ class RASearch
*
* @param queryTree Tree whose statistics should be reset.
*/
- void ResetQueryTree(TreeType* queryTree) const;
+ void ResetQueryTree(Tree* queryTree) const;
//! Get the rank-approximation in percentile of the data.
double Tau() const { return tau; }
@@ -284,11 +292,11 @@ class RASearch
private:
//! Copy of reference dataset (if we need it, because tree building modifies
//! it).
- arma::mat referenceCopy;
+ MatType referenceCopy;
//! Reference dataset.
- const arma::mat& referenceSet;
+ const MatType& referenceSet;
//! Pointer to the root of the reference tree.
- TreeType* referenceTree;
+ Tree* referenceTree;
//! If true, this object created the trees and is responsible for them.
bool treeOwner;
@@ -318,8 +326,8 @@ class RASearch
}; // class RASearch
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
// Include implementation.
#include "ra_search_impl.hpp"
diff --git a/src/mlpack/methods/rann/ra_search_impl.hpp b/src/mlpack/methods/rann/ra_search_impl.hpp
index b0db731..d44707a 100644
--- a/src/mlpack/methods/rann/ra_search_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_impl.hpp
@@ -44,9 +44,13 @@ TreeType* BuildTree(
} // namespace aux
// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(const typename TreeType::Mat& referenceSetIn,
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(const MatType& referenceSetIn,
const bool naive,
const bool singleMode,
const double tau,
@@ -55,7 +59,7 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
const bool firstLeafExact,
const size_t singleSampleLimit,
const MetricType metric) :
- referenceSet((tree::TreeTraits<TreeType>::RearrangesDataset && !naive)
+ referenceSet((tree::TreeTraits<Tree>::RearrangesDataset && !naive)
? referenceCopy : referenceSetIn),
referenceTree(NULL),
treeOwner(!naive),
@@ -73,11 +77,10 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
if (!naive)
{
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
referenceCopy = referenceSetIn;
- referenceTree = aux::BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(referenceSet),
+ referenceTree = aux::BuildTree<Tree>(const_cast<MatType&>(referenceSet),
oldFromNewReferences);
}
@@ -86,9 +89,13 @@ RASearch(const typename TreeType::Mat& referenceSetIn,
}
// Construct the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
-RASearch(TreeType* referenceTree,
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
+RASearch(Tree* referenceTree,
const bool singleMode,
const double tau,
const double alpha,
@@ -114,8 +121,12 @@ RASearch(TreeType* referenceTree,
* The tree is the only member we may be responsible for deleting. The others
* will take care of themselves.
*/
-template<typename SortPolicy, typename MetricType, typename TreeType>
-RASearch<SortPolicy, MetricType, TreeType>::
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+RASearch<SortPolicy, MetricType, MatType, TreeType>::
~RASearch()
{
if (treeOwner && referenceTree)
@@ -126,9 +137,13 @@ RASearch<SortPolicy, MetricType, TreeType>::
* Computes the best neighbors and stores them in resultingNeighbors and
* distances.
*/
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::
-Search(const typename TreeType::Mat& querySet,
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::
+Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -147,7 +162,7 @@ Search(const typename TreeType::Mat& querySet,
// Mapping is only required if this tree type rearranges points and we are not
// in naive mode.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
if (!singleMode && !naive)
distancePtr = new arma::mat; // Query indices need to be mapped.
@@ -163,17 +178,16 @@ Search(const typename TreeType::Mat& querySet,
// If we will be building a tree and it will modify the query set, make a copy
// of the dataset.
- typename TreeType::Mat queryCopy;
+ MatType queryCopy;
const bool needsCopy = (!naive && !singleMode &&
- tree::TreeTraits<TreeType>::RearrangesDataset);
+ tree::TreeTraits<Tree>::RearrangesDataset);
if (needsCopy)
queryCopy = querySet;
- const typename TreeType::Mat& querySetRef = (needsCopy) ? queryCopy :
- querySet;
+ const MatType& querySetRef = (needsCopy) ? queryCopy : querySet;
// Create the helper object for the tree traversal.
- typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(referenceSet, querySetRef, *neighborPtr, *distancePtr,
metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, false);
@@ -182,10 +196,10 @@ Search(const typename TreeType::Mat& querySet,
{
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
- const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+ const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
tau, alpha);
arma::uvec distinctSamples;
- rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+ RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
distinctSamples);
// Run the base case on each combination of query point and sampled
@@ -203,8 +217,7 @@ Search(const typename TreeType::Mat& querySet,
Log::Info << "Performing single-tree traversal..." << std::endl;
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType>
- traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < querySetRef.n_cols; ++i)
@@ -223,12 +236,12 @@ Search(const typename TreeType::Mat& querySet,
// Build the query tree.
Timer::Stop("computing_neighbors");
Timer::Start("tree_building");
- TreeType* queryTree = aux::BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(querySetRef), oldFromNewQueries);
+ Tree* queryTree = aux::BuildTree<Tree>(const_cast<MatType&>(querySetRef),
+ oldFromNewQueries);
Timer::Stop("tree_building");
Timer::Start("computing_neighbors");
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
Log::Info << "Query statistic pre-search: "
<< queryTree->Stat().NumSamplesMade() << std::endl;
@@ -245,7 +258,7 @@ Search(const typename TreeType::Mat& querySet,
Timer::Stop("computing_neighbors");
// Map points back to original indices, if necessary.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
if (!singleMode && !naive && treeOwner)
{
@@ -304,9 +317,13 @@ Search(const typename TreeType::Mat& querySet,
}
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::Search(
- TreeType* queryTree,
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
+ Tree* queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -314,7 +331,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
Timer::Start("computing_neighbors");
// Get a reference to the query set.
- const typename TreeType::Mat& querySet = queryTree->Dataset();
+ const MatType& querySet = queryTree->Dataset();
// Make sure we are in dual-tree mode.
if (singleMode || naive)
@@ -324,7 +341,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
// We won't need to map query indices, but will we need to map distances?
arma::Mat<size_t>* neighborPtr = &neighbors;
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
neighborPtr = new arma::Mat<size_t>;
neighborPtr->set_size(k, querySet.n_cols);
@@ -333,19 +350,19 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
distances.fill(SortPolicy::WorstDistance());
// Create the helper object for the tree traversal.
- typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(referenceSet, queryTree->Dataset(), *neighborPtr, distances,
metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, false);
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
Timer::Stop("computing_neighbors");
// Do we need to map indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
// We must map reference indices only.
neighbors.set_size(k, querySet.n_cols);
@@ -360,8 +377,12 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
}
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::Search(
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::Search(
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances)
@@ -371,7 +392,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
arma::Mat<size_t>* neighborPtr = &neighbors;
arma::mat* distancePtr = &distances;
- if (tree::TreeTraits<TreeType>::RearrangesDataset && treeOwner)
+ if (tree::TreeTraits<Tree>::RearrangesDataset && treeOwner)
{
// We will always need to rearrange in this case.
distancePtr = new arma::mat;
@@ -385,7 +406,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
distancePtr->fill(SortPolicy::WorstDistance());
// Create the helper object for the tree traversal.
- typedef RASearchRules<SortPolicy, MetricType, TreeType> RuleType;
+ typedef RASearchRules<SortPolicy, MetricType, Tree> RuleType;
RuleType rules(referenceSet, referenceSet, *neighborPtr, *distancePtr,
metric, tau, alpha, naive, sampleAtLeaves, firstLeafExact,
singleSampleLimit, true /* sets are the same */);
@@ -394,10 +415,10 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
{
// Find how many samples from the reference set we need and sample uniformly
// from the reference set without replacement.
- const size_t numSamples = rules.MinimumSamplesReqd(referenceSet.n_cols, k,
+ const size_t numSamples = RAUtil::MinimumSamplesReqd(referenceSet.n_cols, k,
tau, alpha);
arma::uvec distinctSamples;
- rules.ObtainDistinctSamples(numSamples, referenceSet.n_cols,
+ RAUtil::ObtainDistinctSamples(numSamples, referenceSet.n_cols,
distinctSamples);
// The naive brute-force solution.
@@ -408,7 +429,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
else if (singleMode)
{
// Create the traverser.
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
// Now have it traverse for each point.
for (size_t i = 0; i < referenceSet.n_cols; ++i)
@@ -417,7 +438,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
else
{
// Create the traverser.
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*referenceTree, *referenceTree);
}
@@ -425,7 +446,7 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
Timer::Stop("computing_neighbors");
// Do we need to map the reference indices?
- if (treeOwner && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (treeOwner && tree::TreeTraits<Tree>::RearrangesDataset)
{
neighbors.set_size(k, referenceSet.n_cols);
distances.set_size(k, referenceSet.n_cols);
@@ -447,9 +468,13 @@ void RASearch<SortPolicy, MetricType, TreeType>::Search(
}
}
-template<typename SortPolicy, typename MetricType, typename TreeType>
-void RASearch<SortPolicy, MetricType, TreeType>::ResetQueryTree(
- TreeType* queryNode) const
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void RASearch<SortPolicy, MetricType, MatType, TreeType>::ResetQueryTree(
+ Tree* queryNode) const
{
queryNode->Stat().Bound() = SortPolicy::WorstDistance();
queryNode->Stat().NumSamplesMade() = 0;
@@ -459,8 +484,13 @@ void RASearch<SortPolicy, MetricType, TreeType>::ResetQueryTree(
}
// Returns a string representation of the object.
-template<typename SortPolicy, typename MetricType, typename TreeType>
-std::string RASearch<SortPolicy, MetricType, TreeType>::ToString() const
+template<typename SortPolicy,
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+std::string RASearch<SortPolicy, MetricType, MatType, TreeType>::ToString()
+ const
{
std::ostringstream convert;
convert << "RASearch [" << this << "]" << std::endl;
@@ -493,12 +523,11 @@ std::string RASearch<SortPolicy, MetricType, TreeType>::ToString() const
else
convert << "false" << std::endl;
convert << " singleSampleLimit: " << singleSampleLimit << std::endl;
- convert << " metric: " << std::endl <<
- mlpack::util::Indent(metric.ToString(),2);
+ convert << " metric: " << std::endl << util::Indent(metric.ToString(), 2);
return convert.str();
}
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/rann/ra_search_rules.hpp b/src/mlpack/methods/rann/ra_search_rules.hpp
index 2642afe..fe4cafe 100644
--- a/src/mlpack/methods/rann/ra_search_rules.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules.hpp
@@ -10,7 +10,6 @@
#define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
#include "../neighbor_search/ns_traversal_info.hpp"
-#include "ra_search.hpp" // For friend declaration.
namespace mlpack {
namespace neighbor {
@@ -248,47 +247,6 @@ class RASearchRules
const double distance);
/**
- * Compute the minimum number of samples required to guarantee
- * the given rank-approximation and success probability.
- *
- * @param n Size of the set to be sampled from.
- * @param k The number of neighbors required within the rank-approximation.
- * @param tau The rank-approximation in percentile of the data.
- * @param alpha The success probability desired.
- */
- size_t MinimumSamplesReqd(const size_t n,
- const size_t k,
- const double tau,
- const double alpha) const;
-
- /**
- * Compute the success probability of obtaining 'k'-neighbors from a
- * set of size 'n' within the top 't' neighbors if 'm' samples are made.
- *
- * @param n Size of the set being sampled from.
- * @param k The number of neighbors required within the rank-approximation.
- * @param m The number of random samples.
- * @param t The desired rank-approximation.
- */
- double SuccessProbability(const size_t n,
- const size_t k,
- const size_t m,
- const size_t t) const;
-
- /**
- * Pick up desired number of samples (with replacement) from a given range
- * of integers so that only the distinct samples are returned from
- * the range [0 - specified upper bound)
- *
- * @param numSamples Number of random samples.
- * @param rangeUpperBound The upper bound on the range of integers.
- * @param distinctSamples The list of the distinct samples.
- */
- void ObtainDistinctSamples(const size_t numSamples,
- const size_t rangeUpperBound,
- arma::uvec& distinctSamples) const;
-
- /**
* Perform actual scoring for single-tree case.
*/
double Score(const size_t queryIndex,
@@ -303,15 +261,10 @@ class RASearchRules
TreeType& referenceNode,
const double distance,
const double bestDistance);
-
- // So that RASearch can access ObtainDistinctSamples() and
- // MinimumSamplesReqd(). Maybe refactoring is a better solution but this is
- // okay for now.
- friend class RASearch<SortPolicy, MetricType, TreeType>;
}; // class RASearchRules
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
// Include implementation.
#include "ra_search_rules_impl.hpp"
diff --git a/src/mlpack/methods/rann/ra_search_rules_impl.hpp b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
index 77a5b0e..88e74eb 100644
--- a/src/mlpack/methods/rann/ra_search_rules_impl.hpp
+++ b/src/mlpack/methods/rann/ra_search_rules_impl.hpp
@@ -57,7 +57,7 @@ RASearchRules(const arma::mat& referenceSet,
<< std::endl;
Timer::Start("computing_number_of_samples_reqd");
- numSamplesReqd = MinimumSamplesReqd(n, k, tau, alpha);
+ numSamplesReqd = RAUtil::MinimumSamplesReqd(n, k, tau, alpha);
Timer::Stop("computing_number_of_samples_reqd");
// Initialize some statistics to be collected during the search.
@@ -74,194 +74,13 @@ RASearchRules(const arma::mat& referenceSet,
for (size_t i = 0; i < querySet.n_cols; ++i)
{
arma::uvec distinctSamples;
- ObtainDistinctSamples(numSamplesReqd, n, distinctSamples);
+ RAUtil::ObtainDistinctSamples(numSamplesReqd, n, distinctSamples);
for (size_t j = 0; j < distinctSamples.n_elem; j++)
BaseCase(i, (size_t) distinctSamples[j]);
}
}
}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline force_inline
-void RASearchRules<SortPolicy, MetricType, TreeType>::
-ObtainDistinctSamples(const size_t numSamples,
- const size_t rangeUpperBound,
- arma::uvec& distinctSamples) const
-{
- // Keep track of the points that are sampled.
- arma::Col<size_t> sampledPoints;
- sampledPoints.zeros(rangeUpperBound);
-
- for (size_t i = 0; i < numSamples; i++)
- sampledPoints[(size_t) math::RandInt(rangeUpperBound)]++;
-
- distinctSamples = arma::find(sampledPoints > 0);
- return;
-}
-
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-size_t RASearchRules<SortPolicy, MetricType, TreeType>::
-MinimumSamplesReqd(const size_t n,
- const size_t k,
- const double tau,
- const double alpha) const
-{
- size_t ub = n; // The upper bound on the binary search.
- size_t lb = k; // The lower bound on the binary search.
- size_t m = lb; // The minimum number of random samples.
-
- // The rank-approximation.
- const size_t t = (size_t) std::ceil(tau * (double) n / 100.0);
-
- double prob;
- Log::Assert(alpha <= 1.0);
-
- // going through all values of sample sizes
- // to find the minimum samples required to satisfy the
- // desired bound
- bool done = false;
-
- // This performs a binary search on the integer values between 'lb = k'
- // and 'ub = n' to find the minimum number of samples 'm' required to obtain
- // the desired success probability 'alpha'.
- do
- {
- prob = SuccessProbability(n, k, m, t);
-
- if (prob > alpha)
- {
- if (prob - alpha < 0.001 || ub < lb + 2) {
- done = true;
- break;
- }
- else
- ub = m;
- }
- else
- {
- if (prob < alpha)
- {
- if (m == lb)
- {
- m++;
- continue;
- }
- else
- lb = m;
- }
- else
- {
- done = true;
- break;
- }
- }
- m = (ub + lb) / 2;
-
- } while (!done);
-
- return (std::min(m + 1, n));
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-double RASearchRules<SortPolicy, MetricType, TreeType>::SuccessProbability(
- const size_t n,
- const size_t k,
- const size_t m,
- const size_t t) const
-{
- if (k == 1)
- {
- if (m > n - t)
- return 1.0;
-
- double eps = (double) t / (double) n;
-
- return 1.0 - std::pow(1.0 - eps, (double) m);
-
- } // Faster implementation for topK = 1.
- else
- {
- if (m < k)
- return 0.0;
-
- if (m > n - t + k - 1)
- return 1.0;
-
- double eps = (double) t / (double) n;
- double sum = 0.0;
-
- // The probability that 'k' of the 'm' samples lie within the top 't'
- // of the neighbors is given by:
- // sum_{j = k}^m Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
- // which is also equal to
- // 1 - sum_{j = 0}^{k - 1} Choose(m, j) (t/n)^j (1 - t/n)^{m - j}
- //
- // So this is a m - k term summation or a k term summation. So if
- // m > 2k, do the k term summation, otherwise do the m term summation.
-
- size_t lb;
- size_t ub;
- bool topHalf;
-
- if (2 * k < m)
- {
- // Compute 1 - sum_{j = 0}^{k - 1} Choose(m, j) eps^j (1 - eps)^{m - j}
- // eps = t/n.
- //
- // Choosing 'lb' as 1 and 'ub' as k so as to sum from 1 to (k - 1), and
- // add the term (1 - eps)^m term separately.
- lb = 1;
- ub = k;
- topHalf = true;
- sum = std::pow(1 - eps, (double) m);
- }
- else
- {
- // Compute sum_{j = k}^m Choose(m, j) eps^j (1 - eps)^{m - j}
- // eps = t/n.
- //
- // Choosing 'lb' as k and 'ub' as m so as to sum from k to (m - 1), and
- // add the term eps^m term separately.
- lb = k;
- ub = m;
- topHalf = false;
- sum = std::pow(eps, (double) m);
- }
-
- for (size_t j = lb; j < ub; j++)
- {
- // Compute Choose(m, j).
- double mCj = (double) m;
- size_t jTrans;
-
- // If j < m - j, compute Choose(m, j).
- // If j > m - j, compute Choose(m, m - j).
- if (topHalf)
- jTrans = j;
- else
- jTrans = m - j;
-
- for(size_t i = 2; i <= jTrans; i++)
- {
- mCj *= (double) (m - (i - 1));
- mCj /= (double) i;
- }
-
- sum += (mCj * std::pow(eps, (double) j)
- * std::pow(1.0 - eps, (double) (m - j)));
- }
-
- if (topHalf)
- sum = 1.0 - sum;
-
- return sum;
- } // For k > 1.
-}
-
template<typename SortPolicy, typename MetricType, typename TreeType>
inline force_inline
double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
@@ -359,7 +178,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
// Then samplesReqd <= singleSampleLimit.
// Hence, approximate the node by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
@@ -375,7 +195,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
{
// Approximate node by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
@@ -463,8 +284,8 @@ Rescore(const size_t queryIndex,
// Then, samplesReqd <= singleSampleLimit. Hence, approximate the node
// by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function so
// no book-keeping is required here.
@@ -479,8 +300,8 @@ Rescore(const size_t queryIndex,
{
// Approximate node by sampling enough points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
@@ -666,8 +487,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
{
const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
@@ -696,8 +517,8 @@ inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
{
const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase'
// function so no book-keeping is required here.
@@ -871,8 +692,8 @@ Rescore(TreeType& queryNode,
{
const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase'
// function so no book-keeping is required here.
@@ -900,8 +721,8 @@ Rescore(TreeType& queryNode,
{
const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
- distinctSamples);
+ RAUtil::ObtainDistinctSamples(samplesReqd,
+ referenceNode.NumDescendants(), distinctSamples);
for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in BaseCase() so no
// book-keeping is required here.
More information about the mlpack-git
mailing list