[mlpack-git] master: Refactor to take a MetricType. (47cc6fe)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:44 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit 47cc6fe6f072e72ef28604e203e9873c2e1038e8
Author: ryan <ryan at ratml.org>
Date: Wed Jul 22 22:07:47 2015 -0400
Refactor to take a MetricType.
>---------------------------------------------------------------
47cc6fe6f072e72ef28604e203e9873c2e1038e8
src/mlpack/core/tree/hrectbound.hpp | 40 ++++++---
src/mlpack/core/tree/hrectbound_impl.hpp | 142 ++++++++++++++++---------------
2 files changed, 100 insertions(+), 82 deletions(-)
diff --git a/src/mlpack/core/tree/hrectbound.hpp b/src/mlpack/core/tree/hrectbound.hpp
index ea525f3..13bf1cc 100644
--- a/src/mlpack/core/tree/hrectbound.hpp
+++ b/src/mlpack/core/tree/hrectbound.hpp
@@ -17,6 +17,26 @@
namespace mlpack {
namespace bound {
+namespace meta /** Metaprogramming utilities. */ {
+
+//! Utility struct where Value is true if and only if the argument is of type
+//! LMetric.
+template<typename MetricType>
+struct IsLMetric
+{
+ static const bool Value = false;
+};
+
+//! Specialization for IsLMetric when the argument is of type LMetric.
+template<>
+template<int Power, bool TakeRoot>
+struct IsLMetric<metric::LMetric<Power, TakeRoot>>
+{
+ static const bool Value = true;
+};
+
+} // namespace util
+
/**
* Hyper-rectangle bound for an L-metric. This should be used in conjunction
* with the LMetric class. Be sure to use the same template parameters for
@@ -26,13 +46,14 @@ namespace bound {
* @tparam TakeRoot Whether or not the root should be taken (see LMetric
* documentation).
*/
-template<int Power = 2, bool TakeRoot = true>
+template<typename MetricType = metric::LMetric<2, true>>
class HRectBound
{
- public:
- //! This is the metric type that this bound is using.
- typedef metric::LMetric<Power, TakeRoot> MetricType;
+ // It is required that HRectBound have an LMetric as the given MetricType.
+ static_assert(meta::IsLMetric<MetricType>::Value == true,
+ "HRectBound can only be used with the LMetric<> metric type.");
+ public:
/**
* Empty constructor; creates a bound of dimensionality 0.
*/
@@ -174,13 +195,6 @@ class HRectBound
*/
std::string ToString() const;
- /**
- * Return the metric associated with this bound. Because it is an LMetric, it
- * cannot store state, so we can make it on the fly. It is also static
- * because the metric is only dependent on the template arguments.
- */
- static MetricType Metric() { return metric::LMetric<Power, TakeRoot>(); }
-
private:
//! The dimensionality of the bound.
size_t dim;
@@ -191,8 +205,8 @@ class HRectBound
};
// A specialization of BoundTraits for this class.
-template<int Power, bool TakeRoot>
-struct BoundTraits<HRectBound<Power, TakeRoot>>
+template<typename MetricType>
+struct BoundTraits<HRectBound<MetricType>>
{
//! These bounds are always tight for each dimension.
const static bool HasTightBounds = true;
diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp
index 43a7982..0b9f76c 100644
--- a/src/mlpack/core/tree/hrectbound_impl.hpp
+++ b/src/mlpack/core/tree/hrectbound_impl.hpp
@@ -20,8 +20,8 @@ namespace bound {
/**
* Empty constructor.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound() :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound() :
dim(0),
bounds(NULL),
minWidth(0)
@@ -31,8 +31,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound() :
* Initializes to specified dimensionality with each dimension the empty
* set.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound(const size_t dimension) :
dim(dimension),
bounds(new math::Range[dim]),
minWidth(0)
@@ -41,8 +41,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
/***
* Copy constructor necessary to prevent memory leaks.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound(const HRectBound& other) :
dim(other.Dim()),
bounds(new math::Range[dim]),
minWidth(other.MinWidth())
@@ -55,8 +55,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
/***
* Same as the copy constructor.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
+template<typename MetricType>
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator=(
const HRectBound& other)
{
if (dim != other.Dim())
@@ -81,8 +81,8 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
/**
* Destructor: clean up memory.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::~HRectBound()
+template<typename MetricType>
+inline HRectBound<MetricType>::~HRectBound()
{
if (bounds)
delete[] bounds;
@@ -91,8 +91,8 @@ inline HRectBound<Power, TakeRoot>::~HRectBound()
/**
* Resets all dimensions to the empty set.
*/
-template<int Power, bool TakeRoot>
-inline void HRectBound<Power, TakeRoot>::Clear()
+template<typename MetricType>
+inline void HRectBound<MetricType>::Clear()
{
for (size_t i = 0; i < dim; i++)
bounds[i] = math::Range();
@@ -104,8 +104,8 @@ inline void HRectBound<Power, TakeRoot>::Clear()
*
* @param centroid Vector which the centroid will be written to.
*/
-template<int Power, bool TakeRoot>
-inline void HRectBound<Power, TakeRoot>::Center(arma::vec& center) const
+template<typename MetricType>
+inline void HRectBound<MetricType>::Center(arma::vec& center) const
{
// Set size correctly if necessary.
if (!(center.n_elem == dim))
@@ -120,8 +120,8 @@ inline void HRectBound<Power, TakeRoot>::Center(arma::vec& center) const
*
* @return Volume of the hyperrectangle.
*/
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::Volume() const
+template<typename MetricType>
+inline double HRectBound<MetricType>::Volume() const
{
double volume = 1.0;
for (size_t i = 0; i < dim; ++i)
@@ -133,9 +133,9 @@ inline double HRectBound<Power, TakeRoot>::Volume() const
/**
* Calculates minimum bound-to-point squared distance.
*/
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename VecType>
-inline double HRectBound<Power, TakeRoot>::MinDistance(
+inline double HRectBound<MetricType>::MinDistance(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >* /* junk */) const
{
@@ -152,24 +152,25 @@ inline double HRectBound<Power, TakeRoot>::MinDistance(
// Since only one of 'lower' or 'higher' is negative, if we add each's
// absolute value to itself and then sum those two, our result is the
// nonnegative half of the equation times two; then we raise to power Power.
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (double) MetricType::Power);
}
// Now take the Power'th root (but make sure our result is squared if it needs
// to be); then cancel out the constant of 2 (which may have been squared now)
// that was introduced earlier. The compiler should optimize out the if
// statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power) / 2.0;
+ if (MetricType::TakeRoot)
+ return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
else
- return sum / pow(2.0, Power);
+ return sum / pow(2.0, MetricType::Power);
}
/**
* Calculates minimum bound-to-bound squared distance.
*/
-template<int Power, bool TakeRoot>
-double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
+template<typename MetricType>
+double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
{
Log::Assert(dim == other.dim);
@@ -185,7 +186,8 @@ double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
// We invoke the following:
// x + fabs(x) = max(x * 2, 0)
// (x * 2)^2 / 4 = x^2
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (double) MetricType::Power);
// Move bound pointers.
mbound++;
@@ -193,18 +195,18 @@ double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
}
// The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power) / 2.0;
+ if (MetricType::TakeRoot)
+ return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
else
- return sum / pow(2.0, Power);
+ return sum / pow(2.0, MetricType::Power);
}
/**
* Calculates maximum bound-to-point squared distance.
*/
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename VecType>
-inline double HRectBound<Power, TakeRoot>::MaxDistance(
+inline double HRectBound<MetricType>::MaxDistance(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >* /* junk */) const
{
@@ -216,12 +218,12 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(
{
double v = std::max(fabs(point[d] - bounds[d].Lo()),
fabs(bounds[d].Hi() - point[d]));
- sum += pow(v, (double) Power);
+ sum += pow(v, (double) MetricType::Power);
}
// The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power);
+ if (MetricType::TakeRoot)
+ return pow(sum, 1.0 / (double) MetricType::Power);
else
return sum;
}
@@ -229,8 +231,8 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(
/**
* Computes maximum distance.
*/
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
+template<typename MetricType>
+inline double HRectBound<MetricType>::MaxDistance(const HRectBound& other)
const
{
double sum = 0;
@@ -242,12 +244,12 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
{
v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
fabs(bounds[d].Hi() - other.bounds[d].Lo()));
- sum += pow(v, (double) Power); // v is non-negative.
+ sum += pow(v, (double) MetricType::Power); // v is non-negative.
}
// The compiler should optimize out this if statement entirely.
- if (TakeRoot)
- return pow(sum, 1.0 / (double) Power);
+ if (MetricType::TakeRoot)
+ return pow(sum, 1.0 / (double) MetricType::Power);
else
return sum;
}
@@ -255,8 +257,8 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
/**
* Calculates minimum and maximum bound-to-bound squared distance.
*/
-template<int Power, bool TakeRoot>
-inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
+template<typename MetricType>
+inline math::Range HRectBound<MetricType>::RangeDistance(
const HRectBound& other) const
{
double loSum = 0;
@@ -281,13 +283,13 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
}
- loSum += pow(vLo, (double) Power);
- hiSum += pow(vHi, (double) Power);
+ loSum += pow(vLo, (double) MetricType::Power);
+ hiSum += pow(vHi, (double) MetricType::Power);
}
- if (TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) Power),
- pow(hiSum, 1.0 / (double) Power));
+ if (MetricType::TakeRoot)
+ return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
+ pow(hiSum, 1.0 / (double) MetricType::Power));
else
return math::Range(loSum, hiSum);
}
@@ -295,9 +297,9 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
/**
* Calculates minimum and maximum bound-to-point squared distance.
*/
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename VecType>
-inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
+inline math::Range HRectBound<MetricType>::RangeDistance(
const VecType& point,
typename boost::enable_if<IsVector<VecType> >* /* junk */) const
{
@@ -331,13 +333,13 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
}
}
- loSum += pow(vLo, (double) Power);
- hiSum += pow(vHi, (double) Power);
+ loSum += pow(vLo, (double) MetricType::Power);
+ hiSum += pow(vHi, (double) MetricType::Power);
}
- if (TakeRoot)
- return math::Range(pow(loSum, 1.0 / (double) Power),
- pow(hiSum, 1.0 / (double) Power));
+ if (MetricType::TakeRoot)
+ return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
+ pow(hiSum, 1.0 / (double) MetricType::Power));
else
return math::Range(loSum, hiSum);
}
@@ -345,9 +347,9 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
/**
* Expands this region to include a new point.
*/
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename MatType>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
const MatType& data)
{
Log::Assert(data.n_rows == dim);
@@ -370,8 +372,8 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
/**
* Expands this region to encompass another bound.
*/
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+template<typename MetricType>
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
const HRectBound& other)
{
assert(other.dim == dim);
@@ -391,9 +393,9 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
/**
* Determines if a point is within this bound.
*/
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename VecType>
-inline bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
+inline bool HRectBound<MetricType>::Contains(const VecType& point) const
{
for (size_t i = 0; i < point.n_elem; i++)
{
@@ -407,23 +409,24 @@ inline bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
/**
* Returns the diameter of the hyperrectangle (that is, the longest diagonal).
*/
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::Diameter() const
+template<typename MetricType>
+inline double HRectBound<MetricType>::Diameter() const
{
double d = 0;
for (size_t i = 0; i < dim; ++i)
- d += std::pow(bounds[i].Hi() - bounds[i].Lo(), (double) Power);
+ d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
+ (double) MetricType::Power);
- if (TakeRoot)
- return std::pow(d, 1.0 / (double) Power);
+ if (MetricType::TakeRoot)
+ return std::pow(d, 1.0 / (double) MetricType::Power);
else
return d;
}
//! Serialize the bound object.
-template<int Power, bool TakeRoot>
+template<typename MetricType>
template<typename Archive>
-void HRectBound<Power, TakeRoot>::Serialize(Archive& ar,
+void HRectBound<MetricType>::Serialize(Archive& ar,
const unsigned int /* version */)
{
ar & data::CreateNVP(dim, "dim");
@@ -443,13 +446,14 @@ void HRectBound<Power, TakeRoot>::Serialize(Archive& ar,
/**
* Returns a string representation of this object.
*/
-template<int Power, bool TakeRoot>
-std::string HRectBound<Power, TakeRoot>::ToString() const
+template<typename MetricType>
+std::string HRectBound<MetricType>::ToString() const
{
std::ostringstream convert;
convert << "HRectBound [" << this << "]" << std::endl;
- convert << " Power: " << Power << std::endl;
- convert << " TakeRoot: " << (TakeRoot ? "true" : "false") << std::endl;
+ convert << " Power: " << MetricType::Power << std::endl;
+ convert << " TakeRoot: " << (MetricType::TakeRoot ? "true" : "false")
+ << std::endl;
convert << " Dimensionality: " << dim << std::endl;
convert << " Bounds: " << std::endl;
for (size_t i = 0; i < dim; ++i)
More information about the mlpack-git
mailing list