[mlpack-git] master: Very minor fixes of HollowBallBound. (300882a)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Aug 8 14:31:11 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/acd81e11579f69e75aa8406b2982328c88cf1fde...1e9f0f39ea4443f0d595c395871ea8c6b27443af
>---------------------------------------------------------------
commit 300882ac96e7a663e3e303ca0c45c14c6fafe1a6
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Tue Jul 19 16:53:45 2016 +0300
Very minor fixes of HollowBallBound.
>---------------------------------------------------------------
300882ac96e7a663e3e303ca0c45c14c6fafe1a6
src/mlpack/core/tree/hollow_ball_bound.hpp | 22 ++---
src/mlpack/core/tree/hollow_ball_bound_impl.hpp | 126 ++++++++++++------------
2 files changed, 71 insertions(+), 77 deletions(-)
diff --git a/src/mlpack/core/tree/hollow_ball_bound.hpp b/src/mlpack/core/tree/hollow_ball_bound.hpp
index 5770acb..53be287 100644
--- a/src/mlpack/core/tree/hollow_ball_bound.hpp
+++ b/src/mlpack/core/tree/hollow_ball_bound.hpp
@@ -33,10 +33,8 @@ class HollowBallBound
typedef VecType Vec;
private:
- //! The radius of the inner ball bound.
- ElemType innerRadius;
- //! The radius of the outer ball bound.
- ElemType outerRadius;
+ //! The inner and the outer radii of the bound.
+ math::RangeType<ElemType> radii;
//! The center of the ball bound.
VecType center;
//! The metric used in this bound.
@@ -65,8 +63,8 @@ class HollowBallBound
/**
* Create the ball bound with the specified radius and center.
*
- * @param innerRradius Inner radius of ball bound.
- * @param outerRradius Outer radius of ball bound.
+ * @param innerRadius Inner radius of ball bound.
+ * @param outerRadius Outer radius of ball bound.
* @param center Center of ball bound.
*/
HollowBallBound(const ElemType innerRadius,
@@ -86,14 +84,14 @@ class HollowBallBound
~HollowBallBound();
//! Get the outer radius of the ball.
- ElemType OuterRadius() const { return outerRadius; }
+ ElemType OuterRadius() const { return radii.Hi(); }
//! Modify the outer radius of the ball.
- ElemType& OuterRadius() { return outerRadius; }
+ ElemType& OuterRadius() { return radii.Hi(); }
//! Get the innner radius of the ball.
- ElemType InnerRadius() const { return innerRadius; }
+ ElemType InnerRadius() const { return radii.Lo(); }
//! Modify the inner radius of the ball.
- ElemType& InnerRadius() { return innerRadius; }
+ ElemType& InnerRadius() { return radii.Lo(); }
//! Get the center point of the ball.
const VecType& Center() const { return center; }
@@ -107,7 +105,7 @@ class HollowBallBound
* Get the minimum width of the bound (this is same as the diameter).
* For ball bounds, width along all dimensions remain same.
*/
- ElemType MinWidth() const { return outerRadius * 2.0; }
+ ElemType MinWidth() const { return radii.Hi() * 2.0; }
//! Get the range in a certain dimension.
math::RangeType<ElemType> operator[](const size_t i) const;
@@ -194,7 +192,7 @@ class HollowBallBound
/**
* Returns the diameter of the ballbound.
*/
- ElemType Diameter() const { return 2 * outerRadius; }
+ ElemType Diameter() const { return 2 * radii.Hi(); }
//! Returns the distance metric used in this bound.
const MetricType& Metric() const { return *metric; }
diff --git a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
index 570d180..ee61daa 100644
--- a/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
+++ b/src/mlpack/core/tree/hollow_ball_bound_impl.hpp
@@ -20,8 +20,8 @@ namespace bound {
//! Empty Constructor.
template<typename MetricType, typename VecType>
HollowBallBound<MetricType, VecType>::HollowBallBound() :
- innerRadius(std::numeric_limits<ElemType>::lowest()),
- outerRadius(std::numeric_limits<ElemType>::lowest()),
+ radii(std::numeric_limits<ElemType>::lowest(),
+ std::numeric_limits<ElemType>::lowest()),
metric(new MetricType()),
ownsMetric(true)
{ /* Nothing to do. */ }
@@ -33,8 +33,8 @@ HollowBallBound<MetricType, VecType>::HollowBallBound() :
*/
template<typename MetricType, typename VecType>
HollowBallBound<MetricType, VecType>::HollowBallBound(const size_t dimension) :
- innerRadius(std::numeric_limits<ElemType>::lowest()),
- outerRadius(std::numeric_limits<ElemType>::lowest()),
+ radii(std::numeric_limits<ElemType>::lowest(),
+ std::numeric_limits<ElemType>::lowest()),
center(dimension),
metric(new MetricType()),
ownsMetric(true)
@@ -52,8 +52,8 @@ HollowBallBound<MetricType, VecType>::
HollowBallBound(const ElemType innerRadius,
const ElemType outerRadius,
const VecType& center) :
- innerRadius(innerRadius),
- outerRadius(outerRadius),
+ radii(innerRadius,
+ outerRadius),
center(center),
metric(new MetricType()),
ownsMetric(true)
@@ -63,8 +63,7 @@ HollowBallBound(const ElemType innerRadius,
template<typename MetricType, typename VecType>
HollowBallBound<MetricType, VecType>::HollowBallBound(
const HollowBallBound& other) :
- innerRadius(other.innerRadius),
- outerRadius(other.outerRadius),
+ radii(other.radii),
center(other.center),
metric(other.metric),
ownsMetric(false)
@@ -75,8 +74,7 @@ template<typename MetricType, typename VecType>
HollowBallBound<MetricType, VecType>& HollowBallBound<MetricType, VecType>::
operator=(const HollowBallBound& other)
{
- innerRadius = other.innerRadius;
- outerRadius = other.outerRadius;
+ radii = other.radii;
center = other.center;
metric = other.metric;
ownsMetric = false;
@@ -87,15 +85,14 @@ operator=(const HollowBallBound& other)
//! Move constructor.
template<typename MetricType, typename VecType>
HollowBallBound<MetricType, VecType>::HollowBallBound(HollowBallBound&& other) :
- innerRadius(other.innerRadius),
- outerRadius(other.outerRadius),
+ radii(other.radii),
center(other.center),
metric(other.metric),
ownsMetric(other.ownsMetric)
{
// Fix the other bound.
- other.innerRadius = 0.0;
- other.outerRadius = 0.0;
+ other.radii.Hi() = 0.0;
+ other.radii.Lo() = 0.0;
other.center = VecType();
other.metric = NULL;
other.ownsMetric = false;
@@ -114,10 +111,10 @@ template<typename MetricType, typename VecType>
math::RangeType<typename HollowBallBound<MetricType, VecType>::ElemType>
HollowBallBound<MetricType, VecType>::operator[](const size_t i) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return math::Range();
else
- return math::Range(center[i] - outerRadius, center[i] + outerRadius);
+ return math::Range(center[i] - radii.Hi(), center[i] + radii.Hi());
}
/**
@@ -126,12 +123,12 @@ HollowBallBound<MetricType, VecType>::operator[](const size_t i) const
template<typename MetricType, typename VecType>
bool HollowBallBound<MetricType, VecType>::Contains(const VecType& point) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return false;
else
{
const ElemType dist = metric->Evaluate(center, point);
- return ((dist <= outerRadius) && (dist >= innerRadius));
+ return ((dist <= radii.Hi()) && (dist >= radii.Lo()));
}
}
@@ -142,19 +139,19 @@ template<typename MetricType, typename VecType>
bool HollowBallBound<MetricType, VecType>::Contains(
const HollowBallBound& other) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return false;
else
{
const ElemType dist = metric->Evaluate(center, other.center);
- bool containOnOneSide = (dist - other.outerRadius >= innerRadius) &&
- (dist + other.outerRadius <= outerRadius);
- bool containOnEverySide = (dist + innerRadius <= other.innerRadius) &&
- (dist + other.outerRadius <= outerRadius);
+ bool containOnOneSide = (dist - other.radii.Hi() >= radii.Lo()) &&
+ (dist + other.radii.Hi() <= radii.Hi());
+ bool containOnEverySide = (dist + radii.Lo() <= other.radii.Lo()) &&
+ (dist + other.radii.Hi() <= radii.Hi());
- bool containAsBall = (innerRadius == 0) &&
- (dist + other.outerRadius <= outerRadius);
+ bool containAsBall = (radii.Lo() == 0) &&
+ (dist + other.radii.Hi() <= radii.Hi());
return (containOnOneSide || containOnEverySide || containAsBall);
}
@@ -171,14 +168,14 @@ HollowBallBound<MetricType, VecType>::MinDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType>>* /* junk */) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return std::numeric_limits<ElemType>::max();
else
{
const ElemType dist = metric->Evaluate(point, center);
- const ElemType outerDistance = math::ClampNonNegative(dist - outerRadius);
- const ElemType innerDistance = math::ClampNonNegative(innerRadius - dist);
+ const ElemType outerDistance = math::ClampNonNegative(dist - radii.Hi());
+ const ElemType innerDistance = math::ClampNonNegative(radii.Lo() - dist);
return innerDistance + outerDistance;
}
@@ -192,18 +189,18 @@ typename HollowBallBound<MetricType, VecType>::ElemType
HollowBallBound<MetricType, VecType>::MinDistance(const HollowBallBound& other)
const
{
- if (outerRadius < 0 || other.outerRadius < 0)
+ if (radii.Hi() < 0 || other.radii.Hi() < 0)
return std::numeric_limits<ElemType>::max();
else
{
const ElemType centerDistance = metric->Evaluate(center, other.center);
const ElemType outerDistance = math::ClampNonNegative(centerDistance -
- outerRadius - other.outerRadius);
- const ElemType innerDistance1 = math::ClampNonNegative(other.innerRadius -
- centerDistance - outerRadius);
- const ElemType innerDistance2 = math::ClampNonNegative(innerRadius -
- centerDistance - other.outerRadius);
+ radii.Hi() - other.radii.Hi());
+ const ElemType innerDistance1 = math::ClampNonNegative(other.radii.Lo() -
+ centerDistance - radii.Hi());
+ const ElemType innerDistance2 = math::ClampNonNegative(radii.Lo() -
+ centerDistance - other.radii.Hi());
return outerDistance + innerDistance1 + innerDistance2;
}
@@ -219,10 +216,10 @@ HollowBallBound<MetricType, VecType>::MaxDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return std::numeric_limits<ElemType>::max();
else
- return metric->Evaluate(point, center) + outerRadius;
+ return metric->Evaluate(point, center) + radii.Hi();
}
/**
@@ -233,11 +230,11 @@ typename HollowBallBound<MetricType, VecType>::ElemType
HollowBallBound<MetricType, VecType>::MaxDistance(const HollowBallBound& other)
const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return std::numeric_limits<ElemType>::max();
else
- return metric->Evaluate(other.center, center) + outerRadius +
- other.outerRadius;
+ return metric->Evaluate(other.center, center) + radii.Hi() +
+ other.radii.Hi();
}
/**
@@ -252,15 +249,15 @@ HollowBallBound<MetricType, VecType>::RangeDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return math::Range(std::numeric_limits<ElemType>::max(),
std::numeric_limits<ElemType>::max());
else
{
const ElemType dist = metric->Evaluate(center, point);
- return math::Range(math::ClampNonNegative(dist - outerRadius) +
- math::ClampNonNegative(innerRadius - dist),
- dist + outerRadius);
+ return math::Range(math::ClampNonNegative(dist - radii.Hi()) +
+ math::ClampNonNegative(radii.Lo() - dist),
+ dist + radii.Hi());
}
}
@@ -269,13 +266,13 @@ math::RangeType<typename HollowBallBound<MetricType, VecType>::ElemType>
HollowBallBound<MetricType, VecType>::RangeDistance(
const HollowBallBound& other) const
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
return math::Range(std::numeric_limits<ElemType>::max(),
std::numeric_limits<ElemType>::max());
else
{
const ElemType dist = metric->Evaluate(center, other.center);
- const ElemType sumradius = outerRadius + other.outerRadius;
+ const ElemType sumradius = radii.Hi() + other.radii.Hi();
return math::Range(MinDistance(other), dist + sumradius);
}
}
@@ -291,11 +288,11 @@ template<typename MatType>
const HollowBallBound<MetricType, VecType>&
HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
{
center = data.col(0);
- outerRadius = 0;
- innerRadius = 0;
+ radii.Hi() = 0;
+ radii.Lo() = 0;
// Now iteratively add points.
for (size_t i = 0; i < data.n_cols; ++i)
@@ -303,13 +300,13 @@ HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
const ElemType dist = metric->Evaluate(center, (VecType) data.col(i));
// See if the new point lies outside the bound.
- if (dist > outerRadius)
+ if (dist > radii.Hi())
{
// Move towards the new point and increase the radius just enough to
// accommodate the new point.
const VecType diff = data.col(i) - center;
- center += ((dist - outerRadius) / (2 * dist)) * diff;
- outerRadius = 0.5 * (dist + outerRadius);
+ center += ((dist - radii.Hi()) / (2 * dist)) * diff;
+ radii.Hi() = 0.5 * (dist + radii.Hi());
}
}
}
@@ -321,10 +318,10 @@ HollowBallBound<MetricType, VecType>::operator|=(const MatType& data)
const ElemType dist = metric->Evaluate(center, data.col(i));
// See if the new point lies outside the bound.
- if (dist > outerRadius)
- outerRadius = dist;
- if (dist < innerRadius)
- innerRadius = dist;
+ if (dist > radii.Hi())
+ radii.Hi() = dist;
+ if (dist < radii.Lo())
+ radii.Lo() = dist;
}
}
@@ -338,23 +335,23 @@ template<typename MetricType, typename VecType>
const HollowBallBound<MetricType, VecType>&
HollowBallBound<MetricType, VecType>::operator|=(const HollowBallBound& other)
{
- if (outerRadius < 0)
+ if (radii.Hi() < 0)
{
center = other.center;
- outerRadius = other.outerRadius;
- innerRadius = other.innerRadius;
+ radii.Hi() = other.radii.Hi();
+ radii.Lo() = other.radii.Lo();
return *this;
}
const ElemType dist = metric->Evaluate(center, other.center);
- if (outerRadius < dist + other.outerRadius)
- outerRadius = dist + other.outerRadius;
+ if (radii.Hi() < dist + other.radii.Hi())
+ radii.Hi() = dist + other.radii.Hi();
- const ElemType innerDist = math::ClampNonNegative(other.innerRadius - dist);
+ const ElemType innerDist = math::ClampNonNegative(other.radii.Lo() - dist);
- if (innerRadius > innerDist)
- innerRadius = innerDist;
+ if (radii.Lo() > innerDist)
+ radii.Lo() = innerDist;
return *this;
}
@@ -367,8 +364,7 @@ void HollowBallBound<MetricType, VecType>::Serialize(
Archive& ar,
const unsigned int /* version */)
{
- ar & data::CreateNVP(innerRadius, "innerRadius");
- ar & data::CreateNVP(outerRadius, "outerRadius");
+ ar & data::CreateNVP(radii, "radii");
ar & data::CreateNVP(center, "center");
if (Archive::is_loading::value)
More information about the mlpack-git
mailing list