[mlpack-git] master: Add GetFurthestChild() and GetNearestChild() for queryNode. (8f27741)

gitdub at mlpack.org gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e

>---------------------------------------------------------------

commit 8f277415e401cf5b2197721829b939a49c3616f9
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Fri Aug 19 14:44:40 2016 -0300

    Add GetFurthestChild() and GetNearestChild() for queryNode.


>---------------------------------------------------------------

8f277415e401cf5b2197721829b939a49c3616f9
 .../tree/binary_space_tree/binary_space_tree.hpp   | 12 +++
 .../binary_space_tree/binary_space_tree_impl.hpp   | 78 ++++++++++++++++++-
 src/mlpack/core/tree/cover_tree/cover_tree.hpp     | 12 +++
 .../core/tree/cover_tree/cover_tree_impl.hpp       | 66 +++++++++++++++-
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 12 +++
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 72 ++++++++++++++++-
 src/mlpack/core/tree/spill_tree/spill_tree.hpp     | 11 +++
 .../core/tree/spill_tree/spill_tree_impl.hpp       | 90 ++++++++++++++++++++--
 8 files changed, 335 insertions(+), 18 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index 3be4fe1..b291456 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -351,6 +351,18 @@ class BinarySpaceTree
       typename boost::enable_if<IsVector<VecType> >::type* = 0);
 
   /**
+   * Return the nearest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  BinarySpaceTree* GetNearestChild(const BinarySpaceTree& queryNode);
+
+  /**
+   * Return the furthest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  BinarySpaceTree* GetFurthestChild(const BinarySpaceTree& queryNode);
+
+  /**
    * Return the furthest distance to a point held in this node.  If this is not
    * a leaf node, then the distance is 0 because the node holds no points.
    */
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index e1f77a8..200e631 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -503,8 +503,13 @@ BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
 {
   if (IsLeaf())
     return *this;
-  if (left && (!right || left->MinDistance(point) <= right->MinDistance(point)))
-        return *left;
+  if (!left)
+    return *right;
+  if (!right)
+    return *left;
+
+  if (left->MinDistance(point) <= right->MinDistance(point))
+    return *left;
   return *right;
 }
 
@@ -527,12 +532,77 @@ BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
 {
   if (IsLeaf())
     return *this;
-  if (left && (!right || left->MaxDistance(point) > right->MaxDistance(point)))
-        return *left;
+  if (!left)
+    return *right;
+  if (!right)
+    return *left;
+
+  if (left->MaxDistance(point) > right->MaxDistance(point))
+    return *left;
   return *right;
 }
 
 /**
+ * Return the nearest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         template<typename BoundMetricType, typename...> class BoundType,
+         template<typename SplitBoundType, typename SplitMatType>
+             class SplitType>
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+    GetNearestChild(const BinarySpaceTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+  if (!left)
+    return right;
+  if (!right)
+    return left;
+
+  ElemType leftDist = left->MinDistance(&queryNode);
+  ElemType rightDist = right->MinDistance(&queryNode);
+  if (leftDist < rightDist)
+    return left;
+  if (rightDist < leftDist)
+    return right;
+  return NULL;
+}
+
+/**
+ * Return the furthest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         template<typename BoundMetricType, typename...> class BoundType,
+         template<typename SplitBoundType, typename SplitMatType>
+             class SplitType>
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>*
+BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+    GetFurthestChild(const BinarySpaceTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+  if (!left)
+    return right;
+  if (!right)
+    return left;
+
+  ElemType leftDist = left->MaxDistance(&queryNode);
+  ElemType rightDist = right->MaxDistance(&queryNode);
+  if (leftDist > rightDist)
+    return left;
+  if (rightDist > leftDist)
+    return right;
+  return NULL;
+}
+
+/**
  * Return a bound on the furthest point in the node from the center.  This
  * returns 0 unless the node is a leaf.
  */
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 75fab6a..a87a188 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -312,6 +312,18 @@ class CoverTree
       const VecType& point,
       typename boost::enable_if<IsVector<VecType> >::type* = 0);
 
+  /**
+   * Return the nearest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  CoverTree* GetNearestChild(const CoverTree& queryNode);
+
+  /**
+   * Return the furthest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  CoverTree* GetFurthestChild(const CoverTree& queryNode);
+
   //! Return the minimum distance to another node.
   ElemType MinDistance(const CoverTree* other) const;
 
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index c39263a..c73b469 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -623,11 +623,11 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::GetNearestChild(
   if (IsLeaf())
     return *this;
 
-  double bestDistance = DBL_MAX;
+  ElemType bestDistance = std::numeric_limits<ElemType>::max();
   size_t bestIndex = 0;
   for (size_t i = 0; i < children.size(); ++i)
   {
-    double distance = children[i]->MinDistance(point);
+    ElemType distance = children[i]->MinDistance(point);
     if (distance <= bestDistance)
     {
       bestDistance = distance;
@@ -655,11 +655,11 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
   if (IsLeaf())
     return *this;
 
-  double bestDistance = 0;
+  ElemType bestDistance = 0;
   size_t bestIndex = 0;
   for (size_t i = 0; i < children.size(); ++i)
   {
-    double distance = children[i]->MaxDistance(point);
+    ElemType distance = children[i]->MaxDistance(point);
     if (distance >= bestDistance)
     {
       bestDistance = distance;
@@ -669,6 +669,64 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
   return *children[bestIndex];
 }
 
+/**
+ * Return the nearest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename RootPointPolicy>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+    GetNearestChild(const CoverTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+
+  ElemType bestDistance = std::numeric_limits<ElemType>::max();
+  size_t bestIndex = 0;
+  for (size_t i = 0; i < children.size(); ++i)
+  {
+    ElemType distance = children[i]->MinDistance(&queryNode);
+    if (distance <= bestDistance)
+    {
+      bestDistance = distance;
+      bestIndex = i;
+    }
+  }
+  return children[bestIndex];
+}
+
+/**
+ * Return the furthest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename RootPointPolicy>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>*
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
+    GetFurthestChild(const CoverTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+
+  ElemType bestDistance = 0;
+  size_t bestIndex = 0;
+  for (size_t i = 0; i < children.size(); ++i)
+  {
+    ElemType distance = children[i]->MaxDistance(&queryNode);
+    if (distance >= bestDistance)
+    {
+      bestDistance = distance;
+      bestIndex = i;
+    }
+  }
+  return children[bestIndex];
+}
+
 template<
     typename MetricType,
     typename StatisticType,
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 2e2c053..6b13a53 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -361,6 +361,18 @@ class RectangleTree
       typename boost::enable_if<IsVector<VecType> >::type* = 0);
 
   /**
+   * Return the nearest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  RectangleTree* GetNearestChild(const RectangleTree& queryNode);
+
+  /**
+   * Return the furthest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  RectangleTree* GetFurthestChild(const RectangleTree& queryNode);
+
+  /**
    * Return the furthest distance to a point held in this node.  If this is not
    * a leaf node, then the distance is 0 because the node holds no points.
    */
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 257bc18..b220bd8 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -584,11 +584,11 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   if (IsLeaf())
     return *this;
 
-  double bestDistance = DBL_MAX;
+  ElemType bestDistance = std::numeric_limits<ElemType>::max();
   size_t bestIndex = 0;
   for (size_t i = 0; i < NumChildren(); ++i)
   {
-    double distance = Child(i).MinDistance(point);
+    ElemType distance = Child(i).MinDistance(point);
     if (distance <= bestDistance)
     {
       bestDistance = distance;
@@ -619,11 +619,11 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   if (IsLeaf())
     return *this;
 
-  double bestDistance = 0;
+  ElemType bestDistance = 0;
   size_t bestIndex = 0;
   for (size_t i = 0; i < NumChildren(); ++i)
   {
-    double distance = Child(i).MaxDistance(point);
+    ElemType distance = Child(i).MaxDistance(point);
     if (distance >= bestDistance)
     {
       bestDistance = distance;
@@ -634,6 +634,70 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 }
 
 /**
+ * Return the nearest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType,
+         typename DescentType,
+         template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+    AuxiliaryInformationType>*
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+    AuxiliaryInformationType>::GetNearestChild(const RectangleTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+
+  ElemType bestDistance = std::numeric_limits<ElemType>::max();
+  size_t bestIndex = 0;
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    ElemType distance = Child(i).MinDistance(&queryNode);
+    if (distance <= bestDistance)
+    {
+      bestDistance = distance;
+      bestIndex = i;
+    }
+  }
+  return &Child(bestIndex);
+}
+
+/**
+ * Return the furthest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType,
+         typename DescentType,
+         template<typename> class AuxiliaryInformationType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+    AuxiliaryInformationType>*
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
+    AuxiliaryInformationType>::GetFurthestChild(const RectangleTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+
+  ElemType bestDistance = 0;
+  size_t bestIndex = 0;
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    ElemType distance = Child(i).MaxDistance(&queryNode);
+    if (distance >= bestDistance)
+    {
+      bestDistance = distance;
+      bestIndex = i;
+    }
+  }
+  return &Child(bestIndex);
+}
+
+/**
  * Return a bound on the furthest point in the node form the centroid.
  * This returns 0 unless the node is a leaf.
  */
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index 4a81512..42cd002 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -282,6 +282,17 @@ class SpillTree
       const VecType& point,
       typename boost::enable_if<IsVector<VecType> >::type* = 0);
 
+  /**
+   * Return the nearest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  SpillTree* GetNearestChild(const SpillTree& queryNode);
+
+  /**
+   * Return the furthest child node to the given query node.  If it can't decide
+   * it will return a null pointer.
+   */
+  SpillTree* GetFurthestChild(const SpillTree& queryNode);
 
   /**
    * Return the furthest distance to a point held in this node.  If this is not
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 815fb23..42ba1ba 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -325,18 +325,20 @@ SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
 {
   if (IsLeaf())
     return *this;
+  if (!left)
+    return *right;
+  if (!right)
+    return *left;
 
   if (overlappingNode)
   {
-    if (left && (!right || hyperplane.Left(point)))
+    if (hyperplane.Left(point))
       return *left;
-    else
-      return *right;
+    return *right;
   }
   else
   {
-    if (left && (!right ||
-        left->MinDistance(point) <= right->MinDistance(point)))
+    if (left->MinDistance(point) <= right->MinDistance(point))
       return *left;
     return *right;
   }
@@ -361,13 +363,89 @@ SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
 {
   if (IsLeaf())
     return *this;
+  if (!left)
+    return *right;
+  if (!right)
+    return *left;
 
-  if (left && (!right || left->MaxDistance(point) > right->MaxDistance(point)))
+  if (left->MaxDistance(point) > right->MaxDistance(point))
     return *left;
   return *right;
 }
 
 /**
+ * Return the nearest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
+         template<typename SplitMetricType, typename SplitMatType>
+             class SplitType>
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
+    GetNearestChild(const SpillTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+  if (!left)
+    return right;
+  if (!right)
+    return left;
+
+  if (overlappingNode)
+  {
+      if (hyperplane.Left(queryNode.Bound()))
+        return left;
+      if (hyperplane.Right(queryNode.Bound()))
+        return right;
+      // Can't decide.
+      return NULL;
+  }
+  else
+  {
+    ElemType leftDist = left->MinDistance(&queryNode);
+    ElemType rightDist = right->MinDistance(&queryNode);
+    if (leftDist < rightDist)
+      return left;
+    if (rightDist < leftDist)
+      return right;
+    return NULL;
+  }
+}
+
+/**
+ * Return the furthest child node to the given query node.  If it can't decide
+ * will return a null pointer.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         template<typename HyperplaneMetricType> class HyperplaneType,
+         template<typename SplitMetricType, typename SplitMatType>
+             class SplitType>
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>*
+SpillTree<MetricType, StatisticType, MatType, HyperplaneType, SplitType>::
+    GetFurthestChild(const SpillTree& queryNode)
+{
+  if (IsLeaf())
+    return NULL;
+  if (!left)
+    return right;
+  if (!right)
+    return left;
+
+  ElemType leftDist = left->MaxDistance(&queryNode);
+  ElemType rightDist = right->MaxDistance(&queryNode);
+  if (leftDist > rightDist)
+    return left;
+  if (rightDist > leftDist)
+    return right;
+  return NULL;
+}
+
+/**
  * Return a bound on the furthest point in the node from the center.  This
  * returns 0 unless the node is a leaf.
  */




More information about the mlpack-git mailing list