[mlpack-git] master: Optimize bound-to-point and bound-to-bound distance calculations. The same changes as in HRectBound. (9bd634a)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:03:12 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit 9bd634a459974976dc976dad530c4f716e3810c6
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Thu Aug 18 20:03:12 2016 +0300

    Optimize bound-to-point and bound-to-bound distance calculations.
    The same changes as in HRectBound.


>---------------------------------------------------------------

9bd634a459974976dc976dad530c4f716e3810c6
 src/mlpack/core/tree/cellbound_impl.hpp | 178 +++++++++++++++++++++++++++-----
 1 file changed, 152 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 13f02e0..150e2a2 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -461,8 +461,21 @@ inline ElemType CellBound<MetricType, ElemType>::MinDistance(
       // Since only one of 'lower' or 'higher' is negative, if we add each's
       // absolute value to itself and then sum those two, our result is the
       // nonnegative half of the equation times two; then we raise to power Power.
-      sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
-          (ElemType) MetricType::Power);
+      if (MetricType::Power == 1)
+        sum += lower + std::fabs(lower) + higher + std::fabs(higher);
+      else if (MetricType::Power == 2)
+      {
+        ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
+        sum += dist * dist;
+      }
+      else
+      {
+        sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+            (ElemType) MetricType::Power);
+      }
+
+      if (sum >= minSum)
+        break;
     }
 
     if (sum < minSum)
@@ -473,11 +486,23 @@ inline ElemType CellBound<MetricType, ElemType>::MinDistance(
   // to be); then cancel out the constant of 2 (which may have been squared now)
   // that was introduced earlier.  The compiler should optimize out the if
   // statement entirely.
-  if (MetricType::TakeRoot)
-    return (ElemType) pow((double) minSum,
-        1.0 / (double) MetricType::Power) / 2.0;
+  if (MetricType::Power == 1)
+    return minSum * 0.5;
+  else if (MetricType::Power == 2)
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) std::sqrt(minSum) * 0.5;
+    else
+      return minSum * 0.25;
+  }
   else
-    return minSum / pow(2.0, MetricType::Power);
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) pow((double) minSum,
+          1.0 / (double) MetricType::Power) / 2.0;
+    else
+      return minSum / pow(2.0, MetricType::Power);
+  }
 }
 
 /**
@@ -504,9 +529,23 @@ ElemType CellBound<MetricType, ElemType>::MinDistance(const CellBound& other)
         // We invoke the following:
         //   x + fabs(x) = max(x * 2, 0)
         //   (x * 2)^2 / 4 = x^2
-        sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
-            (ElemType) MetricType::Power);
 
+        // The compiler should optimize out this if statement entirely.
+        if (MetricType::Power == 1)
+          sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+        else if (MetricType::Power == 2)
+        {
+          ElemType dist = lower + std::fabs(lower) + higher + std::fabs(higher);
+          sum += dist * dist;
+        }
+        else
+        {
+          sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+              (ElemType) MetricType::Power);
+        }
+
+        if (sum >= minSum)
+          break;
       }
 
       if (sum < minSum)
@@ -514,11 +553,23 @@ ElemType CellBound<MetricType, ElemType>::MinDistance(const CellBound& other)
     }
 
   // The compiler should optimize out this if statement entirely.
-  if (MetricType::TakeRoot)
-    return (ElemType) pow((double) minSum,
-        1.0 / (double) MetricType::Power) / 2.0;
+  if (MetricType::Power == 1)
+    return minSum * 0.5;
+  else if (MetricType::Power == 2)
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) std::sqrt(minSum) * 0.5;
+    else
+      return minSum * 0.25;
+  }
   else
-    return minSum / pow(2.0, MetricType::Power);
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) pow((double) minSum,
+          1.0 / (double) MetricType::Power) / 2.0;
+    else
+      return minSum / pow(2.0, MetricType::Power);
+  }
 }
 
 /**
@@ -541,7 +592,13 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
     {
       ElemType v = std::max(fabs(point[d] - loBound(d, i)),
           fabs(hiBound(d, i) - point[d]));
-      sum += pow(v, (ElemType) MetricType::Power);
+
+      if (MetricType::Power == 1)
+        sum += v; // v is non-negative.
+      else if (MetricType::Power == 2)
+        sum += v * v;
+      else
+        sum += std::pow(v, (ElemType) MetricType::Power);
     }
 
     if (sum > maxSum)
@@ -550,7 +607,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
 
   // The compiler should optimize out this if statement entirely.
   if (MetricType::TakeRoot)
-    return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+  {
+    if (MetricType::Power == 1)
+      return maxSum;
+    else if (MetricType::Power == 2)
+      return (ElemType) std::sqrt(maxSum);
+    else
+      return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+  }
   else
     return maxSum;
 }
@@ -576,7 +640,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
       {
         v = std::max(fabs(other.hiBound(d, j) - loBound(d, i)),
             fabs(hiBound(d, i) - other.loBound(d, j)));
-        sum += pow(v, (ElemType) MetricType::Power); // v is non-negative.
+
+        // The compiler should optimize out this if statement entirely.
+        if (MetricType::Power == 1)
+          sum += v; // v is non-negative.
+        else if (MetricType::Power == 2)
+          sum += v * v;
+        else
+          sum += std::pow(v, (ElemType) MetricType::Power);
       }
 
       if (sum > maxSum)
@@ -585,7 +656,14 @@ inline ElemType CellBound<MetricType, ElemType>::MaxDistance(
 
   // The compiler should optimize out this if statement entirely.
   if (MetricType::TakeRoot)
-    return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+  {
+    if (MetricType::Power == 1)
+      return maxSum;
+    else if (MetricType::Power == 2)
+      return (ElemType) std::sqrt(maxSum);
+    else
+      return (ElemType) pow((double) maxSum, 1.0 / (double) MetricType::Power);
+  }
   else
     return maxSum;
 }
@@ -626,8 +704,22 @@ CellBound<MetricType, ElemType>::RangeDistance(
           vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
         }
 
-        loSum += pow(vLo, (ElemType) MetricType::Power);
-        hiSum += pow(vHi, (ElemType) MetricType::Power);
+        // The compiler should optimize out this if statement entirely.
+        if (MetricType::Power == 1)
+        {
+          loSum += vLo; // vLo is non-negative.
+          hiSum += vHi; // vHi is non-negative.
+        }
+        else if (MetricType::Power == 2)
+        {
+          loSum += vLo * vLo;
+          hiSum += vHi * vHi;
+        }
+        else
+        {
+          loSum += std::pow(vLo, (ElemType) MetricType::Power);
+          hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+        }
       }
 
       if (loSum < minLoSum)
@@ -637,9 +729,19 @@ CellBound<MetricType, ElemType>::RangeDistance(
     }
 
   if (MetricType::TakeRoot)
-    return math::RangeType<ElemType>(
-        (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
-        (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+  {
+    if (MetricType::Power == 1)
+      return math::RangeType<ElemType>(minLoSum, maxHiSum);
+    else if (MetricType::Power == 2)
+      return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
+                                       (ElemType) std::sqrt(maxHiSum));
+    else
+    {
+      return math::RangeType<ElemType>(
+          (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
+          (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+    }
+  }
   else
     return math::RangeType<ElemType>(minLoSum, maxHiSum);
 }
@@ -688,8 +790,22 @@ CellBound<MetricType, ElemType>::RangeDistance(
         }
       }
 
-      loSum += pow(vLo, (ElemType) MetricType::Power);
-      hiSum += pow(vHi, (ElemType) MetricType::Power);
+      // The compiler should optimize out this if statement entirely.
+      if (MetricType::Power == 1)
+      {
+        loSum += vLo; // vLo is non-negative.
+        hiSum += vHi; // vHi is non-negative.
+      }
+      else if (MetricType::Power == 2)
+      {
+        loSum += vLo * vLo;
+        hiSum += vHi * vHi;
+      }
+      else
+      {
+        loSum += std::pow(vLo, (ElemType) MetricType::Power);
+        hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+      }
     }
     if (loSum < minLoSum)
       minLoSum = loSum;
@@ -698,9 +814,19 @@ CellBound<MetricType, ElemType>::RangeDistance(
   }
 
   if (MetricType::TakeRoot)
-    return math::RangeType<ElemType>(
-        (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
-        (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+  {
+    if (MetricType::Power == 1)
+      return math::RangeType<ElemType>(minLoSum, maxHiSum);
+    else if (MetricType::Power == 2)
+      return math::RangeType<ElemType>((ElemType) std::sqrt(minLoSum),
+                                       (ElemType) std::sqrt(maxHiSum));
+    else
+    {
+      return math::RangeType<ElemType>(
+          (ElemType) pow((double) minLoSum, 1.0 / (double) MetricType::Power),
+          (ElemType) pow((double) maxHiSum, 1.0 / (double) MetricType::Power));
+    }
+  }
   else
     return math::RangeType<ElemType>(minLoSum, maxHiSum);
 }




More information about the mlpack-git mailing list