[mlpack-git] master: HRectBound improvements. (55e1579)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Aug 12 14:10:00 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/8a0ad2fbe4db5614a6aad27cfb5e101ae8b1db96...dc6bae4e8634486b384b67e3ae7a690f34bdc677
>---------------------------------------------------------------
commit 55e15792e4dc2135c6e3421c5dae1322d07b0dcc
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Fri Aug 12 21:10:00 2016 +0300
HRectBound improvements.
>---------------------------------------------------------------
55e15792e4dc2135c6e3421c5dae1322d07b0dcc
src/mlpack/core/tree/hrectbound_impl.hpp | 168 ++++++++++++++++++++++++++-----
1 file changed, 144 insertions(+), 24 deletions(-)
diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp
index 60aa2b8..196067f 100644
--- a/src/mlpack/core/tree/hrectbound_impl.hpp
+++ b/src/mlpack/core/tree/hrectbound_impl.hpp
@@ -173,18 +173,40 @@ inline ElemType HRectBound<MetricType, ElemType>::MinDistance(
// Since only one of 'lower' or 'higher' is negative, if we add each's
// absolute value to itself and then sum those two, our result is the
// nonnegative half of the equation times two; then we raise to power Power.
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (ElemType) MetricType::Power);
+ if (MetricType::Power == 1)
+ sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+ else if (MetricType::Power == 2)
+ {
+ ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+ sum += dist * dist;
+ }
+ else
+ {
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (ElemType) MetricType::Power);
+ }
}
// Now take the Power'th root (but make sure our result is squared if it needs
// to be); then cancel out the constant of 2 (which may have been squared now)
// that was introduced earlier. The compiler should optimize out the if
// statement entirely.
- if (MetricType::TakeRoot)
- return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+ if (MetricType::Power == 1)
+ return sum * 0.5;
+ else if (MetricType::Power == 2)
+ {
+ if (MetricType::TakeRoot)
+ return std::sqrt(sum) * 0.5;
+ else
+ return sum * 0.25;
+ }
else
- return sum / pow(2.0, MetricType::Power);
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+ else
+ return sum / pow(2.0, MetricType::Power);
+ }
}
/**
@@ -208,8 +230,20 @@ ElemType HRectBound<MetricType, ElemType>::MinDistance(const HRectBound& other)
// We invoke the following:
// x + fabs(x) = max(x * 2, 0)
// (x * 2)^2 / 4 = x^2
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (ElemType) MetricType::Power);
+
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+ else if (MetricType::Power == 2)
+ {
+ ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+ sum += dist * dist;
+ }
+ else
+ {
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (ElemType) MetricType::Power);
+ }
// Move bound pointers.
mbound++;
@@ -217,10 +251,22 @@ ElemType HRectBound<MetricType, ElemType>::MinDistance(const HRectBound& other)
}
// The compiler should optimize out this if statement entirely.
- if (MetricType::TakeRoot)
- return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+ if (MetricType::Power == 1)
+ return sum * 0.5;
+ else if (MetricType::Power == 2)
+ {
+ if (MetricType::TakeRoot)
+ return std::sqrt(sum) * 0.5;
+ else
+ return sum * 0.25;
+ }
else
- return sum / pow(2.0, MetricType::Power);
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+ else
+ return sum / pow(2.0, MetricType::Power);
+ }
}
/**
@@ -240,12 +286,26 @@ inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
{
ElemType v = std::max(fabs(point[d] - bounds[d].Lo()),
fabs(bounds[d].Hi() - point[d]));
- sum += pow(v, (ElemType) MetricType::Power);
+
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ sum += v; // v is non-negative.
+ else if (MetricType::Power == 2)
+ sum += v * v;
+ else
+ sum += std::pow(v, (ElemType) MetricType::Power);
}
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+ {
+ if (MetricType::Power == 1)
+ return sum;
+ else if (MetricType::Power == 2)
+ return std::sqrt(sum);
+ else
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+ }
else
return sum;
}
@@ -267,12 +327,26 @@ inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
{
v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
fabs(bounds[d].Hi() - other.bounds[d].Lo()));
- sum += pow(v, (ElemType) MetricType::Power); // v is non-negative.
+
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ sum += v; // v is non-negative.
+ else if (MetricType::Power == 2)
+ sum += v * v;
+ else
+ sum += std::pow(v, (ElemType) MetricType::Power);
}
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+ {
+ if (MetricType::Power == 1)
+ return sum;
+ else if (MetricType::Power == 2)
+ return std::sqrt(sum);
+ else
+ return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+ }
else
return sum;
}
@@ -307,14 +381,37 @@ HRectBound<MetricType, ElemType>::RangeDistance(
vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
}
- loSum += pow(vLo, (ElemType) MetricType::Power);
- hiSum += pow(vHi, (ElemType) MetricType::Power);
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ {
+ loSum += vLo; // vLo is non-negative.
+ hiSum += vHi; // vHi is non-negative.
+ }
+ else if (MetricType::Power == 2)
+ {
+ loSum += vLo * vLo;
+ hiSum += vHi * vHi;
+ }
+ else
+ {
+ loSum += std::pow(vLo, (ElemType) MetricType::Power);
+ hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+ }
}
if (MetricType::TakeRoot)
- return math::RangeType<ElemType>(
- (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
- (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+ {
+ if (MetricType::Power == 1)
+ return math::RangeType<ElemType>(loSum, hiSum);
+ else if (MetricType::Power == 2)
+ return math::RangeType<ElemType>(std::sqrt(loSum), std::sqrt(hiSum));
+ else
+ {
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+ }
+ }
else
return math::RangeType<ElemType>(loSum, hiSum);
}
@@ -359,14 +456,37 @@ HRectBound<MetricType, ElemType>::RangeDistance(
}
}
- loSum += pow(vLo, (ElemType) MetricType::Power);
- hiSum += pow(vHi, (ElemType) MetricType::Power);
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ {
+ loSum += vLo; // vLo is non-negative.
+ hiSum += vHi; // vHi is non-negative.
+ }
+ else if (MetricType::Power == 2)
+ {
+ loSum += vLo * vLo;
+ hiSum += vHi * vHi;
+ }
+ else
+ {
+ loSum += std::pow(vLo, (ElemType) MetricType::Power);
+ hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+ }
}
if (MetricType::TakeRoot)
- return math::RangeType<ElemType>(
- (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
- (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+ {
+ if (MetricType::Power == 1)
+ return math::RangeType<ElemType>(loSum, hiSum);
+ else if (MetricType::Power == 2)
+ return math::RangeType<ElemType>(std::sqrt(loSum), std::sqrt(hiSum));
+ else
+ {
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+ }
+ }
else
return math::RangeType<ElemType>(loSum, hiSum);
}
More information about the mlpack-git
mailing list