[mlpack-git] master: Optimize bound-to-point and bound-to-bound distance calculations. The same changes as in HRectBound. (9bd634a)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:03:12 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit 9bd634a459974976dc976dad530c4f716e3810c6
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Thu Aug 18 20:03:12 2016 +0300
Optimize bound-to-point and bound-to-bound distance calculations.
The same changes as in HRectBound.
>---------------------------------------------------------------
9bd634a459974976dc976dad530c4f716e3810c6
src/mlpack/core/tree/cellbound_impl.hpp | 178 +++++++++++++++++++++++++++-----
1 file changed, 152 insertions(+), 26 deletions(-)
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 13f02e0..150e2a2 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -461,8 +461,21 @@ inline ElemType CellBound<MetricType, ElemType>::MinDistance(
// Since only one of 'lower' or 'higher' is negative, if we add each's
// absolute value to itself and then sum those two, our result is the
// nonnegative half of the equation times two; then we raise to power Power.
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (ElemType) MetricType::Power);
+ if (MetricType::Power == 1)
+ sum += lower + std::fabs(lower) + higher + std::fabs(higher);
+ else if (MetricType::Power == 2)
+ {
+ ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
+ sum += dist * dist;
+ }
+ else
+ {
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (ElemType) MetricType::Power);
+ }
+
+ if (sum >= minSum)
+ break;
}
if (sum < minSum)
@@ -473,11 +486,23 @@ inline ElemType CellBound<MetricType, ElemType>::MinDistance(
// to be); then cancel out the constant of 2 (which may have been squared now)
// that was introduced earlier. The compiler should optimize out the if
// statement entirely.
- if (MetricType::TakeRoot)
- return (ElemType) pow((double) minSum,
- 1.0 / (double) MetricType::Power) / 2.0;
+ if (MetricType::Power == 1)
+ return minSum * 0.5;
+ else if (MetricType::Power == 2)
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) std::sqrt(minSum) * 0.5;
+ else
+ return minSum * 0.25;
+ }
else
- return minSum / pow(2.0, MetricType::Power);
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) pow((double) minSum,
+ 1.0 / (double) MetricType::Power) / 2.0;
+ else
+ return minSum / pow(2.0, MetricType::Power);
+ }
}
/**
@@ -504,9 +529,23 @@ ElemType CellBound<MetricType, ElemType>::MinDistance(const CellBound& other)
// We invoke the following:
// x + fabs(x) = max(x * 2, 0)
// (x * 2)^2 / 4 = x^2
- sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
- (ElemType) MetricType::Power);
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+ else if (MetricType::Power == 2)
+ {
+ ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
+ sum += dist * dist;
+ }
+ else
+ {
+ sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+ (ElemType) MetricType::Power);
+ }
+
+ if (sum >= minSum)
+ break;
}
if (sum < minSum)
@@ -514,11 +553,23 @@ ElemType CellBound<MetricType, ElemType>::MinDistance(const CellBound& other)
}
// The compiler should optimize out this if statement entirely.
- if (MetricType::TakeRoot)
- return (ElemType) pow((double) minSum,
- 1.0 / (double) MetricType::Power) / 2.0;
+ if (MetricType::Power == 1)
+ return minSum * 0.5;
+ else if (MetricType::Power == 2)
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) std::sqrt(minSum) * 0.5;
+ else
+ return minSum * 0.25;
+ }
else
- return minSum / pow(2.0, MetricType::Power);
+ {
+ if (MetricType::TakeRoot)
+ return (ElemType) pow((double) minSum,
+ 1.0 / (double) MetricType::Power) / 2.0;
+ else
+ return minSum / pow(2.0, MetricType::Power);
+ }
}
/**
@@ -541,7 +592,13 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
{
ElemType v = std::max(fabs(point[d] - loBound(d, i)),
fabs(hiBound(d, i) - point[d]));
- sum += pow(v, (ElemType) MetricType::Power);
+
+ if (MetricType::Power == 1)
+ sum += v; // v is non-negative.
+ else if (MetricType::Power == 2)
+ sum += v * v;
+ else
+ sum += std::pow(v, (ElemType) MetricType::Power);
}
if (sum > maxSum)
@@ -550,7 +607,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+ {
+ if (MetricType::Power == 1)
+ return maxSum;
+ else if (MetricType::Power == 2)
+ return (ElemType) std::sqrt(maxSum);
+ else
+ return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+ }
else
return maxSum;
}
@@ -576,7 +640,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
{
v = std::max(fabs(other.hiBound(d, j) - loBound(d, i)),
fabs(hiBound(d, i) - other.loBound(d, j)));
- sum += pow(v, (ElemType) MetricType::Power); // v is non-negative.
+
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ sum += v; // v is non-negative.
+ else if (MetricType::Power == 2)
+ sum += v * v;
+ else
+ sum += std::pow(v, (ElemType) MetricType::Power);
}
if (sum > maxSum)
@@ -585,7 +656,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
// The compiler should optimize out this if statement entirely.
if (MetricType::TakeRoot)
- return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+ {
+ if (MetricType::Power == 1)
+ return maxSum;
+ else if (MetricType::Power == 2)
+ return (ElemType) std::sqrt(maxSum);
+ else
+ return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+ }
else
return maxSum;
}
@@ -626,8 +704,22 @@ CellBound<MetricType, ElemType>::RangeDistance(
vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
}
- loSum += pow(vLo, (ElemType) MetricType::Power);
- hiSum += pow(vHi, (ElemType) MetricType::Power);
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ {
+ loSum += vLo; // vLo is non-negative.
+ hiSum += vHi; // vHi is non-negative.
+ }
+ else if (MetricType::Power == 2)
+ {
+ loSum += vLo * vLo;
+ hiSum += vHi * vHi;
+ }
+ else
+ {
+ loSum += std::pow(vLo, (ElemType) MetricType::Power);
+ hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+ }
}
if (loSum < minLoSum)
@@ -637,9 +729,19 @@ CellBound<MetricType, ElemType>::RangeDistance(
}
if (MetricType::TakeRoot)
- return math::RangeType<ElemType>(
- (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
- (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+ {
+ if (MetricType::Power == 1)
+ return math::RangeType<ElemType>(minLoSum, maxHiSum);
+ else if (MetricType::Power == 2)
+ return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
+ (ElemType) std::sqrt(maxHiSum));
+ else
+ {
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+ }
+ }
else
return math::RangeType<ElemType>(minLoSum, maxHiSum);
}
@@ -688,8 +790,22 @@ CellBound<MetricType, ElemType>::RangeDistance(
}
}
- loSum += pow(vLo, (ElemType) MetricType::Power);
- hiSum += pow(vHi, (ElemType) MetricType::Power);
+ // The compiler should optimize out this if statement entirely.
+ if (MetricType::Power == 1)
+ {
+ loSum += vLo; // vLo is non-negative.
+ hiSum += vHi; // vHi is non-negative.
+ }
+ else if (MetricType::Power == 2)
+ {
+ loSum += vLo * vLo;
+ hiSum += vHi * vHi;
+ }
+ else
+ {
+ loSum += std::pow(vLo, (ElemType) MetricType::Power);
+ hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+ }
}
if (loSum < minLoSum)
minLoSum = loSum;
@@ -698,9 +814,19 @@ CellBound<MetricType, ElemType>::RangeDistance(
}
if (MetricType::TakeRoot)
- return math::RangeType<ElemType>(
- (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
- (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+ {
+ if (MetricType::Power == 1)
+ return math::RangeType<ElemType>(minLoSum, maxHiSum);
+ else if (MetricType::Power == 2)
+ return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
+ (ElemType) std::sqrt(maxHiSum));
+ else
+ {
+ return math::RangeType<ElemType>(
+ (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
+ (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+ }
+ }
else
return math::RangeType<ElemType>(minLoSum, maxHiSum);
}
More information about the mlpack-git
mailing list