[mlpack-git] master: Refactor to take a MetricType. (47cc6fe)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:44 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit 47cc6fe6f072e72ef28604e203e9873c2e1038e8
Author: ryan <ryan at ratml.org>
Date:   Wed Jul 22 22:07:47 2015 -0400

    Refactor to take a MetricType.


>---------------------------------------------------------------

47cc6fe6f072e72ef28604e203e9873c2e1038e8
 src/mlpack/core/tree/hrectbound.hpp      |  40 ++++++---
 src/mlpack/core/tree/hrectbound_impl.hpp | 142 ++++++++++++++++---------------
 2 files changed, 100 insertions(+), 82 deletions(-)

diff --git a/src/mlpack/core/tree/hrectbound.hpp b/src/mlpack/core/tree/hrectbound.hpp
index ea525f3..13bf1cc 100644
--- a/src/mlpack/core/tree/hrectbound.hpp
+++ b/src/mlpack/core/tree/hrectbound.hpp
@@ -17,6 +17,26 @@
 namespace mlpack {
 namespace bound {
 
+namespace meta /** Metaprogramming utilities. */ {
+
+//! Utility struct where Value is true if and only if the argument is of type
+//! LMetric.
+template<typename MetricType>
+struct IsLMetric
+{
+  static const bool Value = false;
+};
+
+//! Specialization for IsLMetric when the argument is of type LMetric.
+template<>
+template<int Power, bool TakeRoot>
+struct IsLMetric<metric::LMetric<Power, TakeRoot>>
+{
+  static const bool Value = true;
+};
+
+} // namespace util
+
 /**
  * Hyper-rectangle bound for an L-metric.  This should be used in conjunction
  * with the LMetric class.  Be sure to use the same template parameters for
@@ -26,13 +46,14 @@ namespace bound {
  * @tparam TakeRoot Whether or not the root should be taken (see LMetric
  *     documentation).
  */
-template<int Power = 2, bool TakeRoot = true>
+template<typename MetricType = metric::LMetric<2, true>>
 class HRectBound
 {
- public:
-  //! This is the metric type that this bound is using.
-  typedef metric::LMetric<Power, TakeRoot> MetricType;
+  // It is required that HRectBound have an LMetric as the given MetricType.
+  static_assert(meta::IsLMetric<MetricType>::Value == true,
+      "HRectBound can only be used with the LMetric<> metric type.");
 
+ public:
   /**
    * Empty constructor; creates a bound of dimensionality 0.
    */
@@ -174,13 +195,6 @@ class HRectBound
    */
   std::string ToString() const;
 
-  /**
-   * Return the metric associated with this bound.  Because it is an LMetric, it
-   * cannot store state, so we can make it on the fly.  It is also static
-   * because the metric is only dependent on the template arguments.
-   */
-  static MetricType Metric() { return metric::LMetric<Power, TakeRoot>(); }
-
  private:
   //! The dimensionality of the bound.
   size_t dim;
@@ -191,8 +205,8 @@ class HRectBound
 };
 
 // A specialization of BoundTraits for this class.
-template<int Power, bool TakeRoot>
-struct BoundTraits<HRectBound<Power, TakeRoot>>
+template<typename MetricType>
+struct BoundTraits<HRectBound<MetricType>>
 {
   //! These bounds are always tight for each dimension.
   const static bool HasTightBounds = true;
diff --git a/src/mlpack/core/tree/hrectbound_impl.hpp b/src/mlpack/core/tree/hrectbound_impl.hpp
index 43a7982..0b9f76c 100644
--- a/src/mlpack/core/tree/hrectbound_impl.hpp
+++ b/src/mlpack/core/tree/hrectbound_impl.hpp
@@ -20,8 +20,8 @@ namespace bound {
 /**
  * Empty constructor.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound() :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound() :
     dim(0),
     bounds(NULL),
     minWidth(0)
@@ -31,8 +31,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound() :
  * Initializes to specified dimensionality with each dimension the empty
  * set.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound(const size_t dimension) :
     dim(dimension),
     bounds(new math::Range[dim]),
     minWidth(0)
@@ -41,8 +41,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound(const size_t dimension) :
 /***
  * Copy constructor necessary to prevent memory leaks.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
+template<typename MetricType>
+inline HRectBound<MetricType>::HRectBound(const HRectBound& other) :
     dim(other.Dim()),
     bounds(new math::Range[dim]),
     minWidth(other.MinWidth())
@@ -55,8 +55,8 @@ inline HRectBound<Power, TakeRoot>::HRectBound(const HRectBound& other) :
 /***
  * Same as the copy constructor.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
+template<typename MetricType>
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator=(
     const HRectBound& other)
 {
   if (dim != other.Dim())
@@ -81,8 +81,8 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator=(
 /**
  * Destructor: clean up memory.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>::~HRectBound()
+template<typename MetricType>
+inline HRectBound<MetricType>::~HRectBound()
 {
   if (bounds)
     delete[] bounds;
@@ -91,8 +91,8 @@ inline HRectBound<Power, TakeRoot>::~HRectBound()
 /**
  * Resets all dimensions to the empty set.
  */
-template<int Power, bool TakeRoot>
-inline void HRectBound<Power, TakeRoot>::Clear()
+template<typename MetricType>
+inline void HRectBound<MetricType>::Clear()
 {
   for (size_t i = 0; i < dim; i++)
     bounds[i] = math::Range();
@@ -104,8 +104,8 @@ inline void HRectBound<Power, TakeRoot>::Clear()
  *
  * @param centroid Vector which the centroid will be written to.
  */
-template<int Power, bool TakeRoot>
-inline void HRectBound<Power, TakeRoot>::Center(arma::vec& center) const
+template<typename MetricType>
+inline void HRectBound<MetricType>::Center(arma::vec& center) const
 {
   // Set size correctly if necessary.
   if (!(center.n_elem == dim))
@@ -120,8 +120,8 @@ inline void HRectBound<Power, TakeRoot>::Center(arma::vec& center) const
  *
  * @return Volume of the hyperrectangle.
  */
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::Volume() const
+template<typename MetricType>
+inline double HRectBound<MetricType>::Volume() const
 {
   double volume = 1.0;
   for (size_t i = 0; i < dim; ++i)
@@ -133,9 +133,9 @@ inline double HRectBound<Power, TakeRoot>::Volume() const
 /**
  * Calculates minimum bound-to-point squared distance.
  */
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename VecType>
-inline double HRectBound<Power, TakeRoot>::MinDistance(
+inline double HRectBound<MetricType>::MinDistance(
     const VecType& point,
     typename boost::enable_if<IsVector<VecType> >* /* junk */) const
 {
@@ -152,24 +152,25 @@ inline double HRectBound<Power, TakeRoot>::MinDistance(
     // Since only one of 'lower' or 'higher' is negative, if we add each's
     // absolute value to itself and then sum those two, our result is the
     // nonnegative half of the equation times two; then we raise to power Power.
-    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+        (double) MetricType::Power);
   }
 
   // Now take the Power'th root (but make sure our result is squared if it needs
   // to be); then cancel out the constant of 2 (which may have been squared now)
   // that was introduced earlier.  The compiler should optimize out the if
   // statement entirely.
-  if (TakeRoot)
-    return pow(sum, 1.0 / (double) Power) / 2.0;
+  if (MetricType::TakeRoot)
+    return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
   else
-    return sum / pow(2.0, Power);
+    return sum / pow(2.0, MetricType::Power);
 }
 
 /**
  * Calculates minimum bound-to-bound squared distance.
  */
-template<int Power, bool TakeRoot>
-double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
+template<typename MetricType>
+double HRectBound<MetricType>::MinDistance(const HRectBound& other) const
 {
   Log::Assert(dim == other.dim);
 
@@ -185,7 +186,8 @@ double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
     // We invoke the following:
     //   x + fabs(x) = max(x * 2, 0)
     //   (x * 2)^2 / 4 = x^2
-    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)), (double) Power);
+    sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
+        (double) MetricType::Power);
 
     // Move bound pointers.
     mbound++;
@@ -193,18 +195,18 @@ double HRectBound<Power, TakeRoot>::MinDistance(const HRectBound& other) const
   }
 
   // The compiler should optimize out this if statement entirely.
-  if (TakeRoot)
-    return pow(sum, 1.0 / (double) Power) / 2.0;
+  if (MetricType::TakeRoot)
+    return pow(sum, 1.0 / (double) MetricType::Power) / 2.0;
   else
-    return sum / pow(2.0, Power);
+    return sum / pow(2.0, MetricType::Power);
 }
 
 /**
  * Calculates maximum bound-to-point squared distance.
  */
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename VecType>
-inline double HRectBound<Power, TakeRoot>::MaxDistance(
+inline double HRectBound<MetricType>::MaxDistance(
     const VecType& point,
     typename boost::enable_if<IsVector<VecType> >* /* junk */) const
 {
@@ -216,12 +218,12 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(
   {
     double v = std::max(fabs(point[d] - bounds[d].Lo()),
         fabs(bounds[d].Hi() - point[d]));
-    sum += pow(v, (double) Power);
+    sum += pow(v, (double) MetricType::Power);
   }
 
   // The compiler should optimize out this if statement entirely.
-  if (TakeRoot)
-    return pow(sum, 1.0 / (double) Power);
+  if (MetricType::TakeRoot)
+    return pow(sum, 1.0 / (double) MetricType::Power);
   else
     return sum;
 }
@@ -229,8 +231,8 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(
 /**
  * Computes maximum distance.
  */
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
+template<typename MetricType>
+inline double HRectBound<MetricType>::MaxDistance(const HRectBound& other)
     const
 {
   double sum = 0;
@@ -242,12 +244,12 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
   {
     v = std::max(fabs(other.bounds[d].Hi() - bounds[d].Lo()),
         fabs(bounds[d].Hi() - other.bounds[d].Lo()));
-    sum += pow(v, (double) Power); // v is non-negative.
+    sum += pow(v, (double) MetricType::Power); // v is non-negative.
   }
 
   // The compiler should optimize out this if statement entirely.
-  if (TakeRoot)
-    return pow(sum, 1.0 / (double) Power);
+  if (MetricType::TakeRoot)
+    return pow(sum, 1.0 / (double) MetricType::Power);
   else
     return sum;
 }
@@ -255,8 +257,8 @@ inline double HRectBound<Power, TakeRoot>::MaxDistance(const HRectBound& other)
 /**
  * Calculates minimum and maximum bound-to-bound squared distance.
  */
-template<int Power, bool TakeRoot>
-inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
+template<typename MetricType>
+inline math::Range HRectBound<MetricType>::RangeDistance(
     const HRectBound& other) const
 {
   double loSum = 0;
@@ -281,13 +283,13 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
       vLo = (v2 > 0) ? v2 : 0; // Force to be 0 if negative.
     }
 
-    loSum += pow(vLo, (double) Power);
-    hiSum += pow(vHi, (double) Power);
+    loSum += pow(vLo, (double) MetricType::Power);
+    hiSum += pow(vHi, (double) MetricType::Power);
   }
 
-  if (TakeRoot)
-    return math::Range(pow(loSum, 1.0 / (double) Power),
-                       pow(hiSum, 1.0 / (double) Power));
+  if (MetricType::TakeRoot)
+    return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
+                       pow(hiSum, 1.0 / (double) MetricType::Power));
   else
     return math::Range(loSum, hiSum);
 }
@@ -295,9 +297,9 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
 /**
  * Calculates minimum and maximum bound-to-point squared distance.
  */
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename VecType>
-inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
+inline math::Range HRectBound<MetricType>::RangeDistance(
     const VecType& point,
     typename boost::enable_if<IsVector<VecType> >* /* junk */) const
 {
@@ -331,13 +333,13 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
       }
     }
 
-    loSum += pow(vLo, (double) Power);
-    hiSum += pow(vHi, (double) Power);
+    loSum += pow(vLo, (double) MetricType::Power);
+    hiSum += pow(vHi, (double) MetricType::Power);
   }
 
-  if (TakeRoot)
-    return math::Range(pow(loSum, 1.0 / (double) Power),
-                       pow(hiSum, 1.0 / (double) Power));
+  if (MetricType::TakeRoot)
+    return math::Range(pow(loSum, 1.0 / (double) MetricType::Power),
+                       pow(hiSum, 1.0 / (double) MetricType::Power));
   else
     return math::Range(loSum, hiSum);
 }
@@ -345,9 +347,9 @@ inline math::Range HRectBound<Power, TakeRoot>::RangeDistance(
 /**
  * Expands this region to include a new point.
  */
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename MatType>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
     const MatType& data)
 {
   Log::Assert(data.n_rows == dim);
@@ -370,8 +372,8 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
 /**
  * Expands this region to encompass another bound.
  */
-template<int Power, bool TakeRoot>
-inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
+template<typename MetricType>
+inline HRectBound<MetricType>& HRectBound<MetricType>::operator|=(
     const HRectBound& other)
 {
   assert(other.dim == dim);
@@ -391,9 +393,9 @@ inline HRectBound<Power, TakeRoot>& HRectBound<Power, TakeRoot>::operator|=(
 /**
  * Determines if a point is within this bound.
  */
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename VecType>
-inline bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
+inline bool HRectBound<MetricType>::Contains(const VecType& point) const
 {
   for (size_t i = 0; i < point.n_elem; i++)
   {
@@ -407,23 +409,24 @@ inline bool HRectBound<Power, TakeRoot>::Contains(const VecType& point) const
 /**
  * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
  */
-template<int Power, bool TakeRoot>
-inline double HRectBound<Power, TakeRoot>::Diameter() const
+template<typename MetricType>
+inline double HRectBound<MetricType>::Diameter() const
 {
   double d = 0;
   for (size_t i = 0; i < dim; ++i)
-    d += std::pow(bounds[i].Hi() - bounds[i].Lo(), (double) Power);
+    d += std::pow(bounds[i].Hi() - bounds[i].Lo(),
+        (double) MetricType::Power);
 
-  if (TakeRoot)
-    return std::pow(d, 1.0 / (double) Power);
+  if (MetricType::TakeRoot)
+    return std::pow(d, 1.0 / (double) MetricType::Power);
   else
     return d;
 }
 
 //! Serialize the bound object.
-template<int Power, bool TakeRoot>
+template<typename MetricType>
 template<typename Archive>
-void HRectBound<Power, TakeRoot>::Serialize(Archive& ar,
+void HRectBound<MetricType>::Serialize(Archive& ar,
                                             const unsigned int /* version */)
 {
   ar & data::CreateNVP(dim, "dim");
@@ -443,13 +446,14 @@ void HRectBound<Power, TakeRoot>::Serialize(Archive& ar,
 /**
  * Returns a string representation of this object.
  */
-template<int Power, bool TakeRoot>
-std::string HRectBound<Power, TakeRoot>::ToString() const
+template<typename MetricType>
+std::string HRectBound<MetricType>::ToString() const
 {
   std::ostringstream convert;
   convert << "HRectBound [" << this << "]" << std::endl;
-  convert << "  Power: " << Power << std::endl;
-  convert << "  TakeRoot: " << (TakeRoot ? "true" : "false") << std::endl;
+  convert << "  Power: " << MetricType::Power << std::endl;
+  convert << "  TakeRoot: " << (MetricType::TakeRoot ? "true" : "false")
+      << std::endl;
   convert << "  Dimensionality: " << dim << std::endl;
   convert << "  Bounds: " << std::endl;
   for (size_t i = 0; i < dim; ++i)



More information about the mlpack-git mailing list