[mlpack-git] master: Refactor to handle arbitrary element types. (819bef7)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Mar 7 14:59:46 EST 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f45c17bc4d70ee5d82bf11a91850a34b814eccff...a69871c4eb63087c825502fd2277565453720568
>---------------------------------------------------------------
commit 819bef748545035f9f7520de877f810b02210c25
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 28 12:04:54 2016 +0000
Refactor to handle arbitrary element types.
>---------------------------------------------------------------
819bef748545035f9f7520de877f810b02210c25
src/mlpack/core/tree/hrectbound.hpp | 51 ++++----
src/mlpack/core/tree/hrectbound_impl.hpp | 197 ++++++++++++++++---------------
2 files changed, 129 insertions(+), 119 deletions(-)
diff --git a/src/mlpack/core/tree/hrectbound.hpp b/src/mlpack/core/tree/hrectbound.hpp
index 2f6c3ee..442f575 100644
--- a/src/mlpack/core/tree/hrectbound.hpp
+++ b/src/mlpack/core/tree/hrectbound.hpp
@@ -41,11 +41,11 @@ struct IsLMetric<metric::LMetric<Power, TakeRoot>>
* with the LMetric class. Be sure to use the same template parameters for
* LMetric as you do for HRectBound -- otherwise odd results may occur.
*
- * @tparam Power The metric to use; use 2 for Euclidean (L2).
- * @tparam TakeRoot Whether or not the root should be taken (see LMetric
- * documentation).
+ * @tparam MetricType Type of metric to use; must be of type LMetric.
+ * @tparam ElemType Element type (double/float/int/etc.).
*/
-template<typename MetricType = metric::LMetric<2, true>>
+template<typename MetricType = metric::LMetric<2, true>,
+ typename ElemType = double>
class HRectBound
{
// It is required that HRectBound have an LMetric as the given MetricType.
@@ -86,28 +86,29 @@ class HRectBound
//! Get the range for a particular dimension. No bounds checking. Be
//! careful: this may make MinWidth() invalid.
- math::Range& operator[](const size_t i) { return bounds[i]; }
+ math::RangeType<ElemType>& operator[](const size_t i) { return bounds[i]; }
//! Modify the range for a particular dimension. No bounds checking.
- const math::Range& operator[](const size_t i) const { return bounds[i]; }
+ const math::RangeType<ElemType>& operator[](const size_t i) const
+ { return bounds[i]; }
//! Get the minimum width of the bound.
- double MinWidth() const { return minWidth; }
+ ElemType MinWidth() const { return minWidth; }
//! Modify the minimum width of the bound.
- double& MinWidth() { return minWidth; }
+ ElemType& MinWidth() { return minWidth; }
/**
* Calculates the center of the range, placing it into the given vector.
*
* @param center Vector which the center will be written to.
*/
- void Center(arma::vec& center) const;
+ void Center(arma::Col<ElemType>& center) const;
/**
* Calculate the volume of the hyperrectangle.
*
* @return Volume of the hyperrectangle.
*/
- double Volume() const;
+ ElemType Volume() const;
/**
* Calculates minimum bound-to-point distance.
@@ -115,15 +116,15 @@ class HRectBound
* @param point Point to which the minimum distance is requested.
*/
template<typename VecType>
- double MinDistance(const VecType& point,
- typename boost::enable_if<IsVector<VecType> >* = 0) const;
+ ElemType MinDistance(const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>* = 0) const;
/**
* Calculates minimum bound-to-bound distance.
*
* @param other Bound to which the minimum distance is requested.
*/
- double MinDistance(const HRectBound& other) const;
+ ElemType MinDistance(const HRectBound& other) const;
/**
* Calculates maximum bound-to-point squared distance.
@@ -131,15 +132,15 @@ class HRectBound
* @param point Point to which the maximum distance is requested.
*/
template<typename VecType>
- double MaxDistance(const VecType& point,
- typename boost::enable_if<IsVector<VecType> >* = 0) const;
+ ElemType MaxDistance(const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>* = 0) const;
/**
* Computes maximum distance.
*
* @param other Bound to which the maximum distance is requested.
*/
- double MaxDistance(const HRectBound& other) const;
+ ElemType MaxDistance(const HRectBound& other) const;
/**
* Calculates minimum and maximum bound-to-bound distance.
@@ -147,7 +148,7 @@ class HRectBound
* @param other Bound to which the minimum and maximum distances are
* requested.
*/
- math::Range RangeDistance(const HRectBound& other) const;
+ math::RangeType<ElemType> RangeDistance(const HRectBound& other) const;
/**
* Calculates minimum and maximum bound-to-point distance.
@@ -156,9 +157,9 @@ class HRectBound
* requested.
*/
template<typename VecType>
- math::Range RangeDistance(const VecType& point,
- typename boost::enable_if<IsVector<VecType> >* = 0)
- const;
+ math::RangeType<ElemType> RangeDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>* = 0) const;
/**
* Expands this region to include new points.
@@ -184,7 +185,7 @@ class HRectBound
/**
* Returns the diameter of the hyperrectangle (that is, the longest diagonal).
*/
- double Diameter() const;
+ ElemType Diameter() const;
/**
* Serialize the bound object.
@@ -196,14 +197,14 @@ class HRectBound
//! The dimensionality of the bound.
size_t dim;
//! The bounds for each dimension.
- math::Range* bounds;
+ math::RangeType<ElemType>* bounds;
//! Cached minimum width of bound.
- double minWidth;
+ ElemType minWidth;
};
// A specialization of BoundTraits for this class.
-template<typename MetricType>
-struct BoundTraits<HRectBound<MetricType>>
+template<typename MetricType, typename ElemType>
+struct BoundTraits<HRectBound<MetricType, ElemType>>
{
//! These bounds are always tight for each dimension.
const static bool HasTightBounds = true;
diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp
index 5b6420f..ed5c6c7 100644
--- a/src/mlpack/core/tree/hrectbound_impl.hpp
+++ b/src/mlpack/core/tree/hrectbound_impl.hpp
@@ -20,8 +20,8 @@ namespace bound {
/**
* Empty constructor.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>::HRectBound() :
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>::HRectBound() :
dim(0),
bounds(NULL),
minWidth(0)
@@ -31,20 +31,21 @@ inline HRectBound<MetricType>::HRectBound() :
* Initializes to specified dimensionality with each dimension the empty
* set.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>::HRectBound(const size_t dimension) :
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>::HRectBound(const size_t dimension) :
dim(dimension),
- bounds(new math::Range[dim]),
+ bounds(new math::RangeType<ElemType>[dim]),
minWidth(0)
{ /* Nothing to do. */ }
/**
* Copy constructor necessary to prevent memory leaks.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>::HRectBound(const HRectBound& other) :
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>::HRectBound(
+ const HRectBound<MetricType, ElemType>& other) :
dim(other.Dim()),
- bounds(new math::Range[dim]),
+ bounds(new math::RangeType<ElemType>[dim]),
minWidth(other.MinWidth())
{
// Copy other bounds over.
@@ -55,9 +56,9 @@ inline HRectBound<MetricType>::HRectBound(const HRectBound& other) :
/**
* Same as the copy constructor.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>& HRectBound<MetricType>::operator=(
- const HRectBound& other)
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>& HRectBound<MetricType, ElemType>::operator=(
+ const HRectBound<MetricType, ElemType>& other)
{
if (dim != other.Dim())
{
@@ -66,7 +67,7 @@ inline HRectBound<MetricType>& HRectBound<MetricType>::operator=(
delete[] bounds;
dim = other.Dim();
- bounds = new math::Range[dim];
+ bounds = new math::RangeType<ElemType>[dim];
}
// Now copy each of the bound values.
@@ -81,8 +82,9 @@ inline HRectBound<MetricType>& HRectBound<MetricType>::operator=(
/**
* Move constructor: take possession of another bound's information.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>::HRectBound(HRectBound&& other) :
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>::HRectBound(
+ HRectBound<MetricType, ElemType>&& other) :
dim(other.dim),
bounds(other.bounds),
minWidth(other.minWidth)
@@ -96,8 +98,8 @@ inline HRectBound<MetricType>::HRectBound(HRectBound&& other) :
/**
* Destructor: clean up memory.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>::~HRectBound()
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>::~HRectBound()
{
if (bounds)
delete[] bounds;
@@ -106,11 +108,11 @@ inline HRectBound<MetricType>::~HRectBound()
/**
* Resets all dimensions to the empty set.
*/
-template<typename MetricType>
-inline void HRectBound<MetricType>::Clear()
+template<typename MetricType, typename ElemType>
+inline void HRectBound<MetricType, ElemType>::Clear()
{
for (size_t i = 0; i < dim; i++)
- bounds[i] = math::Range();
+ bounds[i] = math::RangeType<ElemType>();
minWidth = 0;
}
@@ -119,8 +121,9 @@ inline void HRectBound<MetricType>::Clear()
*
* @param centroid Vector which the centroid will be written to.
*/
-template<typename MetricType>
-inline void HRectBound<MetricType>::Center(arma::vec& center) const
+template<typename MetricType, typename ElemType>
+inline void HRectBound<MetricType, ElemType>::Center(
+ arma::Col<ElemType>& center) const
{
// Set size correctly if necessary.
if (!(center.n_elem == dim))
@@ -135,10 +138,10 @@ inline void HRectBound<MetricType>::Center(arma::vec& center) const
*
* @return Volume of the hyperrectangle.
*/
-template<typename MetricType>
-inline double HRectBound<MetricType>::Volume() const
+template<typename MetricType, typename ElemType>
+inline ElemType HRectBound<MetricType, ElemType>::Volume() const
{
- double volume = 1.0;
+ ElemType volume = 1.0;
for (size_t i = 0; i < dim; ++i)
volume *= (bounds[i].Hi() - bounds[i].Lo());
@@ -148,17 +151,17 @@ inline double HRectBound<MetricType>::Volume() const
/**
* Calculates minimum bound-to-point squared distance.
*/
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename VecType>
-inline double HRectBound<MetricType>::MinDistance(
+inline ElemType HRectBound<MetricType, ElemType>::MinDistance(
const VecType& point,
- typename boost::enable_if<IsVector<VecType> >* /* junk */) const
+ typename boost::enable_if<IsVector<VecType>>* /* junk */) const
{
Log::Assert(point.n_elem == dim);
- double sum = 0;
+ ElemType sum = 0;
- double lower, higher;
+ ElemType lower, higher;
for (size_t d = 0; d < dim; d++)
{
lower = bounds[d].Lo() - point[d];
@@ -168,7 +171,7 @@ inline double HRectBound<MetricType>::MinDistance(
// absolute value to itself and then sum those two, our result is the
// nonnegative half of the equation times two; then we raise to power Power.
sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (double) MetricType::Power);
+ (ElemType) MetricType::Power);
}
// Now take the Power'th root (but make sure our result is squared if it needs
@@ -176,7 +179,7 @@ inline double HRectBound<MetricType>::MinDistance(
// that was introduced earlier. The compiler should optimize out the if
// statement entirely.
if (MetricType::TakeRoot)
- return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
else
return sum / pow(2.0, MetricType::Power);
}
@@ -184,16 +187,17 @@ inline double HRectBound<MetricType>::MinDistance(
/**
* Calculates minimum bound-to-bound squared distance.
*/
-template<typename MetricType>
-double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
+template<typename MetricType, typename ElemType>
+ElemType HRectBound<MetricType, ElemType>::MinDistance(const HRectBound& other)
+ const
{
Log::Assert(dim == other.dim);
- double sum = 0;
- const math::Range* mbound = bounds;
- const math::Range* obound = other.bounds;
+ ElemType sum = 0;
+ const math::RangeType<ElemType>* mbound = bounds;
+ const math::RangeType<ElemType>* obound = other.bounds;
- double lower, higher;
+ ElemType lower, higher;
for (size_t d = 0; d < dim; d++)
{
lower = obound->Lo() - mbound->Hi();
@@ -202,7 +206,7 @@ double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
// x + fabs(x) = max(x * 2, 0)
// (x * 2)^2 / 4 = x^2
sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (double) MetricType::Power);
+ (ElemType) MetricType::Power);
// Move bound pointers.
mbound++;
@@ -211,7 +215,7 @@ double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
else
return sum / pow(2.0, MetricType::Power);
}
@@ -219,26 +223,26 @@ double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
/**
* Calculates maximum bound-to-point squared distance.
*/
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename VecType>
-inline double HRectBound<MetricType>::MaxDistance(
+inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >* /* junk */) const
{
- double sum = 0;
+ ElemType sum = 0;
Log::Assert(point.n_elem == dim);
for (size_t d = 0; d < dim; d++)
{
- double v = std::max(fabs(point[d] - bounds[d].Lo()),
+ ElemType v = std::max(fabs(point[d] - bounds[d].Lo()),
fabs(bounds[d].Hi() - point[d]));
- sum += pow(v, (double) MetricType::Power);
+ sum += pow(v, (ElemType) MetricType::Power);
}
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return pow(sum, 1.0 / (double) MetricType::Power);
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
else
return sum;
}
@@ -246,25 +250,26 @@ inline double HRectBound<MetricType>::MaxDistance(
/**
* Computes maximum distance.
*/
-template<typename MetricType>
-inline double HRectBound<MetricType>::MaxDistance(const HRectBound& other)
+template<typename MetricType, typename ElemType>
+inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
+ const HRectBound& other)
const
{
- double sum = 0;
+ ElemType sum = 0;
Log::Assert(dim == other.dim);
- double v;
+ ElemType v;
for (size_t d = 0; d < dim; d++)
{
v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
fabs(bounds[d].Hi() - other.bounds[d].Lo()));
- sum += pow(v, (double) MetricType::Power); // v is non-negative.
+ sum += pow(v, (ElemType) MetricType::Power); // v is non-negative.
}
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return pow(sum, 1.0 / (double) MetricType::Power);
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
else
return sum;
}
@@ -272,16 +277,17 @@ inline double HRectBound<MetricType>::MaxDistance(const HRectBound& other)
/**
* Calculates minimum and maximum bound-to-bound squared distance.
*/
-template<typename MetricType>
-inline math::Range HRectBound<MetricType>::RangeDistance(
+template<typename MetricType, typename ElemType>
+inline math::RangeType<ElemType>
+HRectBound<MetricType, ElemType>::RangeDistance(
const HRectBound& other) const
{
- double loSum = 0;
- double hiSum = 0;
+ ElemType loSum = 0;
+ ElemType hiSum = 0;
Log::Assert(dim == other.dim);
- double v1, v2, vLo, vHi;
+ ElemType v1, v2, vLo, vHi;
for (size_t d = 0; d < dim; d++)
{
v1 = other.bounds[d].Lo() - bounds[d].Hi();
@@ -298,32 +304,34 @@ inline math::Range HRectBound<MetricType>::RangeDistance(
vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
}
- loSum += pow(vLo, (double) MetricType::Power);
- hiSum += pow(vHi, (double) MetricType::Power);
+ loSum += pow(vLo, (ElemType) MetricType::Power);
+ hiSum += pow(vHi, (ElemType) MetricType::Power);
}
if (MetricType::TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
- pow(hiSum, 1.0 / (double) MetricType::Power));
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
else
- return math::Range(loSum, hiSum);
+ return math::RangeType<ElemType>(loSum, hiSum);
}
/**
* Calculates minimum and maximum bound-to-point squared distance.
*/
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename VecType>
-inline math::Range HRectBound<MetricType>::RangeDistance(
+inline math::RangeType<ElemType>
+HRectBound<MetricType, ElemType>::RangeDistance(
const VecType& point,
- typename boost::enable_if<IsVector<VecType> >* /* junk */) const
+ typename boost::enable_if<IsVector<VecType>>* /* junk */) const
{
- double loSum = 0;
- double hiSum = 0;
+ ElemType loSum = 0;
+ ElemType hiSum = 0;
Log::Assert(point.n_elem == dim);
- double v1, v2, vLo, vHi;
+ ElemType v1, v2, vLo, vHi;
for (size_t d = 0; d < dim; d++)
{
v1 = bounds[d].Lo() - point[d]; // Negative if point[d] > lo.
@@ -348,35 +356,36 @@ inline math::Range HRectBound<MetricType>::RangeDistance(
}
}
- loSum += pow(vLo, (double) MetricType::Power);
- hiSum += pow(vHi, (double) MetricType::Power);
+ loSum += pow(vLo, (ElemType) MetricType::Power);
+ hiSum += pow(vHi, (ElemType) MetricType::Power);
}
if (MetricType::TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
- pow(hiSum, 1.0 / (double) MetricType::Power));
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
else
- return math::Range(loSum, hiSum);
+ return math::RangeType<ElemType>(loSum, hiSum);
}
/**
* Expands this region to include a new point.
*/
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename MatType>
-inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
+inline HRectBound<MetricType, ElemType>& HRectBound<MetricType, ElemType>::operator|=(
const MatType& data)
{
Log::Assert(data.n_rows == dim);
- arma::vec mins(min(data, 1));
- arma::vec maxs(max(data, 1));
+ arma::Col<ElemType> mins(min(data, 1));
+ arma::Col<ElemType> maxs(max(data, 1));
- minWidth = DBL_MAX;
+ minWidth = std::numeric_limits<ElemType>::max();
for (size_t i = 0; i < dim; i++)
{
- bounds[i] |= math::Range(mins[i], maxs[i]);
- const double width = bounds[i].Width();
+ bounds[i] |= math::RangeType<ElemType>(mins[i], maxs[i]);
+ const ElemType width = bounds[i].Width();
if (width < minWidth)
minWidth = width;
}
@@ -387,17 +396,17 @@ inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
/**
* Expands this region to encompass another bound.
*/
-template<typename MetricType>
-inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
+template<typename MetricType, typename ElemType>
+inline HRectBound<MetricType, ElemType>& HRectBound<MetricType, ElemType>::operator|=(
const HRectBound& other)
{
assert(other.dim == dim);
- minWidth = DBL_MAX;
+ minWidth = std::numeric_limits<ElemType>::max();
for (size_t i = 0; i < dim; i++)
{
bounds[i] |= other.bounds[i];
- const double width = bounds[i].Width();
+ const ElemType width = bounds[i].Width();
if (width < minWidth)
minWidth = width;
}
@@ -408,9 +417,9 @@ inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
/**
* Determines if a point is within this bound.
*/
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename VecType>
-inline bool HRectBound<MetricType>::Contains(const VecType& point) const
+inline bool HRectBound<MetricType, ElemType>::Contains(const VecType& point) const
{
for (size_t i = 0; i < point.n_elem; i++)
{
@@ -424,25 +433,25 @@ inline bool HRectBound<MetricType>::Contains(const VecType& point) const
/**
* Returns the diameter of the hyperrectangle (that is, the longest diagonal).
*/
-template<typename MetricType>
-inline double HRectBound<MetricType>::Diameter() const
+template<typename MetricType, typename ElemType>
+inline ElemType HRectBound<MetricType, ElemType>::Diameter() const
{
- double d = 0;
+ ElemType d = 0;
for (size_t i = 0; i < dim; ++i)
d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
- (double) MetricType::Power);
+ (ElemType) MetricType::Power);
if (MetricType::TakeRoot)
- return std::pow(d, 1.0 / (double) MetricType::Power);
+ return (ElemType) std::pow((double) d, 1.0 / (double) MetricType::Power);
else
return d;
}
//! Serialize the bound object.
-template<typename MetricType>
+template<typename MetricType, typename ElemType>
template<typename Archive>
-void HRectBound<MetricType>::Serialize(Archive& ar,
- const unsigned int /* version */)
+void HRectBound<MetricType, ElemType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
{
ar & data::CreateNVP(dim, "dim");
@@ -451,7 +460,7 @@ void HRectBound<MetricType>::Serialize(Archive& ar,
{
if (bounds)
delete[] bounds;
- bounds = new math::Range[dim];
+ bounds = new math::RangeType<ElemType>[dim];
}
ar & data::CreateArrayNVP(bounds, dim, "bounds");
More information about the mlpack-git
mailing list