[mlpack-git] master: HRectBound improvements. (55e1579)

gitdub at mlpack.org gitdub at mlpack.org
Fri Aug 12 14:10:00 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/8a0ad2fbe4db5614a6aad27cfb5e101ae8b1db96...dc6bae4e8634486b384b67e3ae7a690f34bdc677

>---------------------------------------------------------------

commit 55e15792e4dc2135c6e3421c5dae1322d07b0dcc
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Fri Aug 12 21:10:00 2016 +0300

    HRectBound improvements.


>---------------------------------------------------------------

55e15792e4dc2135c6e3421c5dae1322d07b0dcc
 src/mlpack/core/tree/hrectbound_impl.hpp | 168 ++++++++++++++++++++++++++-----
 1 file changed, 144 insertions(+), 24 deletions(-)

diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp
index 60aa2b8..196067f 100644
--- a/src/mlpack/core/tree/hrectbound_impl.hpp
+++ b/src/mlpack/core/tree/hrectbound_impl.hpp
@@ -173,18 +173,40 @@ inline ElemType HRectBound<MetricType, ElemType>::MinDistance(
     // Since only one of 'lower' or 'higher' is negative, if we add each's
     // absolute value to itself and then sum those two, our result is the
     // nonnegative half of the equation times two; then we raise to power Power.
-    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
-        (ElemType) MetricType::Power);
+    if (MetricType::Power == 1)
+      sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+    else if (MetricType::Power == 2)
+    {
+      ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+      sum += dist * dist;
+    }
+    else
+    {
+      sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+          (ElemType) MetricType::Power);
+    }
   }
 
   // Now take the Power'th root (but make sure our result is squared if it needs
   // to be); then cancel out the constant of 2 (which may have been squared now)
   // that was introduced earlier.  The compiler should optimize out the if
   // statement entirely.
-  if (MetricType::TakeRoot)
-    return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+  if (MetricType::Power == 1)
+    return sum * 0.5;
+  else if (MetricType::Power == 2)
+  {
+    if (MetricType::TakeRoot)
+      return std::sqrt(sum) * 0.5;
+    else
+      return sum * 0.25;
+  }
   else
-    return sum / pow(2.0, MetricType::Power);
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+    else
+      return sum / pow(2.0, MetricType::Power);
+  }
 }
 
 /**
@@ -208,8 +230,20 @@ ElemType HRectBound<MetricType, ElemType>::MinDistance(const HRectBound& other)
     // We invoke the following:
     //   x + fabs(x) = max(x * 2, 0)
     //   (x * 2)^2 / 4 = x^2
-    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
-        (ElemType) MetricType::Power);
+
+    // The compiler should optimize out this if statement entirely.
+    if (MetricType::Power == 1)
+      sum += (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+    else if (MetricType::Power == 2)
+    {
+      ElemType dist = (lower + std::fabs(lower)) + (higher + std::fabs(higher));
+      sum += dist * dist;
+    }
+    else
+    {
+      sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+          (ElemType) MetricType::Power);
+    }
 
     // Move bound pointers.
     mbound++;
@@ -217,10 +251,22 @@ ElemType HRectBound<MetricType, ElemType>::MinDistance(const HRectBound& other)
   }
 
   // The compiler should optimize out this if statement entirely.
-  if (MetricType::TakeRoot)
-    return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+  if (MetricType::Power == 1)
+    return sum * 0.5;
+  else if (MetricType::Power == 2)
+  {
+    if (MetricType::TakeRoot)
+      return std::sqrt(sum) * 0.5;
+    else
+      return sum * 0.25;
+  }
   else
-    return sum / pow(2.0, MetricType::Power);
+  {
+    if (MetricType::TakeRoot)
+      return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power) / 2.0;
+    else
+      return sum / pow(2.0, MetricType::Power);
+  }
 }
 
 /**
@@ -240,12 +286,26 @@ inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
   {
     ElemType v = std::max(fabs(point[d] - bounds[d].Lo()),
         fabs(bounds[d].Hi() - point[d]));
-    sum += pow(v, (ElemType) MetricType::Power);
+
+    // The compiler should optimize out this if statement entirely.
+    if (MetricType::Power == 1)
+      sum += v; // v is non-negative.
+    else if (MetricType::Power == 2)
+      sum += v * v;
+    else
+      sum += std::pow(v, (ElemType) MetricType::Power);
   }
 
   // The compiler should optimize out this if statement entirely.
   if (MetricType::TakeRoot)
-    return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+  {
+    if (MetricType::Power == 1)
+      return sum;
+    else if (MetricType::Power == 2)
+      return std::sqrt(sum);
+    else
+      return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+  }
   else
     return sum;
 }
@@ -267,12 +327,26 @@ inline ElemType HRectBound<MetricType, ElemType>::MaxDistance(
   {
     v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
         fabs(bounds[d].Hi() - other.bounds[d].Lo()));
-    sum += pow(v, (ElemType) MetricType::Power); // v is non-negative.
+
+    // The compiler should optimize out this if statement entirely.
+    if (MetricType::Power == 1)
+      sum += v; // v is non-negative.
+    else if (MetricType::Power == 2)
+      sum += v * v;
+    else
+      sum += std::pow(v, (ElemType) MetricType::Power);
   }
 
   // The compiler should optimize out this if statement entirely.
   if (MetricType::TakeRoot)
-    return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+  {
+    if (MetricType::Power == 1)
+      return sum;
+    else if (MetricType::Power == 2)
+      return std::sqrt(sum);
+    else
+      return (ElemType) pow((double) sum, 1.0 / (double) MetricType::Power);
+  }
   else
     return sum;
 }
@@ -307,14 +381,37 @@ HRectBound<MetricType, ElemType>::RangeDistance(
       vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
     }
 
-    loSum += pow(vLo, (ElemType) MetricType::Power);
-    hiSum += pow(vHi, (ElemType) MetricType::Power);
+    // The compiler should optimize out this if statement entirely.
+    if (MetricType::Power == 1)
+    {
+      loSum += vLo; // vLo is non-negative.
+      hiSum += vHi; // vHi is non-negative.
+    }
+    else if (MetricType::Power == 2)
+    {
+      loSum += vLo * vLo;
+      hiSum += vHi * vHi;
+    }
+    else
+    {
+      loSum += std::pow(vLo, (ElemType) MetricType::Power);
+      hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+    }
   }
 
   if (MetricType::TakeRoot)
-    return math::RangeType<ElemType>(
-        (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
-        (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+  {
+    if (MetricType::Power == 1)
+      return math::RangeType<ElemType>(loSum, hiSum);
+    else if (MetricType::Power == 2)
+      return math::RangeType<ElemType>(std::sqrt(loSum), std::sqrt(hiSum));
+    else
+    {
+      return math::RangeType<ElemType>(
+          (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+          (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+    }
+  }
   else
     return math::RangeType<ElemType>(loSum, hiSum);
 }
@@ -359,14 +456,37 @@ HRectBound<MetricType, ElemType>::RangeDistance(
       }
     }
 
-    loSum += pow(vLo, (ElemType) MetricType::Power);
-    hiSum += pow(vHi, (ElemType) MetricType::Power);
+    // The compiler should optimize out this if statement entirely.
+    if (MetricType::Power == 1)
+    {
+      loSum += vLo; // vLo is non-negative.
+      hiSum += vHi; // vHi is non-negative.
+    }
+    else if (MetricType::Power == 2)
+    {
+      loSum += vLo * vLo;
+      hiSum += vHi * vHi;
+    }
+    else
+    {
+      loSum += std::pow(vLo, (ElemType) MetricType::Power);
+      hiSum += std::pow(vHi, (ElemType) MetricType::Power);
+    }
   }
 
   if (MetricType::TakeRoot)
-    return math::RangeType<ElemType>(
-        (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
-        (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+  {
+    if (MetricType::Power == 1)
+      return math::RangeType<ElemType>(loSum, hiSum);
+    else if (MetricType::Power == 2)
+      return math::RangeType<ElemType>(std::sqrt(loSum), std::sqrt(hiSum));
+    else
+    {
+      return math::RangeType<ElemType>(
+          (ElemType) pow((double) loSum, 1.0 / (double) MetricType::Power),
+          (ElemType) pow((double) hiSum, 1.0 / (double) MetricType::Power));
+    }
+  }
   else
     return math::RangeType<ElemType>(loSum, hiSum);
 }




More information about the mlpack-git mailing list