[mlpack-git] master: Simplify spill dual tree traverser to use GetBestChild. (bfd66ec)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit bfd66ec8198dedaa20f0a51cd4b7e5b36bd0f8c5
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Fri Aug 19 14:47:45 2016 -0300
Simplify spill dual tree traverser to use GetBestChild.
>---------------------------------------------------------------
bfd66ec8198dedaa20f0a51cd4b7e5b36bd0f8c5
.../spill_tree/spill_dual_tree_traverser_impl.hpp | 57 ++++++++++------------
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 9 ----
.../core/tree/spill_tree/spill_tree_impl.hpp | 41 ----------------
.../neighbor_search/neighbor_search_rules.hpp | 11 +++--
.../neighbor_search/neighbor_search_rules_impl.hpp | 8 +++
.../sort_policies/furthest_neighbor_sort.hpp | 11 +++++
.../sort_policies/nearest_neighbor_sort.hpp | 11 +++++
7 files changed, 65 insertions(+), 83 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
index 6853b0d..c331451 100644
--- a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
@@ -105,13 +105,12 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (Defeatist && referenceNode.Overlap())
{
// If referenceNode is a overlapping node let's do defeatist search.
- bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(queryNode);
- bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
- queryNode);
- if (traverseLeft && !traverseRight)
- Traverse(queryNode, *referenceNode.Left());
- else if (!traverseLeft && traverseRight)
- Traverse(queryNode, *referenceNode.Right());
+ SpillTree* bestChild = rule.GetBestChild(queryNode, referenceNode);
+ if (bestChild)
+ {
+ Traverse(queryNode, *bestChild);
+ ++numPrunes;
+ }
else
{
// If we can't decide which child node to traverse, this means that
@@ -147,14 +146,15 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (leftScore < rightScore)
{
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
+ // Recurse to the left. Restore the left traversal info. Store the
+ // right traversal info.
traversalInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(queryNode, *referenceNode.Left());
// Is it still valid to recurse to the right?
- rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
+ rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
+ rightScore);
if (rightScore != DBL_MAX)
{
@@ -216,14 +216,13 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (Defeatist && referenceNode.Overlap())
{
// If referenceNode is a overlapping node let's do defeatist search.
- bool traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
- *queryNode.Left());
- bool traverseRight = referenceNode.Right()->HalfSpaceIntersects(
- *queryNode.Left());
- if (traverseLeft && !traverseRight)
- Traverse(*queryNode.Left(), *referenceNode.Left());
- else if (!traverseLeft && traverseRight)
- Traverse(*queryNode.Left(), *referenceNode.Right());
+ SpillTree* bestChild = rule.GetBestChild(*queryNode.Left(),
+ referenceNode);
+ if (bestChild)
+ {
+ Traverse(*queryNode.Left(), *bestChild);
+ ++numPrunes;
+ }
else
{
// If we can't decide which child node to traverse, this means that
@@ -232,14 +231,12 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
Traverse(*queryNode.Left(), referenceNode);
}
- traverseLeft = referenceNode.Left()->HalfSpaceIntersects(
- *queryNode.Right());
- traverseRight = referenceNode.Right()->HalfSpaceIntersects(
- *queryNode.Right());
- if (traverseLeft && !traverseRight)
- Traverse(*queryNode.Right(), *referenceNode.Left());
- else if (!traverseLeft && traverseRight)
- Traverse(*queryNode.Right(), *referenceNode.Right());
+ bestChild = rule.GetBestChild(*queryNode.Right(), referenceNode);
+ if (bestChild)
+ {
+ Traverse(*queryNode.Right(), *bestChild);
+ ++numPrunes;
+ }
else
{
// If we can't decide which child node to traverse, this means that
@@ -263,8 +260,8 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (leftScore < rightScore)
{
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
+ // Recurse to the left. Restore the left traversal info. Store the
+ // right traversal info.
rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(*queryNode.Left(), *referenceNode.Left());
@@ -341,8 +338,8 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (leftScore < rightScore)
{
- // Recurse to the left. Restore the left traversal info. Store the right
- // traversal info.
+ // Recurse to the left. Restore the left traversal info. Store the
+ // right traversal info.
rightInfo = rule.TraversalInfo();
rule.TraversalInfo() = leftInfo;
Traverse(*queryNode.Right(), *referenceNode.Left());
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index 42cd002..452ebfb 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -359,15 +359,6 @@ class SpillTree
*/
size_t Point(const size_t index) const;
- //! Determines if the node's half space intersects the given node.
- bool HalfSpaceIntersects(const SpillTree& other) const;
-
- //! Determines if the node's half space contains the given point.
- template<typename VecType>
- bool HalfSpaceContains(
- const VecType& point,
- typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
-
//! Return the minimum distance to another node.
ElemType MinDistance(const SpillTree* other) const
{
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 42ba1ba..b60e84e 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -601,47 +601,6 @@ template<typename MetricType,
template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitMetricType, typename SplitMatType>
class SplitType>
-inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
- SplitType>::HalfSpaceIntersects(const SpillTree& other) const
-{
- if (!Parent())
- return true;
-
- const bool left = this == Parent()->Left();
-
- if (left)
- return !Parent()->Hyperplane().Right(other.Bound());
- else
- return !Parent()->Hyperplane().Left(other.Bound());
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitMetricType, typename SplitMatType>
- class SplitType>
-template<typename VecType>
-inline bool SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
- SplitType>::HalfSpaceContains(
- const VecType& point,
- typename boost::enable_if<IsVector<VecType> >::type*) const
-{
- if (!Parent())
- return true;
-
- const bool left = this == Parent()->Left();
- const bool toTheLeft = Parent()->Hyperplane().Left(point);
-
- return left == toTheLeft;
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename HyperplaneMetricType> class HyperplaneType,
- template<typename SplitMetricType, typename SplitMatType>
- class SplitType>
void SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
SplitNode(arma::Col<size_t>& points,
const size_t maxLeafSize,
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index 55d2163..ec61fd1 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -85,6 +85,14 @@ class NeighborSearchRules
TreeType& GetBestChild(const size_t queryIndex, TreeType& referenceNode);
/**
+ * Get the child node with the best score.
+ *
+ * @param queryNode Node to be considered.
+ * @param referenceNode Candidate node to be recursed into.
+ */
+ TreeType* GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
+
+ /**
* Re-evaluate the score for recursion order. A low score indicates priority
* for recursion, while DBL_MAX indicates that the node should not be recursed
* into at all (it should be pruned). This is used when the score has already
@@ -142,9 +150,6 @@ class NeighborSearchRules
//! Modify the traversal info.
TraversalInfoType& TraversalInfo() { return traversalInfo; }
- //! Access the query set.
- const typename TreeType::Mat& QuerySet() { return querySet; }
-
protected:
//! The reference set.
const typename TreeType::Mat& referenceSet;
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index e06683a..6bf011b 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -154,6 +154,14 @@ GetBestChild(const size_t queryIndex, TreeType& referenceNode)
}
template<typename SortPolicy, typename MetricType, typename TreeType>
+inline TreeType* NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
+{
+ ++scores;
+ return SortPolicy::GetBestChild(queryNode, referenceNode);
+}
+
+template<typename SortPolicy, typename MetricType, typename TreeType>
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
const size_t queryIndex,
TreeType& /* referenceNode */,
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
index 2c55a75..4e855a7 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
@@ -106,6 +106,17 @@ class FurthestNeighborSort
};
/**
+ * Return the best child according to this sort policy. In this case it will
+ * return the one with the maximum distance.
+ */
+ template<typename TreeType>
+ static TreeType* GetBestChild(const TreeType& queryNode,
+ TreeType& referenceNode)
+ {
+ return referenceNode.GetFurthestChild(queryNode);
+ };
+
+ /**
* Return what should represent the worst possible distance with this
* particular sort policy. In our case, this should be the minimum possible
* distance, 0.
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
index 90837ce..7a0ac57 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
@@ -110,6 +110,17 @@ class NearestNeighborSort
};
/**
+ * Return the best child according to this sort policy. In this case it will
+ * return the one with the minimum distance.
+ */
+ template<typename TreeType>
+ static TreeType* GetBestChild(const TreeType& queryNode,
+ TreeType& referenceNode)
+ {
+ return referenceNode.GetNearestChild(queryNode);
+ };
+
+ /**
* Return what should represent the worst possible distance with this
* particular sort policy. In our case, this should be the maximum possible
* distance, DBL_MAX.
More information about the mlpack-git
mailing list