[mlpack-git] master: Add GetFurthestChild() and GetNearestChild() for queryNode. (8f27741)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit 8f277415e401cf5b2197721829b939a49c3616f9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Fri Aug 19 14:44:40 2016 -0300
Add GetFurthestChild() and GetNearestChild() for queryNode.
>---------------------------------------------------------------
8f277415e401cf5b2197721829b939a49c3616f9
.../tree/binary_space_tree/binary_space_tree.hpp | 12 +++
.../binary_space_tree/binary_space_tree_impl.hpp | 78 ++++++++++++++++++-
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 12 +++
.../core/tree/cover_tree/cover_tree_impl.hpp | 66 +++++++++++++++-
.../core/tree/rectangle_tree/rectangle_tree.hpp | 12 +++
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 72 ++++++++++++++++-
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 11 +++
.../core/tree/spill_tree/spill_tree_impl.hpp | 90 ++++++++++++++++++++--
8 files changed, 335 insertions(+), 18 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index 3be4fe1..b291456 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -351,6 +351,18 @@ class BinarySpaceTree
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ BinarySpaceTree* GetNearestChild(const BinarySpaceTree& queryNode);
+
+ /**
+ * Return the furthest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ BinarySpaceTree* GetFurthestChild(const BinarySpaceTree& queryNode);
+
+ /**
* Return the furthest distance to a point held in this node. If this is not
* a leaf node, then the distance is 0 because the node holds no points.
*/
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index e1f77a8..200e631 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -503,8 +503,13 @@ BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
{
if (IsLeaf())
return *this;
- if (left && (!right || left->MinDistance(point) <= right->MinDistance(point)))
- return *left;
+ if (!left)
+ return *right;
+ if (!right)
+ return *left;
+
+ if (left->MinDistance(point) <= right->MinDistance(point))
+ return *left;
return *right;
}
@@ -527,12 +532,77 @@ BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
{
if (IsLeaf())
return *this;
- if (left && (!right || left->MaxDistance(point) > right->MaxDistance(point)))
- return *left;
+ if (!left)
+ return *right;
+ if (!right)
+ return *left;
+
+ if (left->MaxDistance(point) > right->MaxDistance(point))
+ return *left;
return *right;
}
/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType,
+ template<typename SplitBoundType, typename SplitMatType>
+ class SplitType>
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+ GetNearestChild(const BinarySpaceTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+ if (!left)
+ return right;
+ if (!right)
+ return left;
+
+ ElemType leftDist = left->MinDistance(&queryNode);
+ ElemType rightDist = right->MinDistance(&queryNode);
+ if (leftDist < rightDist)
+ return left;
+ if (rightDist < leftDist)
+ return right;
+ return NULL;
+}
+
+/**
+ * Return the furthest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType,
+ template<typename SplitBoundType, typename SplitMatType>
+ class SplitType>
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+ GetFurthestChild(const BinarySpaceTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+ if (!left)
+ return right;
+ if (!right)
+ return left;
+
+ ElemType leftDist = left->MaxDistance(&queryNode);
+ ElemType rightDist = right->MaxDistance(&queryNode);
+ if (leftDist > rightDist)
+ return left;
+ if (rightDist > leftDist)
+ return right;
+ return NULL;
+}
+
+/**
* Return a bound on the furthest point in the node from the center. This
* returns 0 unless the node is a leaf.
*/
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 75fab6a..a87a188 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -312,6 +312,18 @@ class CoverTree
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
+ /**
+ * Return the nearest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ CoverTree* GetNearestChild(const CoverTree& queryNode);
+
+ /**
+ * Return the furthest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ CoverTree* GetFurthestChild(const CoverTree& queryNode);
+
//! Return the minimum distance to another node.
ElemType MinDistance(const CoverTree* other) const;
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index c39263a..c73b469 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -623,11 +623,11 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::GetNearestChild(
if (IsLeaf())
return *this;
- double bestDistance = DBL_MAX;
+ ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
for (size_t i = 0; i < children.size(); ++i)
{
- double distance = children[i]->MinDistance(point);
+ ElemType distance = children[i]->MinDistance(point);
if (distance <= bestDistance)
{
bestDistance = distance;
@@ -655,11 +655,11 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
if (IsLeaf())
return *this;
- double bestDistance = 0;
+ ElemType bestDistance = 0;
size_t bestIndex = 0;
for (size_t i = 0; i < children.size(); ++i)
{
- double distance = children[i]->MaxDistance(point);
+ ElemType distance = children[i]->MaxDistance(point);
if (distance >= bestDistance)
{
bestDistance = distance;
@@ -669,6 +669,64 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
return *children[bestIndex];
}
+/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+ GetNearestChild(const CoverTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+
+ ElemType bestDistance = std::numeric_limits<ElemType>::max();
+ size_t bestIndex = 0;
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ ElemType distance = children[i]->MinDistance(&queryNode);
+ if (distance <= bestDistance)
+ {
+ bestDistance = distance;
+ bestIndex = i;
+ }
+ }
+ return children[bestIndex];
+}
+
+/**
+ * Return the furthest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+ GetFurthestChild(const CoverTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+
+ ElemType bestDistance = 0;
+ size_t bestIndex = 0;
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ ElemType distance = children[i]->MaxDistance(&queryNode);
+ if (distance >= bestDistance)
+ {
+ bestDistance = distance;
+ bestIndex = i;
+ }
+ }
+ return children[bestIndex];
+}
+
template<
typename MetricType,
typename StatisticType,
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 2e2c053..6b13a53 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -361,6 +361,18 @@ class RectangleTree
typename boost::enable_if<IsVector<VecType> >::type* = 0);
/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ RectangleTree* GetNearestChild(const RectangleTree& queryNode);
+
+ /**
+ * Return the furthest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ RectangleTree* GetFurthestChild(const RectangleTree& queryNode);
+
+ /**
* Return the furthest distance to a point held in this node. If this is not
* a leaf node, then the distance is 0 because the node holds no points.
*/
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 257bc18..b220bd8 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -584,11 +584,11 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (IsLeaf())
return *this;
- double bestDistance = DBL_MAX;
+ ElemType bestDistance = std::numeric_limits<ElemType>::max();
size_t bestIndex = 0;
for (size_t i = 0; i < NumChildren(); ++i)
{
- double distance = Child(i).MinDistance(point);
+ ElemType distance = Child(i).MinDistance(point);
if (distance <= bestDistance)
{
bestDistance = distance;
@@ -619,11 +619,11 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (IsLeaf())
return *this;
- double bestDistance = 0;
+ ElemType bestDistance = 0;
size_t bestIndex = 0;
for (size_t i = 0; i < NumChildren(); ++i)
{
- double distance = Child(i).MaxDistance(point);
+ ElemType distance = Child(i).MaxDistance(point);
if (distance >= bestDistance)
{
bestDistance = distance;
@@ -634,6 +634,70 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
}
/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType,
+ template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>*
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>::GetNearestChild(const RectangleTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+
+ ElemType bestDistance = std::numeric_limits<ElemType>::max();
+ size_t bestIndex = 0;
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ ElemType distance = Child(i).MinDistance(&queryNode);
+ if (distance <= bestDistance)
+ {
+ bestDistance = distance;
+ bestIndex = i;
+ }
+ }
+ return &Child(bestIndex);
+}
+
+/**
+ * Return the furthest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType,
+ template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>*
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+ AuxiliaryInformationType>::GetFurthestChild(const RectangleTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+
+ ElemType bestDistance = 0;
+ size_t bestIndex = 0;
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ ElemType distance = Child(i).MaxDistance(&queryNode);
+ if (distance >= bestDistance)
+ {
+ bestDistance = distance;
+ bestIndex = i;
+ }
+ }
+ return &Child(bestIndex);
+}
+
+/**
* Return a bound on the furthest point in the node form the centroid.
* This returns 0 unless the node is a leaf.
*/
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index 4a81512..42cd002 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -282,6 +282,17 @@ class SpillTree
const VecType& point,
typename boost::enable_if<IsVector<VecType> >::type* = 0);
+ /**
+ * Return the nearest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ SpillTree* GetNearestChild(const SpillTree& queryNode);
+
+ /**
+ * Return the furthest child node to the given query node. If it can't decide
+ * it will return a null pointer.
+ */
+ SpillTree* GetFurthestChild(const SpillTree& queryNode);
/**
* Return the furthest distance to a point held in this node. If this is not
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 815fb23..42ba1ba 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -325,18 +325,20 @@ SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
{
if (IsLeaf())
return *this;
+ if (!left)
+ return *right;
+ if (!right)
+ return *left;
if (overlappingNode)
{
- if (left && (!right || hyperplane.Left(point)))
+ if (hyperplane.Left(point))
return *left;
- else
- return *right;
+ return *right;
}
else
{
- if (left && (!right ||
- left->MinDistance(point) <= right->MinDistance(point)))
+ if (left->MinDistance(point) <= right->MinDistance(point))
return *left;
return *right;
}
@@ -361,13 +363,89 @@ SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
{
if (IsLeaf())
return *this;
+ if (!left)
+ return *right;
+ if (!right)
+ return *left;
- if (left && (!right || left->MaxDistance(point) > right->MaxDistance(point)))
+ if (left->MaxDistance(point) > right->MaxDistance(point))
return *left;
return *right;
}
/**
+ * Return the nearest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
+ template<typename SplitMetricType, typename SplitMatType>
+ class SplitType>
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
+ GetNearestChild(const SpillTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+ if (!left)
+ return right;
+ if (!right)
+ return left;
+
+ if (overlappingNode)
+ {
+ if (hyperplane.Left(queryNode.Bound()))
+ return left;
+ if (hyperplane.Right(queryNode.Bound()))
+ return right;
+ // Can't decide.
+ return NULL;
+ }
+ else
+ {
+ ElemType leftDist = left->MinDistance(&queryNode);
+ ElemType rightDist = right->MinDistance(&queryNode);
+ if (leftDist < rightDist)
+ return left;
+ if (rightDist < leftDist)
+ return right;
+ return NULL;
+ }
+}
+
+/**
+ * Return the furthest child node to the given query node. If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename HyperplaneMetricType> class HyperplaneType,
+ template<typename SplitMetricType, typename SplitMatType>
+ class SplitType>
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
+ GetFurthestChild(const SpillTree& queryNode)
+{
+ if (IsLeaf())
+ return NULL;
+ if (!left)
+ return right;
+ if (!right)
+ return left;
+
+ ElemType leftDist = left->MaxDistance(&queryNode);
+ ElemType rightDist = right->MaxDistance(&queryNode);
+ if (leftDist > rightDist)
+ return left;
+ if (rightDist > leftDist)
+ return right;
+ return NULL;
+}
+
+/**
* Return a bound on the furthest point in the node from the center. This
* returns 0 unless the node is a leaf.
*/
More information about the mlpack-git
mailing list