[mlpack-git] master: Very minor fixes of HollowBallBound. (300882a)

gitdub at mlpack.org gitdub at mlpack.org
Mon Aug 8 14:31:11 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/acd81e11579f69e75aa8406b2982328c88cf1fde...1e9f0f39ea4443f0d595c395871ea8c6b27443af

>---------------------------------------------------------------

commit 300882ac96e7a663e3e303ca0c45c14c6fafe1a6
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Tue Jul 19 16:53:45 2016 +0300

    Very minor fixes of HollowBallBound.


>---------------------------------------------------------------

300882ac96e7a663e3e303ca0c45c14c6fafe1a6
 src/mlpack/core/tree/hollow_ball_bound.hpp      |  22 ++---
 src/mlpack/core/tree/hollow_ball_bound_impl.hpp | 126 ++++++++++++------------
 2 files changed, 71 insertions(+), 77 deletions(-)

diff --git a/src/mlpack/core/tree/hollow_ball_bound.hpp b/src/mlpack/core/tree/hollow_ball_bound.hpp
index 5770acb..53be287 100644
--- a/src/mlpack/core/tree/hollow_ball_bound.hpp
+++ b/src/mlpack/core/tree/hollow_ball_bound.hpp
@@ -33,10 +33,8 @@ class HollowBallBound
   typedef VecType Vec;
 
  private:
-  //! The radius of the inner ball bound.
-  ElemType innerRadius;
-  //! The radius of the outer ball bound.
-  ElemType outerRadius;
+  //! The inner and the outer radii of the bound.
+  math::RangeType<ElemType> radii;
   //! The center of the ball bound.
   VecType center;
   //! The metric used in this bound.
@@ -65,8 +63,8 @@ class HollowBallBound
   /**
    * Create the ball bound with the specified radius and center.
    *
-   * @param innerRradius Inner radius of ball bound.
-   * @param outerRradius Outer radius of ball bound.
+   * @param innerRadius Inner radius of ball bound.
+   * @param outerRadius Outer radius of ball bound.
    * @param center Center of ball bound.
    */
   HollowBallBound(const ElemType innerRadius,
@@ -86,14 +84,14 @@ class HollowBallBound
   ~HollowBallBound();
 
   //! Get the outer radius of the ball.
-  ElemType OuterRadius() const { return outerRadius; }
+  ElemType OuterRadius() const { return radii.Hi(); }
   //! Modify the outer radius of the ball.
-  ElemType& OuterRadius() { return outerRadius; }
+  ElemType& OuterRadius() { return radii.Hi(); }
 
   //! Get the innner radius of the ball.
-  ElemType InnerRadius() const { return innerRadius; }
+  ElemType InnerRadius() const { return radii.Lo(); }
   //! Modify the inner radius of the ball.
-  ElemType& InnerRadius() { return innerRadius; }
+  ElemType& InnerRadius() { return radii.Lo(); }
 
   //! Get the center point of the ball.
   const VecType& Center() const { return center; }
@@ -107,7 +105,7 @@ class HollowBallBound
    * Get the minimum width of the bound (this is same as the diameter).
    * For ball bounds, width along all dimensions remain same.
    */
-  ElemType MinWidth() const { return outerRadius * 2.0; }
+  ElemType MinWidth() const { return radii.Hi() * 2.0; }
 
   //! Get the range in a certain dimension.
   math::RangeType<ElemType> operator[](const size_t i) const;
@@ -194,7 +192,7 @@ class HollowBallBound
   /**
    * Returns the diameter of the ballbound.
    */
-  ElemType Diameter() const { return 2 * outerRadius; }
+  ElemType Diameter() const { return 2 * radii.Hi(); }
 
   //! Returns the distance metric used in this bound.
   const MetricType& Metric() const { return *metric; }
diff --git a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
index 570d180..ee61daa 100644
--- a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
+++ b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
@@ -20,8 +20,8 @@ namespace bound {
 //! Empty Constructor.
 template<typename MetricType, typename VecType>
 HollowBallBound<MetricType, VecType>::HollowBallBound() :
-    innerRadius(std::numeric_limits<ElemType>::lowest()),
-    outerRadius(std::numeric_limits<ElemType>::lowest()),
+    radii(std::numeric_limits<ElemType>::lowest(),
+          std::numeric_limits<ElemType>::lowest()),
     metric(new MetricType()),
     ownsMetric(true)
 { /* Nothing to do. */ }
@@ -33,8 +33,8 @@ HollowBallBound<MetricType, VecType>::HollowBallBound() :
  */
 template<typename MetricType, typename VecType>
 HollowBallBound<MetricType, VecType>::HollowBallBound(const size_t dimension) :
-    innerRadius(std::numeric_limits<ElemType>::lowest()),
-    outerRadius(std::numeric_limits<ElemType>::lowest()),
+    radii(std::numeric_limits<ElemType>::lowest(),
+          std::numeric_limits<ElemType>::lowest()),
     center(dimension),
     metric(new MetricType()),
     ownsMetric(true)
@@ -52,8 +52,8 @@ HollowBallBound<MetricType, VecType>::
 HollowBallBound(const ElemType innerRadius,
                 const ElemType outerRadius,
                 const VecType& center) :
-    innerRadius(innerRadius),
-    outerRadius(outerRadius),
+    radii(innerRadius,
+          outerRadius),
     center(center),
     metric(new MetricType()),
     ownsMetric(true)
@@ -63,8 +63,7 @@ HollowBallBound(const ElemType innerRadius,
 template<typename MetricType, typename VecType>
 HollowBallBound<MetricType, VecType>::HollowBallBound(
     const HollowBallBound& other) :
-    innerRadius(other.innerRadius),
-    outerRadius(other.outerRadius),
+    radii(other.radii),
     center(other.center),
     metric(other.metric),
     ownsMetric(false)
@@ -75,8 +74,7 @@ template<typename MetricType, typename VecType>
 HollowBallBound<MetricType, VecType>& HollowBallBound<MetricType, VecType>::
 operator=(const HollowBallBound& other)
 {
-  innerRadius = other.innerRadius;
-  outerRadius = other.outerRadius;
+  radii = other.radii;
   center = other.center;
   metric = other.metric;
   ownsMetric = false;
@@ -87,15 +85,14 @@ operator=(const HollowBallBound& other)
 //! Move constructor.
 template<typename MetricType, typename VecType>
 HollowBallBound<MetricType, VecType>::HollowBallBound(HollowBallBound&& other) :
-    innerRadius(other.innerRadius),
-    outerRadius(other.outerRadius),
+    radii(other.radii),
     center(other.center),
     metric(other.metric),
     ownsMetric(other.ownsMetric)
 {
   // Fix the other bound.
-  other.innerRadius = 0.0;
-  other.outerRadius = 0.0;
+  other.radii.Hi() = 0.0;
+  other.radii.Lo() = 0.0;
   other.center = VecType();
   other.metric = NULL;
   other.ownsMetric = false;
@@ -114,10 +111,10 @@ template<typename MetricType, typename VecType>
 math::RangeType<typename HollowBallBound<MetricType, VecType>::ElemType>
 HollowBallBound<MetricType, VecType>::operator[](const size_t i) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return math::Range();
   else
-    return math::Range(center[i] - outerRadius, center[i] + outerRadius);
+    return math::Range(center[i] - radii.Hi(), center[i] + radii.Hi());
 }
 
 /**
@@ -126,12 +123,12 @@ HollowBallBound<MetricType, VecType>::operator[](const size_t i) const
 template<typename MetricType, typename VecType>
 bool HollowBallBound<MetricType, VecType>::Contains(const VecType& point) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return false;
   else
   {
     const ElemType dist = metric->Evaluate(center, point);
-    return ((dist <= outerRadius) && (dist >= innerRadius));
+    return ((dist <= radii.Hi()) && (dist >= radii.Lo()));
   }
 }
 
@@ -142,19 +139,19 @@ template<typename MetricType, typename VecType>
 bool HollowBallBound<MetricType, VecType>::Contains(
     const HollowBallBound& other) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return false;
   else
   {
     const ElemType dist = metric->Evaluate(center, other.center);
 
-    bool containOnOneSide = (dist - other.outerRadius >= innerRadius) &&
-        (dist + other.outerRadius <= outerRadius);
-    bool containOnEverySide = (dist + innerRadius <= other.innerRadius) &&
-        (dist + other.outerRadius <= outerRadius);
+    bool containOnOneSide = (dist - other.radii.Hi() >= radii.Lo()) &&
+        (dist + other.radii.Hi() <= radii.Hi());
+    bool containOnEverySide = (dist + radii.Lo() <= other.radii.Lo()) &&
+        (dist + other.radii.Hi() <= radii.Hi());
 
-    bool containAsBall = (innerRadius == 0) &&
-        (dist + other.outerRadius <= outerRadius);
+    bool containAsBall = (radii.Lo() == 0) &&
+        (dist + other.radii.Hi() <= radii.Hi());
 
     return (containOnOneSide || containOnEverySide || containAsBall);
   }
@@ -171,14 +168,14 @@ HollowBallBound<MetricType, VecType>::MinDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType>>* /* junk */) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return std::numeric_limits<ElemType>::max();
   else
   {
     const ElemType dist = metric->Evaluate(point, center);
 
-    const ElemType outerDistance = math::ClampNonNegative(dist - outerRadius);
-    const ElemType innerDistance = math::ClampNonNegative(innerRadius - dist);
+    const ElemType outerDistance = math::ClampNonNegative(dist - radii.Hi());
+    const ElemType innerDistance = math::ClampNonNegative(radii.Lo() - dist);
 
     return innerDistance + outerDistance;
   }
@@ -192,18 +189,18 @@ typename HollowBallBound<MetricType, VecType>::ElemType
 HollowBallBound<MetricType, VecType>::MinDistance(const HollowBallBound& other)
     const
 {
-  if (outerRadius < 0 || other.outerRadius < 0)
+  if (radii.Hi() < 0 || other.radii.Hi() < 0)
     return std::numeric_limits<ElemType>::max();
   else
   {
     const ElemType centerDistance = metric->Evaluate(center, other.center);
 
     const ElemType outerDistance = math::ClampNonNegative(centerDistance -
-        outerRadius - other.outerRadius);
-    const ElemType innerDistance1 = math::ClampNonNegative(other.innerRadius -
-        centerDistance - outerRadius);
-    const ElemType innerDistance2 = math::ClampNonNegative(innerRadius -
-        centerDistance - other.outerRadius);
+        radii.Hi() - other.radii.Hi());
+    const ElemType innerDistance1 = math::ClampNonNegative(other.radii.Lo() -
+        centerDistance - radii.Hi());
+    const ElemType innerDistance2 = math::ClampNonNegative(radii.Lo() -
+        centerDistance - other.radii.Hi());
 
     return outerDistance + innerDistance1 + innerDistance2;
   }
@@ -219,10 +216,10 @@ HollowBallBound<MetricType, VecType>::MaxDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return std::numeric_limits<ElemType>::max();
   else
-    return metric->Evaluate(point, center) + outerRadius;
+    return metric->Evaluate(point, center) + radii.Hi();
 }
 
 /**
@@ -233,11 +230,11 @@ typename HollowBallBound<MetricType, VecType>::ElemType
 HollowBallBound<MetricType, VecType>::MaxDistance(const HollowBallBound& other)
     const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return std::numeric_limits<ElemType>::max();
   else
-    return metric->Evaluate(other.center, center) + outerRadius +
-        other.outerRadius;
+    return metric->Evaluate(other.center, center) + radii.Hi() +
+        other.radii.Hi();
 }
 
 /**
@@ -252,15 +249,15 @@ HollowBallBound<MetricType, VecType>::RangeDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return math::Range(std::numeric_limits<ElemType>::max(),
                        std::numeric_limits<ElemType>::max());
   else
   {
     const ElemType dist = metric->Evaluate(center, point);
-    return math::Range(math::ClampNonNegative(dist - outerRadius) +
-                       math::ClampNonNegative(innerRadius - dist),
-                       dist + outerRadius);
+    return math::Range(math::ClampNonNegative(dist - radii.Hi()) +
+                       math::ClampNonNegative(radii.Lo() - dist),
+                       dist + radii.Hi());
   }
 }
 
@@ -269,13 +266,13 @@ math::RangeType<typename HollowBallBound<MetricType, VecType>::ElemType>
 HollowBallBound<MetricType, VecType>::RangeDistance(
     const HollowBallBound& other) const
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
     return math::Range(std::numeric_limits<ElemType>::max(),
                        std::numeric_limits<ElemType>::max());
   else
   {
     const ElemType dist = metric->Evaluate(center, other.center);
-    const ElemType sumradius = outerRadius + other.outerRadius;
+    const ElemType sumradius = radii.Hi() + other.radii.Hi();
     return math::Range(MinDistance(other), dist + sumradius);
   }
 }
@@ -291,11 +288,11 @@ template<typename MatType>
 const HollowBallBound<MetricType, VecType>&
 HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
   {
     center = data.col(0);
-    outerRadius = 0;
-    innerRadius = 0;
+    radii.Hi() = 0;
+    radii.Lo() = 0;
 
     // Now iteratively add points.
     for (size_t i = 0; i < data.n_cols; ++i)
@@ -303,13 +300,13 @@ HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
       const ElemType dist = metric->Evaluate(center, (VecType) data.col(i));
 
       // See if the new point lies outside the bound.
-      if (dist > outerRadius)
+      if (dist > radii.Hi())
       {
         // Move towards the new point and increase the radius just enough to
         // accommodate the new point.
         const VecType diff = data.col(i) - center;
-        center += ((dist - outerRadius) / (2 * dist)) * diff;
-        outerRadius = 0.5 * (dist + outerRadius);
+        center += ((dist - radii.Hi()) / (2 * dist)) * diff;
+        radii.Hi() = 0.5 * (dist + radii.Hi());
       }
     }
   }
@@ -321,10 +318,10 @@ HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
       const ElemType dist = metric->Evaluate(center, data.col(i));
 
       // See if the new point lies outside the bound.
-      if (dist > outerRadius)
-        outerRadius = dist;
-      if (dist < innerRadius)
-        innerRadius = dist;
+      if (dist > radii.Hi())
+        radii.Hi() = dist;
+      if (dist < radii.Lo())
+        radii.Lo() = dist;
     }
   }
 
@@ -338,23 +335,23 @@ template<typename MetricType, typename VecType>
 const HollowBallBound<MetricType, VecType>&
 HollowBallBound<MetricType, VecType>::operator|=(const HollowBallBound& other)
 {
-  if (outerRadius < 0)
+  if (radii.Hi() < 0)
   {
     center = other.center;
-    outerRadius = other.outerRadius;
-    innerRadius = other.innerRadius;
+    radii.Hi() = other.radii.Hi();
+    radii.Lo() = other.radii.Lo();
     return *this;
   }
 
   const ElemType dist = metric->Evaluate(center, other.center);
 
-  if (outerRadius < dist + other.outerRadius)
-    outerRadius = dist + other.outerRadius;
+  if (radii.Hi() < dist + other.radii.Hi())
+    radii.Hi() = dist + other.radii.Hi();
 
-  const ElemType innerDist = math::ClampNonNegative(other.innerRadius - dist);
+  const ElemType innerDist = math::ClampNonNegative(other.radii.Lo() - dist);
 
-  if (innerRadius > innerDist)
-    innerRadius = innerDist;
+  if (radii.Lo() > innerDist)
+    radii.Lo() = innerDist;
 
   return *this;
 }
@@ -367,8 +364,7 @@ void HollowBallBound<MetricType, VecType>::Serialize(
     Archive& ar,
     const unsigned int /* version */)
 {
-  ar & data::CreateNVP(innerRadius, "innerRadius");
-  ar & data::CreateNVP(outerRadius, "outerRadius");
+  ar & data::CreateNVP(radii, "radii");
   ar & data::CreateNVP(center, "center");
 
   if (Archive::is_loading::value)




More information about the mlpack-git mailing list