[mlpack-git] master: Refactor to handle arbitrary element types. (1d7cc4a)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Mar 7 14:59:45 EST 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f45c17bc4d70ee5d82bf11a91850a34b814eccff...a69871c4eb63087c825502fd2277565453720568
>---------------------------------------------------------------
commit 1d7cc4a098b9167f2917897665d1b808075b3860
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 28 12:04:34 2016 +0000
Refactor to handle arbitrary element types.
>---------------------------------------------------------------
1d7cc4a098b9167f2917897665d1b808075b3860
src/mlpack/core/tree/ballbound.hpp | 41 ++++++++++++-----------
src/mlpack/core/tree/ballbound_impl.hpp | 58 +++++++++++++++++++--------------
2 files changed, 56 insertions(+), 43 deletions(-)
diff --git a/src/mlpack/core/tree/ballbound.hpp b/src/mlpack/core/tree/ballbound.hpp
index 2201fc4..c6018e6 100644
--- a/src/mlpack/core/tree/ballbound.hpp
+++ b/src/mlpack/core/tree/ballbound.hpp
@@ -19,21 +19,24 @@ namespace bound {
* specific point (center). TMetricType is the custom metric type that defaults
* to the Euclidean (L2) distance.
*
- * @tparam VecType Type of vector (arma::vec or arma::sp_vec).
+ * @tparam VecType Type of vector (arma::vec or arma::sp_vec or similar).
* @tparam TMetricType metric type used in the distance measure.
*/
template<typename VecType = arma::vec,
- typename TMetricType = metric::LMetric<2, true> >
+ typename TMetricType = metric::LMetric<2, true>>
class BallBound
{
public:
+ //! The underlying data type.
+ typedef typename VecType::elem_type ElemType;
+ //! A public version of the vector type.
typedef VecType Vec;
//! Needed for BinarySpaceTree.
typedef TMetricType MetricType;
private:
//! The radius of the ball bound.
- double radius;
+ ElemType radius;
//! The center of the ball bound.
VecType center;
//! The metric used in this bound.
@@ -65,7 +68,7 @@ class BallBound
* @param radius Radius of ball bound.
* @param center Center of ball bound.
*/
- BallBound(const double radius, const VecType& center);
+ BallBound(const ElemType radius, const VecType& center);
//! Copy constructor. To prevent memory leaks.
BallBound(const BallBound& other);
@@ -80,9 +83,9 @@ class BallBound
~BallBound();
//! Get the radius of the ball.
- double Radius() const { return radius; }
+ ElemType Radius() const { return radius; }
//! Modify the radius of the ball.
- double& Radius() { return radius; }
+ ElemType& Radius() { return radius; }
//! Get the center point of the ball.
const VecType& Center() const { return center; }
@@ -90,16 +93,16 @@ class BallBound
VecType& Center() { return center; }
//! Get the dimensionality of the ball.
- double Dim() const { return center.n_elem; }
+ size_t Dim() const { return center.n_elem; }
/**
* Get the minimum width of the bound (this is same as the diameter).
* For ball bounds, width along all dimensions remain same.
*/
- double MinWidth() const { return radius * 2.0; }
+ ElemType MinWidth() const { return radius * 2.0; }
//! Get the range in a certain dimension.
- math::Range operator[](const size_t i) const;
+ math::RangeType<ElemType> operator[](const size_t i) const;
/**
* Determines if a point is within this bound.
@@ -117,42 +120,42 @@ class BallBound
* Calculates minimum bound-to-point squared distance.
*/
template<typename OtherVecType>
- double MinDistance(const OtherVecType& point,
- typename boost::enable_if<IsVector<OtherVecType> >* = 0)
+ ElemType MinDistance(const OtherVecType& point,
+ typename boost::enable_if<IsVector<OtherVecType>>* = 0)
const;
/**
* Calculates minimum bound-to-bound squared distance.
*/
- double MinDistance(const BallBound& other) const;
+ ElemType MinDistance(const BallBound& other) const;
/**
* Computes maximum distance.
*/
template<typename OtherVecType>
- double MaxDistance(const OtherVecType& point,
- typename boost::enable_if<IsVector<OtherVecType> >* = 0)
+ ElemType MaxDistance(const OtherVecType& point,
+ typename boost::enable_if<IsVector<OtherVecType>>* = 0)
const;
/**
* Computes maximum distance.
*/
- double MaxDistance(const BallBound& other) const;
+ ElemType MaxDistance(const BallBound& other) const;
/**
* Calculates minimum and maximum bound-to-point distance.
*/
template<typename OtherVecType>
- math::Range RangeDistance(
+ math::RangeType<ElemType> RangeDistance(
const OtherVecType& other,
- typename boost::enable_if<IsVector<OtherVecType> >* = 0) const;
+ typename boost::enable_if<IsVector<OtherVecType>>* = 0) const;
/**
* Calculates minimum and maximum bound-to-bound distance.
*
* Example: bound1.MinDistanceSq(other) for minimum distance.
*/
- math::Range RangeDistance(const BallBound& other) const;
+ math::RangeType<ElemType> RangeDistance(const BallBound& other) const;
/**
* Expand the bound to include the given node.
@@ -173,7 +176,7 @@ class BallBound
/**
* Returns the diameter of the ballbound.
*/
- double Diameter() const { return 2 * radius; }
+ ElemType Diameter() const { return 2 * radius; }
//! Returns the distance metric used in this bound.
const TMetricType& Metric() const { return *metric; }
diff --git a/src/mlpack/core/tree/ballbound_impl.hpp b/src/mlpack/core/tree/ballbound_impl.hpp
index 94eac4c..66d6a12 100644
--- a/src/mlpack/core/tree/ballbound_impl.hpp
+++ b/src/mlpack/core/tree/ballbound_impl.hpp
@@ -20,7 +20,7 @@ namespace bound {
//! Empty Constructor.
template<typename VecType, typename TMetricType>
BallBound<VecType, TMetricType>::BallBound() :
- radius(-DBL_MAX),
+ radius(std::numeric_limits<ElemType>::min()),
metric(new TMetricType()),
ownsMetric(true)
{ /* Nothing to do. */ }
@@ -32,7 +32,7 @@ BallBound<VecType, TMetricType>::BallBound() :
*/
template<typename VecType, typename TMetricType>
BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
- radius(-DBL_MAX),
+ radius(std::numeric_limits<ElemType>::min()),
center(dimension),
metric(new TMetricType()),
ownsMetric(true)
@@ -45,8 +45,8 @@ BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
* @param center Center of ball bound.
*/
template<typename VecType, typename TMetricType>
-BallBound<VecType, TMetricType>::BallBound(const double radius,
- const VecType& center) :
+BallBound<VecType, TMetricType>::BallBound(const ElemType radius,
+ const VecType& center) :
radius(radius),
center(center),
metric(new TMetricType()),
@@ -98,7 +98,8 @@ BallBound<VecType, TMetricType>::~BallBound()
//! Get the range in a certain dimension.
template<typename VecType, typename TMetricType>
-math::Range BallBound<VecType, TMetricType>::operator[](const size_t i) const
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::operator[](const size_t i) const
{
if (radius < 0)
return math::Range();
@@ -123,12 +124,13 @@ bool BallBound<VecType, TMetricType>::Contains(const VecType& point) const
*/
template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-double BallBound<VecType, TMetricType>::MinDistance(
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MinDistance(
const OtherVecType& point,
- typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
+ typename boost::enable_if<IsVector<OtherVecType>>* /* junk */) const
{
if (radius < 0)
- return DBL_MAX;
+ return std::numeric_limits<ElemType>::max();
else
return math::ClampNonNegative(metric->Evaluate(point, center) - radius);
}
@@ -137,13 +139,15 @@ double BallBound<VecType, TMetricType>::MinDistance(
* Calculates minimum bound-to-bound squared distance.
*/
template<typename VecType, typename TMetricType>
-double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) const
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MinDistance(const BallBound& other)
+ const
{
if (radius < 0)
- return DBL_MAX;
+ return std::numeric_limits<ElemType>::max();
else
{
- const double delta = metric->Evaluate(center, other.center) - radius -
+ const ElemType delta = metric->Evaluate(center, other.center) - radius -
other.radius;
return math::ClampNonNegative(delta);
}
@@ -154,12 +158,13 @@ double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) cons
*/
template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-double BallBound<VecType, TMetricType>::MaxDistance(
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MaxDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
if (radius < 0)
- return DBL_MAX;
+ return std::numeric_limits<ElemType>::max();
else
return metric->Evaluate(point, center) + radius;
}
@@ -168,11 +173,12 @@ double BallBound<VecType, TMetricType>::MaxDistance(
* Computes maximum distance.
*/
template<typename VecType, typename TMetricType>
-double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
const
{
if (radius < 0)
- return DBL_MAX;
+ return std::numeric_limits<ElemType>::max();
else
return metric->Evaluate(other.center, center) + radius + other.radius;
}
@@ -184,30 +190,34 @@ double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
*/
template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-math::Range BallBound<VecType, TMetricType>::RangeDistance(
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::RangeDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
if (radius < 0)
- return math::Range(DBL_MAX, DBL_MAX);
+ return math::Range(std::numeric_limits<ElemType>::max(),
+ std::numeric_limits<ElemType>::max());
else
{
- const double dist = metric->Evaluate(center, point);
+ const ElemType dist = metric->Evaluate(center, point);
return math::Range(math::ClampNonNegative(dist - radius),
dist + radius);
}
}
template<typename VecType, typename TMetricType>
-math::Range BallBound<VecType, TMetricType>::RangeDistance(
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::RangeDistance(
const BallBound& other) const
{
if (radius < 0)
- return math::Range(DBL_MAX, DBL_MAX);
+ return math::Range(std::numeric_limits<ElemType>::max(),
+ std::numeric_limits<ElemType>::max());
else
{
- const double dist = metric->Evaluate(center, other.center);
- const double sumradius = radius + other.radius;
+ const ElemType dist = metric->Evaluate(center, other.center);
+ const ElemType sumradius = radius + other.radius;
return math::Range(math::ClampNonNegative(dist - sumradius),
dist + sumradius);
}
@@ -250,14 +260,14 @@ BallBound<VecType, TMetricType>::operator|=(const MatType& data)
// Now iteratively add points.
for (size_t i = 0; i < data.n_cols; ++i)
{
- const double dist = metric->Evaluate(center, (VecType) data.col(i));
+ const ElemType dist = metric->Evaluate(center, (VecType) data.col(i));
// See if the new point lies outside the bound.
if (dist > radius)
{
// Move towards the new point and increase the radius just enough to
// accommodate the new point.
- arma::vec diff = data.col(i) - center;
+ const VecType diff = data.col(i) - center;
center += ((dist - radius) / (2 * dist)) * diff;
radius = 0.5 * (dist + radius);
}
More information about the mlpack-git
mailing list