[mlpack-svn] r15919 - mlpack/trunk/src/mlpack/methods/rann
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 4 00:30:09 EDT 2013
Author: rcurtin
Date: Fri Oct 4 00:30:09 2013
New Revision: 15919
Log:
Overhaul RASearchRules so that they will work correctly with cover trees.
Modified:
mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/rann/ra_search_rules_impl.hpp Fri Oct 4 00:30:09 2013
@@ -165,11 +165,11 @@
template<typename SortPolicy, typename MetricType, typename TreeType>
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-SuccessProbability(const size_t n,
- const size_t k,
- const size_t m,
- const size_t t) const
+double RASearchRules<SortPolicy, MetricType, TreeType>::SuccessProbability(
+ const size_t n,
+ const size_t k,
+ const size_t m,
+ const size_t t) const
{
if (k == 1)
{
@@ -262,8 +262,9 @@
template<typename SortPolicy, typename MetricType, typename TreeType>
inline force_inline
-double RASearchRules<SortPolicy, MetricType, TreeType>::
-BaseCase(const size_t queryIndex, const size_t referenceIndex)
+double RASearchRules<SortPolicy, MetricType, TreeType>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
// and we should not return identical points.
@@ -290,32 +291,10 @@
return distance;
}
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Prescore(TreeType& queryNode,
- TreeType& referenceNode,
- TreeType& referenceChildNode,
- const double baseCaseResult) const
-{
- return 0.0;
-}
-
-
-template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-PrescoreQ(TreeType& queryNode,
- TreeType& queryChildNode,
- TreeType& referenceNode,
- const double baseCaseResult) const
-{
- return 0.0;
-}
-
-
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex, TreeType& referenceNode)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode)
{
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
@@ -326,10 +305,10 @@
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double baseCaseResult)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult)
{
const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
const double distance = SortPolicy::BestPointToNodeDistance(queryPoint,
@@ -340,11 +319,11 @@
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline double RASearchRules<SortPolicy, MetricType, TreeType>::
-Score(const size_t queryIndex,
- TreeType& referenceNode,
- const double distance,
- const double bestDistance)
+inline double RASearchRules<SortPolicy, MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double distance,
+ const double bestDistance)
{
// If this is better than the best distance we've seen so far, maybe there
// will be something down this node. Also check if enough samples are already
@@ -360,7 +339,7 @@
{
// Check if this node can be approximated by sampling.
size_t samplesReqd = (size_t) std::ceil(samplingRatio *
- (double) referenceNode.Count());
+ (double) referenceNode.NumDescendants());
samplesReqd = std::min(samplesReqd,
numSamplesReqd - numSamplesMade[queryIndex]);
@@ -376,13 +355,12 @@
// Then samplesReqd <= singleSampleLimit.
// Hence, approximate the node by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
+ BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
// Node approximated, so we can prune it.
return DBL_MAX;
@@ -393,15 +371,15 @@
{
// Approximate node by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
+ referenceNode.Descendant(distinctSamples[i]));
- // (Leaf) node approximated so can prune it.
+ // (Leaf) node approximated, so we can prune it.
return DBL_MAX;
}
else
@@ -429,8 +407,8 @@
// If enough samples are already made, this step does not change the result
// of the search.
- numSamplesMade[queryIndex] +=
- (size_t) std::floor(samplingRatio * (double) referenceNode.Count());
+ numSamplesMade[queryIndex] += (size_t) std::floor(
+ samplingRatio * (double) referenceNode.NumDescendants());
return DBL_MAX;
}
@@ -464,7 +442,7 @@
// Check if this node can be approximated by sampling.
size_t samplesReqd = (size_t) std::ceil(samplingRatio *
- (double) referenceNode.Count());
+ (double) referenceNode.NumDescendants());
samplesReqd = std::min(samplesReqd, numSamplesReqd -
numSamplesMade[queryIndex]);
@@ -481,13 +459,12 @@
// Then, samplesReqd <= singleSampleLimit. Hence, approximate the node
// by sampling enough number of points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function so
// no book-keeping is required here.
- BaseCase(queryIndex, referenceNode.Begin() + (size_t)
- distinctSamples[i]);
+ BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
// Node approximated, so we can prune it.
return DBL_MAX;
@@ -498,13 +475,12 @@
{
// Approximate node by sampling enough points.
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
for (size_t i = 0; i < distinctSamples.n_elem; i++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
+ BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[i]));
// (Leaf) node approximated, so we can prune it.
return DBL_MAX;
@@ -526,7 +502,7 @@
// these samples need not be computed. If enough samples are already made,
// this step does not change the result of the search.
numSamplesMade[queryIndex] += (size_t) std::floor(samplingRatio *
- (double) referenceNode.Count());
+ (double) referenceNode.NumDescendants());
return DBL_MAX;
}
@@ -589,7 +565,7 @@
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
- + maxDescendantDistance;
+ + maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
}
@@ -656,7 +632,7 @@
{
// Check if this node can be approximated by sampling.
size_t samplesReqd = (size_t) std::ceil(samplingRatio
- * (double) referenceNode.Count());
+ * (double) referenceNode.NumDescendants());
samplesReqd = std::min(samplesReqd, numSamplesReqd -
queryNode.Stat().NumSamplesMade());
@@ -682,17 +658,17 @@
{
// Then samplesReqd <= singleSampleLimit. Hence, approximate node by
// sampling enough number of points for every query in the query node.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
+ for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
{
+ const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase' function
// so no book-keeping is required here.
- BaseCase(queryIndex, referenceNode.Begin() + (size_t)
- distinctSamples[i]);
+ BaseCase(queryIndex,
+ referenceNode.Descendant(distinctSamples[j]));
}
// Update the number of samples made for the queryNode and also update
@@ -712,17 +688,17 @@
{
// Approximate node by sampling enough number of points for every
// query in the query node.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
+ for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
{
+ const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase'
// function so no book-keeping is required here.
- BaseCase(queryIndex, referenceNode.Begin() +
- (size_t) distinctSamples[i]);
+ BaseCase(queryIndex,
+ referenceNode.Descendant(distinctSamples[j]));
}
// Update the number of samples made for the queryNode and also
@@ -776,7 +752,7 @@
// this step does not change the result of the search since this queryNode
// will never be descended anymore.
queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
- (double) referenceNode.Count());
+ (double) referenceNode.NumDescendants());
// Since we are not going to descend down the query tree for this reference
// node, there is no point updating the number of samples made for the child
@@ -803,7 +779,7 @@
for (size_t i = 0; i < queryNode.NumPoints(); i++)
{
const double bound = distances(distances.n_rows - 1, queryNode.Point(i))
- + maxDescendantDistance;
+ + maxDescendantDistance;
if (bound < pointBound)
pointBound = bound;
}
@@ -841,9 +817,8 @@
// The number of samples made for a node is propagated up from the
// child nodes if the child nodes have made samples that the parent
// (which is the current 'queryNode') is not aware of.
- queryNode.Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- numSamplesMadeInChildNodes);
+ queryNode.Stat().NumSamplesMade() = std::max(
+ queryNode.Stat().NumSamplesMade(), numSamplesMadeInChildNodes);
}
// Now check if the node-pair interaction can be pruned by sampling
@@ -851,8 +826,8 @@
// If this is better than the best distance we've seen so far,
// maybe there will be something down this node.
// Also check if enough samples are already made for this query.
- if (SortPolicy::IsBetter(oldScore, bestDistance)
- && queryNode.Stat().NumSamplesMade() < numSamplesReqd)
+ if (SortPolicy::IsBetter(oldScore, bestDistance) &&
+ queryNode.Stat().NumSamplesMade() < numSamplesReqd)
{
// We cannot prune this node
// Try approximating this node by sampling
@@ -863,27 +838,27 @@
// So no checks regarding that is made any more.
//
// check if this node can be approximated by sampling
- size_t samplesReqd =
- (size_t) std::ceil(samplingRatio * (double) referenceNode.Count());
- samplesReqd
- = std::min(samplesReqd,
- numSamplesReqd - queryNode.Stat().NumSamplesMade());
+ size_t samplesReqd = (size_t) std::ceil(
+ samplingRatio * (double) referenceNode.NumDescendants());
+ samplesReqd = std::min(samplesReqd,
+ numSamplesReqd - queryNode.Stat().NumSamplesMade());
if (samplesReqd > singleSampleLimit && !referenceNode.IsLeaf())
{
- // if too many samples required and not at a leaf, then can't prune
+ // If too many samples are required and we are not at a leaf, then we
+ // can't prune.
- // Since query tree descend is necessary now,
- // propagate the number of samples made down to the children
+ // Since query tree descent is necessary now, propagate the number of
+ // samples made down to the children.
// Go through all children and propagate the number of
// samples made to the children.
// Only update if the parent node has made samples the children
// have not seen
for (size_t i = 0; i < queryNode.NumChildren(); i++)
- queryNode.Child(i).Stat().NumSamplesMade()
- = std::max(queryNode.Stat().NumSamplesMade(),
- queryNode.Child(i).Stat().NumSamplesMade());
+ queryNode.Child(i).Stat().NumSamplesMade() = std::max(
+ queryNode.Stat().NumSamplesMade(),
+ queryNode.Child(i).Stat().NumSamplesMade());
return oldScore;
}
@@ -894,17 +869,16 @@
// Then samplesReqd <= singleSampleLimit.
// Hence approximate node by sampling enough number of points
// for every query in the 'queryNode'
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
+ for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
{
+ const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
- distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
+ distinctSamples);
+ for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase'
// function so no book-keeping is required here.
- BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
+ BaseCase(queryIndex, referenceNode.Descendant(distinctSamples[j]));
}
// update the number of samples made for the queryNode and
@@ -924,17 +898,17 @@
{
// Approximate node by sampling enough number of points
// for every query in the 'queryNode'.
- for (size_t queryIndex = queryNode.Begin();
- queryIndex < queryNode.End(); queryIndex++)
+ for (size_t i = 0; i < queryNode.NumDescendants(); ++i)
{
+ const size_t queryIndex = queryNode.Descendant(i);
arma::uvec distinctSamples;
- ObtainDistinctSamples(samplesReqd, referenceNode.Count(),
+ ObtainDistinctSamples(samplesReqd, referenceNode.NumDescendants(),
distinctSamples);
- for (size_t i = 0; i < distinctSamples.n_elem; i++)
+ for (size_t j = 0; j < distinctSamples.n_elem; j++)
// The counting of the samples are done in the 'BaseCase'
// function so no book-keeping is required here.
BaseCase(queryIndex,
- referenceNode.Begin() + (size_t) distinctSamples[i]);
+ referenceNode.Descendant(distinctSamples[j]));
}
// Update the number of samples made for the query node and also
@@ -972,7 +946,7 @@
// this step does not change the result of the search since this query node
// will never be descended anymore.
queryNode.Stat().NumSamplesMade() += (size_t) std::floor(samplingRatio *
- (double) referenceNode.Count());
+ (double) referenceNode.NumDescendants());
// Since we are not going to descend down the query tree for this reference
// node, there is no point updating the number of samples made for the child
More information about the mlpack-svn
mailing list