[mlpack-git] master: Return index instead of pointers for GetBestChild(). (74ae685)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit 74ae685eeb5f651b1895f69f009419b64f976644
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Fri Aug 19 23:23:01 2016 -0300
Return index instead of pointers for GetBestChild().
>---------------------------------------------------------------
74ae685eeb5f651b1895f69f009419b64f976644
.../tree/binary_space_tree/binary_space_tree.hpp | 24 +++---
.../binary_space_tree/binary_space_tree_impl.hpp | 88 ++++++++------------
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 24 +++---
.../core/tree/cover_tree/cover_tree_impl.hpp | 53 ++++++------
.../tree/greedy_single_tree_traverser_impl.hpp | 4 +-
.../core/tree/rectangle_tree/rectangle_tree.hpp | 24 +++---
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 48 +++++------
.../spill_tree/spill_dual_tree_traverser_impl.hpp | 17 ++--
.../spill_single_tree_traverser_impl.hpp | 4 +-
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 40 ++++-----
.../core/tree/spill_tree/spill_tree_impl.hpp | 96 ++++++++++------------
.../neighbor_search/neighbor_search_rules.hpp | 4 +-
.../neighbor_search/neighbor_search_rules_impl.hpp | 4 +-
.../sort_policies/furthest_neighbor_sort.hpp | 6 +-
.../sort_policies/nearest_neighbor_sort.hpp | 6 +-
15 files changed, 196 insertions(+), 246 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index b291456..9152a49 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -333,34 +333,34 @@ class BinarySpaceTree
size_t NumChildren() const;
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- BinarySpaceTree& GetNearestChild(
+ size_t GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- BinarySpaceTree& GetFurthestChild(
+ size_t GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the nearest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- BinarySpaceTree* GetNearestChild(const BinarySpaceTree& queryNode);
+ size_t GetNearestChild(const BinarySpaceTree& queryNode);
/**
- * Return the furthest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- BinarySpaceTree* GetFurthestChild(const BinarySpaceTree& queryNode);
+ size_t GetFurthestChild(const BinarySpaceTree& queryNode);
/**
* Return the furthest distance to a point held in this node. If this is not
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index 200e631..3ae07d1 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -485,8 +485,8 @@ inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
}
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -495,27 +495,22 @@ template<typename MetricType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
template<typename VecType>
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
- GetNearestChild(
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
- if (IsLeaf())
- return *this;
- if (!left)
- return *right;
- if (!right)
- return *left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (left->MinDistance(point) <= right->MinDistance(point))
- return *left;
- return *right;
+ return 0;
+ return 1;
}
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -524,27 +519,22 @@ template<typename MetricType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
template<typename VecType>
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
- GetFurthestChild(
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
- if (IsLeaf())
- return *this;
- if (!left)
- return *right;
- if (!right)
- return *left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (left->MaxDistance(point) > right->MaxDistance(point))
- return *left;
- return *right;
+ return 0;
+ return 1;
}
/**
- * Return the nearest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -552,29 +542,24 @@ template<typename MetricType,
template<typename BoundMetricType, typename...> class BoundType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
- GetNearestChild(const BinarySpaceTree& queryNode)
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::GetNearestChild(const BinarySpaceTree& queryNode)
{
- if (IsLeaf())
- return NULL;
- if (!left)
- return right;
- if (!right)
- return left;
+ if (IsLeaf() || !left || !right)
+ return 0;
ElemType leftDist = left->MinDistance(&queryNode);
ElemType rightDist = right->MinDistance(&queryNode);
if (leftDist < rightDist)
- return left;
+ return 0;
if (rightDist < leftDist)
- return right;
- return NULL;
+ return 1;
+ return NumChildren();
}
/**
- * Return the furthest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -582,24 +567,19 @@ template<typename MetricType,
template<typename BoundMetricType, typename...> class BoundType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
-BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
- GetFurthestChild(const BinarySpaceTree& queryNode)
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::GetFurthestChild(const BinarySpaceTree& queryNode)
{
- if (IsLeaf())
- return NULL;
- if (!left)
- return right;
- if (!right)
- return left;
+ if (IsLeaf() || !left || !right)
+ return 0;
ElemType leftDist = left->MaxDistance(&queryNode);
ElemType rightDist = right->MaxDistance(&queryNode);
if (leftDist > rightDist)
- return left;
+ return 0;
if (rightDist > leftDist)
- return right;
- return NULL;
+ return 1;
+ return NumChildren();
}
/**
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index a87a188..5cf2f73 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -295,34 +295,34 @@ class CoverTree
StatisticType& Stat() { return stat; }
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- CoverTree& GetNearestChild(
+ size_t GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- CoverTree& GetFurthestChild(
+ size_t GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the nearest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- CoverTree* GetNearestChild(const CoverTree& queryNode);
+ size_t GetNearestChild(const CoverTree& queryNode);
/**
- * Return the furthest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- CoverTree* GetFurthestChild(const CoverTree& queryNode);
+ size_t GetFurthestChild(const CoverTree& queryNode);
//! Return the minimum distance to another node.
ElemType MinDistance(const CoverTree* other) const;
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index c73b469..0fa742a 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -607,21 +607,20 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::Descendant(
}
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
typename MatType,
typename RootPointPolicy>
template<typename VecType>
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>&
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::GetNearestChild(
- const VecType& point,
- typename boost::enable_if<IsVector<VecType> >::type*)
+size_t CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+ GetNearestChild(const VecType& point,
+ typename boost::enable_if<IsVector<VecType> >::type*)
{
if (IsLeaf())
- return *this;
+ return 0;
ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
@@ -634,26 +633,24 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::GetNearestChild(
bestIndex = i;
}
}
- return *children[bestIndex];
+ return bestIndex;
}
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
typename MatType,
typename RootPointPolicy>
template<typename VecType>
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>&
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
- GetFurthestChild(
- const VecType& point,
- typename boost::enable_if<IsVector<VecType> >::type*)
+size_t CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+ GetFurthestChild(const VecType& point,
+ typename boost::enable_if<IsVector<VecType> >::type*)
{
if (IsLeaf())
- return *this;
+ return 0;
ElemType bestDistance = 0;
size_t bestIndex = 0;
@@ -666,23 +663,22 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
bestIndex = i;
}
}
- return *children[bestIndex];
+ return bestIndex;
}
/**
- * Return the nearest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
typename MatType,
typename RootPointPolicy>
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+size_t CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
GetNearestChild(const CoverTree& queryNode)
{
if (IsLeaf())
- return NULL;
+ return 0;
ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
@@ -695,23 +691,22 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
bestIndex = i;
}
}
- return children[bestIndex];
+ return bestIndex;
}
/**
- * Return the furthest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
typename MatType,
typename RootPointPolicy>
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
-CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+size_t CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
GetFurthestChild(const CoverTree& queryNode)
{
if (IsLeaf())
- return NULL;
+ return 0;
ElemType bestDistance = 0;
size_t bestIndex = 0;
@@ -724,7 +719,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
bestIndex = i;
}
}
- return children[bestIndex];
+ return bestIndex;
}
template<
diff --git a/src/mlpack/core/tree/greedy_single_tree_traverser_impl.hpp b/src/mlpack/core/tree/greedy_single_tree_traverser_impl.hpp
index d881158..5543f16 100644
--- a/src/mlpack/core/tree/greedy_single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/greedy_single_tree_traverser_impl.hpp
@@ -36,8 +36,8 @@ void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse(
// We are prunning all but one child.
numPrunes += referenceNode.NumChildren() - 1;
// Recurse the best child.
- TreeType& bestChild = rule.GetBestChild(queryIndex, referenceNode);
- Traverse(queryIndex, bestChild);
+ size_t bestChild = rule.GetBestChild(queryIndex, referenceNode);
+ Traverse(queryIndex, referenceNode.Child(bestChild));
}
}
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 6b13a53..d146a15 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -343,34 +343,34 @@ class RectangleTree
size_t& NumChildren() { return numChildren; }
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- RectangleTree& GetNearestChild(
+ size_t GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename VecType>
- RectangleTree& GetFurthestChild(
+ size_t GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the nearest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- RectangleTree* GetNearestChild(const RectangleTree& queryNode);
+ size_t GetNearestChild(const RectangleTree& queryNode);
/**
- * Return the furthest child node to the given query node. If it can't decide
- * it will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
- RectangleTree* GetFurthestChild(const RectangleTree& queryNode);
+ size_t GetFurthestChild(const RectangleTree& queryNode);
/**
* Return the furthest distance to a point held in this node. If this is not
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index b220bd8..d4e244e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -564,8 +564,8 @@ inline bool RectangleTree<MetricType, StatisticType, MatType, SplitType,
}
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -574,15 +574,13 @@ template<typename MetricType,
typename DescentType,
template<typename> class AuxiliaryInformationType>
template<typename VecType>
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
- AuxiliaryInformationType>&
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+size_t RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
AuxiliaryInformationType>::GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
if (IsLeaf())
- return *this;
+ return 0;
ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
@@ -595,12 +593,12 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
bestIndex = i;
}
}
- return Child(bestIndex);
+ return bestIndex;
}
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -609,15 +607,13 @@ template<typename MetricType,
typename DescentType,
template<typename> class AuxiliaryInformationType>
template<typename VecType>
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
- AuxiliaryInformationType>&
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+size_t RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
AuxiliaryInformationType>::GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
if (IsLeaf())
- return *this;
+ return 0;
ElemType bestDistance = 0;
size_t bestIndex = 0;
@@ -630,12 +626,12 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
bestIndex = i;
}
}
- return Child(bestIndex);
+ return bestIndex;
}
/**
- * Return the nearest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -643,13 +639,11 @@ template<typename MetricType,
typename SplitType,
typename DescentType,
template<typename> class AuxiliaryInformationType>
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
- AuxiliaryInformationType>*
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+size_t RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
AuxiliaryInformationType>::GetNearestChild(const RectangleTree& queryNode)
{
if (IsLeaf())
- return NULL;
+ return 0;
ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
@@ -662,12 +656,12 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
bestIndex = i;
}
}
- return &Child(bestIndex);
+ return bestIndex;
}
/**
- * Return the furthest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -675,13 +669,11 @@ template<typename MetricType,
typename SplitType,
typename DescentType,
template<typename> class AuxiliaryInformationType>
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
- AuxiliaryInformationType>*
-RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+size_t RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
AuxiliaryInformationType>::GetFurthestChild(const RectangleTree& queryNode)
{
if (IsLeaf())
- return NULL;
+ return 0;
ElemType bestDistance = 0;
size_t bestIndex = 0;
@@ -694,7 +686,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
bestIndex = i;
}
}
- return &Child(bestIndex);
+ return bestIndex;
}
/**
diff --git a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
index c331451..fafbdfa 100644
--- a/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_dual_tree_traverser_impl.hpp
@@ -105,10 +105,10 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (Defeatist && referenceNode.Overlap())
{
// If referenceNode is a overlapping node let's do defeatist search.
- SpillTree* bestChild = rule.GetBestChild(queryNode, referenceNode);
- if (bestChild)
+ size_t bestChild = rule.GetBestChild(queryNode, referenceNode);
+ if (bestChild < referenceNode.NumChildren())
{
- Traverse(queryNode, *bestChild);
+ Traverse(queryNode, referenceNode.Child(bestChild));
++numPrunes;
}
else
@@ -216,11 +216,10 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
if (Defeatist && referenceNode.Overlap())
{
// If referenceNode is a overlapping node let's do defeatist search.
- SpillTree* bestChild = rule.GetBestChild(*queryNode.Left(),
- referenceNode);
- if (bestChild)
+ size_t bestChild = rule.GetBestChild(*queryNode.Left(), referenceNode);
+ if (bestChild < referenceNode.NumChildren())
{
- Traverse(*queryNode.Left(), *bestChild);
+ Traverse(*queryNode.Left(), referenceNode.Child(bestChild));
++numPrunes;
}
else
@@ -232,9 +231,9 @@ SpillDualTreeTraverser<RuleType, Defeatist>::Traverse(
}
bestChild = rule.GetBestChild(*queryNode.Right(), referenceNode);
- if (bestChild)
+ if (bestChild < referenceNode.NumChildren())
{
- Traverse(*queryNode.Right(), *bestChild);
+ Traverse(*queryNode.Right(), referenceNode.Child(bestChild));
++numPrunes;
}
else
diff --git a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
index 9cff38c..568ed13 100644
--- a/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_single_tree_traverser_impl.hpp
@@ -54,8 +54,8 @@ SpillSingleTreeTraverser<RuleType, Defeatist>::Traverse(
if (Defeatist && referenceNode.Overlap())
{
// If referenceNode is a overlapping node we do defeatist search.
- SpillTree& bestChild = rule.GetBestChild(queryIndex, referenceNode);
- Traverse(queryIndex, bestChild);
+ size_t bestChild = rule.GetBestChild(queryIndex, referenceNode);
+ Traverse(queryIndex, referenceNode.Child(bestChild));
++numPrunes;
}
else
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index 3d34a75..1e5bc1e 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -265,42 +265,42 @@ class SpillTree
size_t NumChildren() const;
/**
- * Return the nearest child node to the given query point (this is an
- * efficient estimation based on the splitting hyperplane, the node returned
- * is not necessarily the nearest). If this is a leaf node, it will return a
- * reference to itself.
+ * Return the index of the nearest child node to the given query point (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the nearest). If this is a leaf node, it will
+ * return NumChildren() (invalid index).
*/
template<typename VecType>
- SpillTree& GetNearestChild(
+ size_t GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the furthest child node to the given query point (this is an
- * efficient estimation based on the splitting hyperplane, the node returned
- * is not necessarily the furthest). If this is a leaf node, it will return a
- * reference to itself.
+ * Return the index of the furthest child node to the given query point (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the furthest). If this is a leaf node, it will
+ * return NumChildren() (invalid index).
*/
template<typename VecType>
- SpillTree& GetFurthestChild(
+ size_t GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
- * Return the nearest child node to the given query node (this is an
- * efficient estimation based on the splitting hyperplane, the node returned
- * is not necessarily the nearest). If it can't decide it will return a null
- * pointer.
+ * Return the index of the nearest child node to the given query node (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the nearest). If it can't decide it will
+ * return NumChildren() (invalid index).
*/
- SpillTree* GetNearestChild(const SpillTree& queryNode);
+ size_t GetNearestChild(const SpillTree& queryNode);
/**
- * Return the furthest child node to the given query node (this is an
- * efficient estimation based on the splitting hyperplane, the node returned
- * is not necessarily the furthest). If it can't decide it will return a null
- * pointer.
+ * Return the index of the furthest child node to the given query node (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the furthest). If it can't decide it will
+ * return NumChildren() (invalid index).
*/
- SpillTree* GetFurthestChild(const SpillTree& queryNode);
+ size_t GetFurthestChild(const SpillTree& queryNode);
/**
* Return the furthest distance to a point held in this node. If this is not
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 9edb7f9..1017f27 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -307,8 +307,10 @@ inline size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
}
/**
- * Return the nearest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the nearest child node to the given query point (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the nearest). If this is a leaf node, it will
+ * return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -317,27 +319,24 @@ template<typename MetricType,
template<typename SplitMetricType, typename SplitMatType>
class SplitType>
template<typename VecType>
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>&
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
- GetNearestChild(
+size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
+ SplitType>::GetNearestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
- if (IsLeaf())
- return *this;
- if (!left)
- return *right;
- if (!right)
- return *left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (hyperplane.Left(point))
- return *left;
- return *right;
+ return 0;
+ return 1;
}
/**
- * Return the furthest child node to the given query point. If this is a leaf
- * node, it will return a reference to itself.
+ * Return the index of the furthest child node to the given query point (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the furthest). If this is a leaf node, it will
+ * return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -346,27 +345,24 @@ template<typename MetricType,
template<typename SplitMetricType, typename SplitMatType>
class SplitType>
template<typename VecType>
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>&
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
- GetFurthestChild(
+size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
+ SplitType>::GetFurthestChild(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type*)
{
- if (IsLeaf())
- return *this;
- if (!left)
- return *right;
- if (!right)
- return *left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (hyperplane.Left(point))
- return *right;
- return *left;
+ return 1;
+ return 0;
}
/**
- * Return the nearest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the nearest child node to the given query node (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the nearest). If it can't decide it will
+ * return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -374,28 +370,25 @@ template<typename MetricType,
template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitMetricType, typename SplitMatType>
class SplitType>
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
- GetNearestChild(const SpillTree& queryNode)
+size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
+ SplitType>::GetNearestChild(const SpillTree& queryNode)
{
- if (IsLeaf())
- return NULL;
- if (!left)
- return right;
- if (!right)
- return left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (hyperplane.Left(queryNode.Bound()))
- return left;
+ return 0;
if (hyperplane.Right(queryNode.Bound()))
- return right;
+ return 1;
// Can't decide.
- return NULL;
+ return 2;
}
/**
- * Return the furthest child node to the given query node. If it can't decide
- * will return a null pointer.
+ * Return the index of the furthest child node to the given query point (this
+ * is an efficient estimation based on the splitting hyperplane, the node
+ * returned is not necessarily the furthest). If this is a leaf node, it will
+ * return NumChildren() (invalid index).
*/
template<typename MetricType,
typename StatisticType,
@@ -403,23 +396,18 @@ template<typename MetricType,
template<typename HyperplaneMetricType> class HyperplaneType,
template<typename SplitMetricType, typename SplitMatType>
class SplitType>
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
-SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
- GetFurthestChild(const SpillTree& queryNode)
+size_t SpillTree<MetricType, StatisticType, MatType, HyperplaneType,
+ SplitType>::GetFurthestChild(const SpillTree& queryNode)
{
- if (IsLeaf())
- return NULL;
- if (!left)
- return right;
- if (!right)
- return left;
+ if (IsLeaf() || !left || !right)
+ return 0;
if (hyperplane.Left(queryNode.Bound()))
- return right;
+ return 1;
if (hyperplane.Right(queryNode.Bound()))
- return left;
+ return 0;
// Can't decide.
- return NULL;
+ return 2;
}
/**
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index ec61fd1..00d80ff 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -82,7 +82,7 @@ class NeighborSearchRules
* @param queryIndex Index of query point.
* @param referenceNode Candidate node to be recursed into.
*/
- TreeType& GetBestChild(const size_t queryIndex, TreeType& referenceNode);
+ size_t GetBestChild(const size_t queryIndex, TreeType& referenceNode);
/**
* Get the child node with the best score.
@@ -90,7 +90,7 @@ class NeighborSearchRules
* @param queryNode Node to be considered.
* @param referenceNode Candidate node to be recursed into.
*/
- TreeType* GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
+ size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode);
/**
* Re-evaluate the score for recursion order. A low score indicates priority
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
index 6bf011b..65b8915 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
@@ -146,7 +146,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline TreeType& NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+inline size_t NeighborSearchRules<SortPolicy, MetricType, TreeType>::
GetBestChild(const size_t queryIndex, TreeType& referenceNode)
{
++scores;
@@ -154,7 +154,7 @@ GetBestChild(const size_t queryIndex, TreeType& referenceNode)
}
template<typename SortPolicy, typename MetricType, typename TreeType>
-inline TreeType* NeighborSearchRules<SortPolicy, MetricType, TreeType>::
+inline size_t NeighborSearchRules<SortPolicy, MetricType, TreeType>::
GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
{
++scores;
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
index 4e855a7..2d385a9 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp
@@ -99,8 +99,7 @@ class FurthestNeighborSort
* return the one with the maximum distance.
*/
template<typename VecType, typename TreeType>
- static TreeType& GetBestChild(const VecType& queryPoint,
- TreeType& referenceNode)
+ static size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode)
{
return referenceNode.GetFurthestChild(queryPoint);
};
@@ -110,8 +109,7 @@ class FurthestNeighborSort
* return the one with the maximum distance.
*/
template<typename TreeType>
- static TreeType* GetBestChild(const TreeType& queryNode,
- TreeType& referenceNode)
+ static size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
{
return referenceNode.GetFurthestChild(queryNode);
};
diff --git a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
index 7a0ac57..3034db3 100644
--- a/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
+++ b/src/mlpack/methods/neighbor_search/sort_policies/nearest_neighbor_sort.hpp
@@ -103,8 +103,7 @@ class NearestNeighborSort
* return the one with the minimum distance.
*/
template<typename VecType, typename TreeType>
- static TreeType& GetBestChild(const VecType& queryPoint,
- TreeType& referenceNode)
+ static size_t GetBestChild(const VecType& queryPoint, TreeType& referenceNode)
{
return referenceNode.GetNearestChild(queryPoint);
};
@@ -114,8 +113,7 @@ class NearestNeighborSort
* return the one with the minimum distance.
*/
template<typename TreeType>
- static TreeType* GetBestChild(const TreeType& queryNode,
- TreeType& referenceNode)
+ static size_t GetBestChild(const TreeType& queryNode, TreeType& referenceNode)
{
return referenceNode.GetNearestChild(queryNode);
};
More information about the mlpack-git
mailing list