[mlpack-git] master: Remove specialization of NeighborSearchRules for SpillTrees. (e4ce9be)
gitdub at mlpack.org
gitdub at mlpack.org
Wed Aug 17 02:33:04 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit e4ce9be6795817f7ceebafc26dd2b24c38a356f9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Wed Aug 17 03:33:04 2016 -0300
Remove specialization of NeighborSearchRules for SpillTrees.
>---------------------------------------------------------------
e4ce9be6795817f7ceebafc26dd2b24c38a356f9
.../spill_tree/spill_dual_tree_traverser_impl.hpp | 349 ++++++++--------
.../spill_single_tree_traverser_impl.hpp | 22 +-
src/mlpack/methods/neighbor_search/CMakeLists.txt | 2 -
.../neighbor_search/neighbor_search_rules.hpp | 6 +-
.../methods/neighbor_search/spill_search_rules.hpp | 229 -----------
.../neighbor_search/spill_search_rules_impl.hpp | 440 ---------------------
6 files changed, 195 insertions(+), 853 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
index c6deca2..6853b0d 100644
--- a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
@@ -102,80 +102,91 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
}
else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
{
- // We have to recurse down the reference node. In this case the recursion
- // order does matter. Before recursing, though, we have to set the
- // traversal information correctly.
- double leftScore = rule.Score(queryNode, *referenceNode.Left());
- typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = traversalInfo;
- double rightScore = rule.Score(queryNode, *referenceNode.Right());
- numScores += 2;
-
- if (leftScore < rightScore)
+ if (Defeatist && referenceNode.Overlap())
{
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- traversalInfo = rule.TraversalInfo();
- rule.TraversalInfo() = leftInfo;
- Traverse(queryNode, *referenceNode.Left());
+ // If referenceNode is a overlapping node let's do defeatist search.
+ bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(queryNode);
+ bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+ queryNode);
+ if (traverseLeft && !traverseRight)
+ Traverse(queryNode, *referenceNode.Left());
+ else if (!traverseLeft && traverseRight)
+ Traverse(queryNode, *referenceNode.Right());
+ else
+ {
+ // If we can't decide which child node to traverse, this means that
+ // queryNode is at both sides of the splitting hyperplane. So, as
+ // queryNode is a leafNode, all we can do is single tree search for each
+ // point in the query node.
+ const size_t queryEnd = queryNode.NumPoints();
+ DefeatistSingleTreeTraverser<RuleType> st(rule);
+ // Loop through each of the points in query node.
+ for (size_t query = 0; query < queryEnd; ++query)
+ {
+ const size_t queryIndex = queryNode.Point(query);
+ // See if we need to investigate this point.
+ const double childScore = rule.Score(queryIndex, referenceNode);
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+ if (childScore == DBL_MAX)
+ continue; // We can't improve this particular point.
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- rule.TraversalInfo() = traversalInfo;
- Traverse(queryNode, *referenceNode.Right());
+ st.Traverse(queryIndex, referenceNode);
+ }
}
- else
- ++numPrunes;
}
- else if (rightScore < leftScore)
+ else
{
- // Recurse to the right.
- Traverse(queryNode, *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+ // We have to recurse down the reference node. In this case the recursion
+ // order does matter. Before recursing, though, we have to set the
+ // traversal information correctly.
+ double leftScore = rule.Score(queryNode, *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ double rightScore = rule.Score(queryNode, *referenceNode.Right());
+ numScores += 2;
- if (leftScore != DBL_MAX)
+ if (leftScore < rightScore)
{
- // Restore the left traversal info.
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ traversalInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(queryNode, *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = traversalInfo;
+ Traverse(queryNode, *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
}
- else
- ++numPrunes;
- }
- else // leftScore is equal to rightScore.
- {
- if (leftScore == DBL_MAX)
+ else if (rightScore < leftScore)
{
- numPrunes += 2;
+ // Recurse to the right.
+ Traverse(queryNode, *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(queryNode, *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
}
- else
+ else // leftScore is equal to rightScore.
{
- if (Defeatist && referenceNode.Overlap())
+ if (leftScore == DBL_MAX)
{
- // If referenceNode is a overlapping node and we can't decide which
- // child node to traverse, this means that queryNode is at both sides
- // of the splitting hyperplane. So, as queryNode is a leafNode, all we
- // can do is single tree search for each point in the query node.
- const size_t queryEnd = queryNode.NumPoints();
- SingleTreeTraverser<RuleType> st(rule);
- // Loop through each of the points in query node.
- for (size_t query = 0; query < queryEnd; ++query)
- {
- const size_t queryIndex = queryNode.Point(query);
- // See if we need to investigate this point.
- const double childScore = rule.Score(queryIndex, referenceNode);
-
- if (childScore == DBL_MAX)
- continue; // We can't improve this particular point.
-
- st.Traverse(queryIndex, referenceNode);
- }
+ numPrunes += 2;
}
else
{
@@ -202,71 +213,98 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
}
else
{
- // We have to recurse down both query and reference nodes. Because the
- // query descent order does not matter, we will go to the left query child
- // first. Before recursing, we have to set the traversal information
- // correctly.
- double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
- typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = traversalInfo;
- double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
- typename RuleType::TraversalInfoType rightInfo;
- numScores += 2;
-
- if (leftScore < rightScore)
+ if (Defeatist && referenceNode.Overlap())
{
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- rightInfo = rule.TraversalInfo();
- rule.TraversalInfo() = leftInfo;
- Traverse(*queryNode.Left(), *referenceNode.Left());
-
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
- rightScore);
-
- if (rightScore != DBL_MAX)
- {
- // Restore the right traversal info.
- rule.TraversalInfo() = rightInfo;
+ // If referenceNode is a overlapping node let's do defeatist search.
+ bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
+ *queryNode.Left());
+ bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+ *queryNode.Left());
+ if (traverseLeft && !traverseRight)
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+ else if (!traverseLeft && traverseRight)
Traverse(*queryNode.Left(), *referenceNode.Right());
+ else
+ {
+ // If we can't decide which child node to traverse, this means that
+ // queryNode.Left() is at both sides of the splitting hyperplane. So,
+ // let's recurse down only the query node.
+ Traverse(*queryNode.Left(), referenceNode);
}
+
+ traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
+ *queryNode.Right());
+ traverseRight = referenceNode.Right()->HalfSpaceIntersects(
+ *queryNode.Right());
+ if (traverseLeft && !traverseRight)
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+ else if (!traverseLeft && traverseRight)
+ Traverse(*queryNode.Right(), *referenceNode.Right());
else
- ++numPrunes;
+ {
+ // If we can't decide which child node to traverse, this means that
+ // queryNode.Right() is at both sides of the splitting hyperplane. So,
+ // let's recurse down only the query node.
+ Traverse(*queryNode.Right(), referenceNode);
+ }
}
- else if (rightScore < leftScore)
+ else
{
- // Recurse to the right.
- Traverse(*queryNode.Left(), *referenceNode.Right());
-
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
- leftScore);
+ // We have to recurse down both query and reference nodes. Because the
+ // query descent order does not matter, we will go to the left query child
+ // first. Before recursing, we have to set the traversal information
+ // correctly.
+ double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
+ typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
+ typename RuleType::TraversalInfoType rightInfo;
+ numScores += 2;
- if (leftScore != DBL_MAX)
+ if (leftScore < rightScore)
{
- // Restore the left traversal info.
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(*queryNode.Left(), *referenceNode.Left());
+
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
+ rightScore);
+
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
}
- else
- ++numPrunes;
- }
- else
- {
- if (leftScore == DBL_MAX)
+ else if (rightScore < leftScore)
{
- numPrunes += 2;
+ // Recurse to the right.
+ Traverse(*queryNode.Left(), *referenceNode.Right());
+
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
+ leftScore);
+
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Left(), *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
}
else
{
- if (Defeatist && referenceNode.Overlap())
+ if (leftScore == DBL_MAX)
{
- // If referenceNode is a overlapping node and we can't decide which
- // child node to traverse, this means that queryNode.Left() is at both
- // sides of the splitting hyperplane. So, let's recurse down only the
- // query node.
- Traverse(*queryNode.Left(), referenceNode);
+ numPrunes += 2;
}
else
{
@@ -290,72 +328,61 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
++numPrunes;
}
}
- }
- // Restore the main traversal information.
- rule.TraversalInfo() = traversalInfo;
+ // Restore the main traversal information.
+ rule.TraversalInfo() = traversalInfo;
- // Now recurse down the right query node.
- leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
- leftInfo = rule.TraversalInfo();
- rule.TraversalInfo() = traversalInfo;
- rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
- numScores += 2;
+ // Now recurse down the right query node.
+ leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
+ leftInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = traversalInfo;
+ rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
+ numScores += 2;
- if (leftScore < rightScore)
- {
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
- rightInfo = rule.TraversalInfo();
- rule.TraversalInfo() = leftInfo;
- Traverse(*queryNode.Right(), *referenceNode.Left());
+ if (leftScore < rightScore)
+ {
+ // Recurse to the left. Restore the left traversal info. Store the right
+ // traversal info.
+ rightInfo = rule.TraversalInfo();
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left());
- // Is it still valid to recurse to the right?
- rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
- rightScore);
+ // Is it still valid to recurse to the right?
+ rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
+ rightScore);
- if (rightScore != DBL_MAX)
+ if (rightScore != DBL_MAX)
+ {
+ // Restore the right traversal info.
+ rule.TraversalInfo() = rightInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Right());
+ }
+ else
+ ++numPrunes;
+ }
+ else if (rightScore < leftScore)
{
- // Restore the right traversal info.
- rule.TraversalInfo() = rightInfo;
+ // Recurse to the right.
Traverse(*queryNode.Right(), *referenceNode.Right());
- }
- else
- ++numPrunes;
- }
- else if (rightScore < leftScore)
- {
- // Recurse to the right.
- Traverse(*queryNode.Right(), *referenceNode.Right());
- // Is it still valid to recurse to the left?
- leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
- leftScore);
+ // Is it still valid to recurse to the left?
+ leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
+ leftScore);
- if (leftScore != DBL_MAX)
- {
- // Restore the left traversal info.
- rule.TraversalInfo() = leftInfo;
- Traverse(*queryNode.Right(), *referenceNode.Left());
- }
- else
- ++numPrunes;
- }
- else
- {
- if (leftScore == DBL_MAX)
- {
- numPrunes += 2;
+ if (leftScore != DBL_MAX)
+ {
+ // Restore the left traversal info.
+ rule.TraversalInfo() = leftInfo;
+ Traverse(*queryNode.Right(), *referenceNode.Left());
+ }
+ else
+ ++numPrunes;
}
else
{
- if (Defeatist && referenceNode.Overlap())
+ if (leftScore == DBL_MAX)
{
- // If referenceNode is a overlapping node and we can't decide which
- // child node to traverse, this means that queryNode.Right() is at
- // both sides of the splitting hyperplane. So, let's recurse down only
- // the query node.
- Traverse(*queryNode.Right(), referenceNode);
+ numPrunes += 2;
}
else
{
diff --git a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
index 4ce5a22..0d28969 100644
--- a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
@@ -53,27 +53,13 @@ SpillSingleTreeTraverser<RuleType, Defeatist>::Traverse(
{
if (Defeatist && referenceNode.Overlap())
{
- // If referenceNode is a overlapping node we do defeatist search. In this
- // case, it is enough to calculate the score of only one child node. As we
- // know that the query point can't be at both sides of the splitting
- // hyperplane, the possible scores for the references child nodes are:
- // 0 or DBL_MAX.
- double leftScore = rule.Score(queryIndex, *referenceNode.Left());
-
- if (leftScore == 0)
- {
- // Recurse to the left.
+ // If referenceNode is a overlapping node we do defeatist search.
+ if (referenceNode.Left()->HalfSpaceContains(
+ rule.QuerySet().col(queryIndex)))
Traverse(queryIndex, *referenceNode.Left());
- // Prune the right node.
- ++numPrunes;
- }
else
- {
- // Recurse to the right.
Traverse(queryIndex, *referenceNode.Right());
- // Prune the left node.
- ++numPrunes;
- }
+ ++numPrunes;
}
else
{
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index 0903508..e66b5a7 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -14,8 +14,6 @@ set(SOURCES
sort_policies/furthest_neighbor_sort_impl.hpp
spill_search.hpp
spill_search_impl.hpp
- spill_search_rules.hpp
- spill_search_rules_impl.hpp
typedef.hpp
unmap.hpp
unmap.cpp
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index b16a9ea..e7a7ce1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -134,6 +134,9 @@ class NeighborSearchRules
//! Modify the traversal info.
TraversalInfoType& TraversalInfo() { return traversalInfo; }
+ //! Access the query set.
+ const typename TreeType::Mat& QuerySet() { return querySet; }
+
protected:
//! The reference set.
const typename TreeType::Mat& referenceSet;
@@ -210,7 +213,4 @@ class NeighborSearchRules
// Include implementation.
#include "neighbor_search_rules_impl.hpp"
-// Include specialization for Spill Trees.
-#include "spill_search_rules.hpp"
-
#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
deleted file mode 100644
index a87578c..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_rules.hpp
+++ /dev/null
@@ -1,229 +0,0 @@
-/**
- * @file spill_search_rules.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Defines the pruning rules and base case rules necessary to perform a
- * tree-based search with Spill Trees for the NeighborSearch class.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
-
-#include <mlpack/core/tree/traversal_info.hpp>
-#include <mlpack/core/tree/spill_tree.hpp>
-
-namespace mlpack {
-namespace neighbor {
-
-/**
- * NeighborSearchRules specialization for Spill Trees.
- * The main difference with the general implementation is that Score() methods
- * consider the special case of a overlapping node.
- * Also, CalculateBound() only considers B_1 bound, because we can not use B_2
- * with spill trees.
- *
- * @tparam SortPolicy The sort policy for distances.
- * @tparam MetricType The metric to use for computation.
- * @tparam TreeType The tree type to use; must adhere to the TreeType API.
- */
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-class NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, HyperplaneType, SplitType>>
-{
- typedef tree::SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
- SplitType> TreeType;
- public:
- /**
- * Construct the NeighborSearchRules object. This is usually done from within
- * the NeighborSearch class at search time.
- *
- * @param referenceSet Set of reference data.
- * @param querySet Set of query data.
- * @param k Number of neighbors to search for.
- * @param metric Instantiated metric.
- * @param epsilon Relative approximate error.
- * @param sameSet If true, the query and reference set are taken to be the
- * same, and a query point will not return itself in the results.
- */
- NeighborSearchRules(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const size_t k,
- MetricType& metric,
- const double epsilon = 0,
- const bool sameSet = false);
-
- /**
- * Store the list of candidates for each query point in the given matrices.
- *
- * @param neighbors Matrix storing lists of neighbors for each query point.
- * @param distances Matrix storing distances of neighbors for each query
- * point.
- */
- void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
-
- /**
- * Get the distance from the query point to the reference point.
- * This will update the list of candidates with the new point if appropriate
- * and will track the number of base cases (number of points evaluated).
- *
- * @param queryIndex Index of query point.
- * @param referenceIndex Index of reference point.
- */
- double BaseCase(const size_t queryIndex, const size_t referenceIndex);
-
- /**
- * Get the score for recursion order. It implements a Hybrid sp-tree
- * search. If referenceNode's parent is a overlapping node, the score is
- * calculated based on the splitting hyperplane: if query point is on the same
- * side, returns 0, else DBL_MAX.
- * If referenceNode's parent is a non-overlapping node, proper score is
- * calculated, similar to the general Score() method.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- */
- double Score(const size_t queryIndex, TreeType& referenceNode);
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryIndex Index of query point.
- * @param referenceNode Candidate node to be recursed into.
- * @param oldScore Old score produced by Score() (or Rescore()).
- */
- double Rescore(const size_t queryIndex,
- TreeType& referenceNode,
- const double oldScore) const;
-
- /**
- * Get the score for recursion order. It implements a Hybrid sp-tree
- * search. If referenceNode's parent is a overlapping node, the score is
- * calculated based on the splitting hyperplane: if queryNode's bound
- * intersects the referenceNode's half space, returns 0, else DBL_MAX.
- * If referenceNode's parent is a non-overlapping node, proper score is
- * calculated, similar to the general Score() method.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- */
- double Score(TreeType& queryNode, TreeType& referenceNode);
-
- /**
- * Re-evaluate the score for recursion order. A low score indicates priority
- * for recursion, while DBL_MAX indicates that the node should not be recursed
- * into at all (it should be pruned). This is used when the score has already
- * been calculated, but another recursion may have modified the bounds for
- * pruning. So the old score is checked against the new pruning bound.
- *
- * @param queryNode Candidate query node to recurse into.
- * @param referenceNode Candidate reference node to recurse into.
- * @param oldScore Old score produced by Socre() (or Rescore()).
- */
- double Rescore(TreeType& queryNode,
- TreeType& referenceNode,
- const double oldScore) const;
-
- //! Get the number of base cases that have been performed.
- size_t BaseCases() const { return baseCases; }
- //! Modify the number of base cases that have been performed.
- size_t& BaseCases() { return baseCases; }
-
- //! Get the number of scores that have been performed.
- size_t Scores() const { return scores; }
- //! Modify the number of scores that have been performed.
- size_t& Scores() { return scores; }
-
- //! Convenience typedef.
- typedef typename tree::TraversalInfo<TreeType> TraversalInfoType;
-
- //! Get the traversal info.
- const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
- //! Modify the traversal info.
- TraversalInfoType& TraversalInfo() { return traversalInfo; }
-
- protected:
- //! The reference set.
- const typename TreeType::Mat& referenceSet;
-
- //! The query set.
- const typename TreeType::Mat& querySet;
-
- //! Candidate represents a possible candidate neighbor (distance, index).
- typedef std::pair<double, size_t> Candidate;
-
- //! Compare two candidates based on the distance.
- struct CandidateCmp {
- bool operator()(const Candidate& c1, const Candidate& c2)
- {
- return !SortPolicy::IsBetter(c2.first, c1.first);
- };
- };
-
- //! Use a priority queue to represent the list of candidate neighbors.
- typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
- CandidateList;
-
- //! Set of candidate neighbors for each point.
- std::vector<CandidateList> candidates;
-
- //! Number of neighbors to search for.
- const size_t k;
-
- //! The instantiated metric.
- MetricType& metric;
-
- //! Denotes whether or not the reference and query sets are the same.
- bool sameSet;
-
- //! Relative error to be considered in approximate search.
- const double epsilon;
-
- //! The last query point BaseCase() was called with.
- size_t lastQueryIndex;
- //! The last reference point BaseCase() was called with.
- size_t lastReferenceIndex;
- //! The last base case result.
- double lastBaseCase;
-
- //! The number of base cases that have been performed.
- size_t baseCases;
- //! The number of scores that have been performed.
- size_t scores;
-
- //! Traversal info for the parent combination; this is updated by the
- //! traversal before each call to Score().
- TraversalInfoType traversalInfo;
-
- /**
- * Recalculate the bound for a given query node.
- */
- double CalculateBound(TreeType& queryNode) const;
-
- /**
- * Helper function to insert a point into the list of candidate points.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
- void InsertNeighbor(const size_t queryIndex,
- const size_t neighbor,
- const double distance);
-};
-
-} // namespace neighbor
-} // namespace mlpack
-
-// Include implementation.
-#include "spill_search_rules_impl.hpp"
-
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_HPP
diff --git a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
deleted file mode 100644
index 50a61ac..0000000
--- a/src/mlpack/methods/neighbor_search/spill_search_rules_impl.hpp
+++ /dev/null
@@ -1,440 +0,0 @@
-/**
- * @file spill_search_rules_impl.hpp
- * @author Ryan Curtin
- * @author Marcos Pividori
- *
- * Implementation of NeighborSearchRules for Spill Trees.
- */
-#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP
-#define MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP
-
-// In case it hasn't been included yet.
-#include "spill_search_rules.hpp"
-
-namespace mlpack {
-namespace neighbor {
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, HyperplaneType, SplitType>>::NeighborSearchRules(
- const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const size_t k,
- MetricType& metric,
- const double epsilon,
- const bool sameSet) :
- referenceSet(referenceSet),
- querySet(querySet),
- k(k),
- metric(metric),
- sameSet(sameSet),
- epsilon(epsilon),
- lastQueryIndex(querySet.n_cols),
- lastReferenceIndex(referenceSet.n_cols),
- baseCases(0),
- scores(0)
-{
- // We must set the traversal info last query and reference node pointers to
- // something that is both invalid (i.e. not a tree node) and not NULL. We'll
- // use the this pointer.
- traversalInfo.LastQueryNode() = (TreeType*) this;
- traversalInfo.LastReferenceNode() = (TreeType*) this;
-
- // Let's build the list of candidate neighbors for each query point.
- // It will be initialized with k candidates: (WorstDistance, size_t() - 1)
- // The list of candidates will be updated when visiting new points with the
- // BaseCase() method.
- const Candidate def = std::make_pair(SortPolicy::WorstDistance(),
- size_t() - 1);
-
- std::vector<Candidate> vect(k, def);
- CandidateList pqueue(CandidateCmp(), std::move(vect));
-
- candidates.reserve(querySet.n_cols);
- for (size_t i = 0; i < querySet.n_cols; i++)
- candidates.push_back(pqueue);
-}
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, HyperplaneType, SplitType>>::GetResults(
- arma::Mat<size_t>& neighbors,
- arma::mat& distances)
-{
- neighbors.set_size(k, querySet.n_cols);
- distances.set_size(k, querySet.n_cols);
-
- for (size_t i = 0; i < querySet.n_cols; i++)
- {
- CandidateList& pqueue = candidates[i];
- for (size_t j = 1; j <= k; j++)
- {
- neighbors(k - j, i) = pqueue.top().second;
- distances(k - j, i) = pqueue.top().first;
- pqueue.pop();
- }
- }
-};
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline force_inline // Absolutely MUST be inline so optimizations can happen.
-double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<MetricType,
- StatisticType, MatType, HyperplaneType, SplitType>>::BaseCase(
- const size_t queryIndex,
- const size_t referenceIndex)
-{
- // If the datasets are the same, then this search is only using one dataset
- // and we should not return identical points.
- if (sameSet && (queryIndex == referenceIndex))
- return 0.0;
-
- double distance = metric.Evaluate(querySet.col(queryIndex),
- referenceSet.col(referenceIndex));
- ++baseCases;
-
- InsertNeighbor(queryIndex, referenceIndex, distance);
-
- return distance;
-}
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
- const size_t queryIndex,
- TreeType& referenceNode)
-{
- ++scores; // Count number of Score() calls.
-
- if (!referenceNode.Parent())
- return 0;
-
- if (referenceNode.Parent()->Overlap()) // Defeatist search.
- {
- if (referenceNode.HalfSpaceContains(querySet.col(queryIndex)))
- return 0;
- else
- return DBL_MAX;
- }
-
- double distance = SortPolicy::BestPointToNodeDistance(
- querySet.col(queryIndex), &referenceNode);
-
- // Compare against the best k'th distance for this query point so far.
- double bestDistance = candidates[queryIndex].top().first;
- bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
- return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
-}
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
- const size_t queryIndex,
- TreeType& /* referenceNode */,
- double oldScore) const
-{
- // If we are already pruning, still prune.
- if (oldScore == DBL_MAX)
- return oldScore;
-
- // Just check the score again against the distances.
- double bestDistance = candidates[queryIndex].top().first;
- bestDistance = SortPolicy::Relax(bestDistance, epsilon);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Score(
- TreeType& queryNode,
- TreeType& referenceNode)
-{
- ++scores; // Count number of Score() calls
-
- if (!referenceNode.Parent())
- return 0;
-
- if (referenceNode.Parent()->Overlap()) // Defeatist search.
- {
- if (referenceNode.HalfSpaceIntersects(queryNode))
- return 0;
- else
- return DBL_MAX;
- }
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- // Use the traversal info to see if a parent-child or parent-parent prune is
- // possible. This is a looser bound than we could make, but it might be
- // sufficient.
- const double queryParentDist = queryNode.ParentDistance();
- const double queryDescDist = queryNode.FurthestDescendantDistance();
- const double refParentDist = referenceNode.ParentDistance();
- const double refDescDist = referenceNode.FurthestDescendantDistance();
- const double score = traversalInfo.LastScore();
- double adjustedScore;
-
- if (score == 0.0) // Nothing we can do here.
- {
- adjustedScore = 0.0;
- }
- else
- {
- // The last score is equal to the distance between the centroids minus the
- // radii of the query and reference bounds along the axis of the line
- // between the two centroids. In the best case, these radii are the
- // furthest descendant distances, but that is not always true. It would
- // take too long to calculate the exact radii, so we are forced to use
- // MinimumBoundDistance() as a lower-bound approximation.
- const double lastQueryDescDist =
- traversalInfo.LastQueryNode()->MinimumBoundDistance();
- const double lastRefDescDist =
- traversalInfo.LastReferenceNode()->MinimumBoundDistance();
- adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist);
- adjustedScore = SortPolicy::CombineWorst(adjustedScore, lastRefDescDist);
- }
-
- // Assemble an adjusted score. For nearest neighbor search, this adjusted
- // score is a lower bound on MinDistance(queryNode, referenceNode) that is
- // assembled without actually calculating MinDistance(). For furthest
- // neighbor search, it is an upper bound on
- // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable
- // then the node should not be pruned by this.
- if (traversalInfo.LastQueryNode() == queryNode.Parent())
- {
- const double queryAdjust = queryParentDist + queryDescDist;
- adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust);
- }
- else if (traversalInfo.LastQueryNode() == &queryNode)
- {
- adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist);
- }
- else
- {
- // The last query node wasn't this query node or its parent. So we force
- // the adjustedScore to be such that this combination can't be pruned here,
- // because we don't really know anything about it.
-
- // It would be possible to modify this section to try and make a prune based
- // on the query descendant distance and the distance between the query node
- // and last traversal query node, but this case doesn't actually happen for
- // kd-trees or cover trees.
- adjustedScore = SortPolicy::BestDistance();
- }
-
- if (traversalInfo.LastReferenceNode() == referenceNode.Parent())
- {
- const double refAdjust = refParentDist + refDescDist;
- adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust);
- }
- else if (traversalInfo.LastReferenceNode() == &referenceNode)
- {
- adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist);
- }
- else
- {
- // The last reference node wasn't this reference node or its parent. So we
- // force the adjustedScore to be such that this combination can't be pruned
- // here, because we don't really know anything about it.
-
- // It would be possible to modify this section to try and make a prune based
- // on the reference descendant distance and the distance between the
- // reference node and last traversal reference node, but this case doesn't
- // actually happen for kd-trees or cover trees.
- adjustedScore = SortPolicy::BestDistance();
- }
-
- // Can we prune?
- if (!SortPolicy::IsBetter(adjustedScore, bestDistance))
- {
- // There isn't any need to set the traversal information because no
- // descendant combinations will be visited, and those are the only
- // combinations that would depend on the traversal information.
- return DBL_MAX;
- }
-
- double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
- &referenceNode);
-
- if (SortPolicy::IsBetter(distance, bestDistance))
- {
- // Set traversal information.
- traversalInfo.LastQueryNode() = &queryNode;
- traversalInfo.LastReferenceNode() = &referenceNode;
- traversalInfo.LastScore() = distance;
-
- return distance;
- }
- else
- {
- // There isn't any need to set the traversal information because no
- // descendant combinations will be visited, and those are the only
- // combinations that would depend on the traversal information.
- return DBL_MAX;
- }
-}
-
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::Rescore(
- TreeType& queryNode,
- TreeType& /* referenceNode */,
- const double oldScore) const
-{
- if (oldScore == DBL_MAX)
- return oldScore;
-
- if (oldScore == SortPolicy::BestDistance())
- return oldScore;
-
- // Update our bound.
- const double bestDistance = CalculateBound(queryNode);
-
- return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
-}
-
-// Calculate the bound for a given query node in its current state and update
-// it.
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline double NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
- CalculateBound(TreeType& queryNode) const
-{
- // This is an adapted form of the B(N_q) function in the paper
- // ``Tree-Independent Dual-Tree Algorithms'' by Curtin et. al.; the goal is to
- // place a bound on the worst possible distance a point combination could have
- // to improve any of the current neighbor estimates. If the best possible
- // distance between two nodes is greater than this bound, then the node
- // combination can be pruned (see Score()).
-
- // There are a couple ways we can assemble a bound. For simplicity, this is
- // described for nearest neighbor search (SortPolicy = NearestNeighborSort),
- // but the code that is written is adapted for whichever SortPolicy.
-
- // First, we can consider the current worst neighbor candidate distance of any
- // descendant point. This is assembled with 'worstDistance' by looping
- // through the points held by the query node, and then by taking the cached
- // worst distance from any child nodes (Stat().FirstBound()). This
- // corresponds roughly to B_1(N_q) in the paper.
-
- double worstDistance = SortPolicy::BestDistance();
-
- // Loop over points held in the node.
- for (size_t i = 0; i < queryNode.NumPoints(); ++i)
- {
- const double distance = candidates[queryNode.Point(i)].top().first;
- if (SortPolicy::IsBetter(worstDistance, distance))
- worstDistance = distance;
- }
-
- // Loop over children of the node, and use their cached information to
- // assemble bounds.
- for (size_t i = 0; i < queryNode.NumChildren(); ++i)
- {
- const double firstBound = queryNode.Child(i).Stat().FirstBound();
-
- if (SortPolicy::IsBetter(worstDistance, firstBound))
- worstDistance = firstBound;
- }
-
- // At this point, worstDistance holds the value of B_1(N_q).
-
- // Now consider the parent bounds.
- if (queryNode.Parent() != NULL)
- {
- // The parent's worst distance bound implies that the bound for this node
- // must be at least as good. Thus, if the parent worst distance bound is
- // better, then take it.
- if (SortPolicy::IsBetter(queryNode.Parent()->Stat().FirstBound(),
- worstDistance))
- worstDistance = queryNode.Parent()->Stat().FirstBound();
- }
-
- // Could the existing bounds be better?
- if (SortPolicy::IsBetter(queryNode.Stat().FirstBound(), worstDistance))
- worstDistance = queryNode.Stat().FirstBound();
-
- // Cache bounds for later.
- queryNode.Stat().FirstBound() = worstDistance;
-
- worstDistance = SortPolicy::Relax(worstDistance, epsilon);
-
- return worstDistance;
-}
-
-/**
- * Helper function to insert a point into the list of candidate points.
- *
- * @param queryIndex Index of point whose neighbors we are inserting into.
- * @param neighbor Index of reference point which is being inserted.
- * @param distance Distance from query point to reference point.
- */
-template<typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitBoundT, typename SplitMatT> class SplitType,
- typename SortPolicy,
- typename MetricType>
-inline void NeighborSearchRules<SortPolicy, MetricType, tree::SpillTree<
- MetricType, StatisticType, MatType, HyperplaneType, SplitType>>::
- InsertNeighbor(
- const size_t queryIndex,
- const size_t neighbor,
- const double distance)
-{
- CandidateList& pqueue = candidates[queryIndex];
- Candidate c = std::make_pair(distance, neighbor);
-
- if (CandidateCmp()(c, pqueue.top()))
- {
- pqueue.pop();
- pqueue.push(c);
- }
-}
-
-} // namespace neighbor
-} // namespace mlpack
-
-#endif // MLPACK_METHODS_NEIGHBOR_SEARCH_SPILL_SEARCH_RULES_IMPL_HPP
More information about the mlpack-git
mailing list