[mlpack-git] master: Refactor to handle arbitrary element types. (1d7cc4a)

gitdub at mlpack.org gitdub at mlpack.org
Mon Mar 7 14:59:45 EST 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f45c17bc4d70ee5d82bf11a91850a34b814eccff...a69871c4eb63087c825502fd2277565453720568

>---------------------------------------------------------------

commit 1d7cc4a098b9167f2917897665d1b808075b3860
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Jan 28 12:04:34 2016 +0000

    Refactor to handle arbitrary element types.


>---------------------------------------------------------------

1d7cc4a098b9167f2917897665d1b808075b3860
 src/mlpack/core/tree/ballbound.hpp      | 41 ++++++++++++-----------
 src/mlpack/core/tree/ballbound_impl.hpp | 58 +++++++++++++++++++--------------
 2 files changed, 56 insertions(+), 43 deletions(-)

diff --git a/src/mlpack/core/tree/ballbound.hpp b/src/mlpack/core/tree/ballbound.hpp
index 2201fc4..c6018e6 100644
--- a/src/mlpack/core/tree/ballbound.hpp
+++ b/src/mlpack/core/tree/ballbound.hpp
@@ -19,21 +19,24 @@ namespace bound {
  * specific point (center). TMetricType is the custom metric type that defaults
  * to the Euclidean (L2) distance.
  *
- * @tparam VecType Type of vector (arma::vec or arma::sp_vec).
+ * @tparam VecType Type of vector (arma::vec or arma::sp_vec or similar).
  * @tparam TMetricType metric type used in the distance measure.
  */
 template<typename VecType = arma::vec,
-         typename TMetricType = metric::LMetric<2, true> >
+         typename TMetricType = metric::LMetric<2, true>>
 class BallBound
 {
  public:
+  //! The underlying data type.
+  typedef typename VecType::elem_type ElemType;
+  //! A public version of the vector type.
   typedef VecType Vec;
   //! Needed for BinarySpaceTree.
   typedef TMetricType MetricType;
 
  private:
   //! The radius of the ball bound.
-  double radius;
+  ElemType radius;
   //! The center of the ball bound.
   VecType center;
   //! The metric used in this bound.
@@ -65,7 +68,7 @@ class BallBound
    * @param radius Radius of ball bound.
    * @param center Center of ball bound.
    */
-  BallBound(const double radius, const VecType& center);
+  BallBound(const ElemType radius, const VecType& center);
 
   //! Copy constructor. To prevent memory leaks.
   BallBound(const BallBound& other);
@@ -80,9 +83,9 @@ class BallBound
   ~BallBound();
 
   //! Get the radius of the ball.
-  double Radius() const { return radius; }
+  ElemType Radius() const { return radius; }
   //! Modify the radius of the ball.
-  double& Radius() { return radius; }
+  ElemType& Radius() { return radius; }
 
   //! Get the center point of the ball.
   const VecType& Center() const { return center; }
@@ -90,16 +93,16 @@ class BallBound
   VecType& Center() { return center; }
 
   //! Get the dimensionality of the ball.
-  double Dim() const { return center.n_elem; }
+  size_t Dim() const { return center.n_elem; }
 
   /**
    * Get the minimum width of the bound (this is same as the diameter).
    * For ball bounds, width along all dimensions remain same.
    */
-  double MinWidth() const { return radius * 2.0; }
+  ElemType MinWidth() const { return radius * 2.0; }
 
   //! Get the range in a certain dimension.
-  math::Range operator[](const size_t i) const;
+  math::RangeType<ElemType> operator[](const size_t i) const;
 
   /**
    * Determines if a point is within this bound.
@@ -117,42 +120,42 @@ class BallBound
    * Calculates minimum bound-to-point squared distance.
    */
   template<typename OtherVecType>
-  double MinDistance(const OtherVecType& point,
-                     typename boost::enable_if<IsVector<OtherVecType> >* = 0)
+  ElemType MinDistance(const OtherVecType& point,
+                       typename boost::enable_if<IsVector<OtherVecType>>* = 0)
       const;
 
   /**
    * Calculates minimum bound-to-bound squared distance.
    */
-  double MinDistance(const BallBound& other) const;
+  ElemType MinDistance(const BallBound& other) const;
 
   /**
    * Computes maximum distance.
    */
   template<typename OtherVecType>
-  double MaxDistance(const OtherVecType& point,
-                     typename boost::enable_if<IsVector<OtherVecType> >* = 0)
+  ElemType MaxDistance(const OtherVecType& point,
+                       typename boost::enable_if<IsVector<OtherVecType>>* = 0)
       const;
 
   /**
    * Computes maximum distance.
    */
-  double MaxDistance(const BallBound& other) const;
+  ElemType MaxDistance(const BallBound& other) const;
 
   /**
    * Calculates minimum and maximum bound-to-point distance.
    */
   template<typename OtherVecType>
-  math::Range RangeDistance(
+  math::RangeType<ElemType> RangeDistance(
       const OtherVecType& other,
-      typename boost::enable_if<IsVector<OtherVecType> >* = 0) const;
+      typename boost::enable_if<IsVector<OtherVecType>>* = 0) const;
 
   /**
    * Calculates minimum and maximum bound-to-bound distance.
    *
    * Example: bound1.MinDistanceSq(other) for minimum distance.
    */
-  math::Range RangeDistance(const BallBound& other) const;
+  math::RangeType<ElemType> RangeDistance(const BallBound& other) const;
 
   /**
    * Expand the bound to include the given node.
@@ -173,7 +176,7 @@ class BallBound
   /**
    * Returns the diameter of the ballbound.
    */
-  double Diameter() const { return 2 * radius; }
+  ElemType Diameter() const { return 2 * radius; }
 
   //! Returns the distance metric used in this bound.
   const TMetricType& Metric() const { return *metric; }
diff --git a/src/mlpack/core/tree/ballbound_impl.hpp b/src/mlpack/core/tree/ballbound_impl.hpp
index 94eac4c..66d6a12 100644
--- a/src/mlpack/core/tree/ballbound_impl.hpp
+++ b/src/mlpack/core/tree/ballbound_impl.hpp
@@ -20,7 +20,7 @@ namespace bound {
 //! Empty Constructor.
 template<typename VecType, typename TMetricType>
 BallBound<VecType, TMetricType>::BallBound() :
-    radius(-DBL_MAX),
+    radius(std::numeric_limits<ElemType>::min()),
     metric(new TMetricType()),
     ownsMetric(true)
 { /* Nothing to do. */ }
@@ -32,7 +32,7 @@ BallBound<VecType, TMetricType>::BallBound() :
  */
 template<typename VecType, typename TMetricType>
 BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
-    radius(-DBL_MAX),
+    radius(std::numeric_limits<ElemType>::min()),
     center(dimension),
     metric(new TMetricType()),
     ownsMetric(true)
@@ -45,8 +45,8 @@ BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
  * @param center Center of ball bound.
  */
 template<typename VecType, typename TMetricType>
-BallBound<VecType, TMetricType>::BallBound(const double radius,
-    const VecType& center) :
+BallBound<VecType, TMetricType>::BallBound(const ElemType radius,
+                                           const VecType& center) :
     radius(radius),
     center(center),
     metric(new TMetricType()),
@@ -98,7 +98,8 @@ BallBound<VecType, TMetricType>::~BallBound()
 
 //! Get the range in a certain dimension.
 template<typename VecType, typename TMetricType>
-math::Range BallBound<VecType, TMetricType>::operator[](const size_t i) const
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::operator[](const size_t i) const
 {
   if (radius < 0)
     return math::Range();
@@ -123,12 +124,13 @@ bool BallBound<VecType, TMetricType>::Contains(const VecType& point) const
  */
 template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-double BallBound<VecType, TMetricType>::MinDistance(
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MinDistance(
     const OtherVecType& point,
-    typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
+    typename boost::enable_if<IsVector<OtherVecType>>* /* junk */) const
 {
   if (radius < 0)
-    return DBL_MAX;
+    return std::numeric_limits<ElemType>::max();
   else
     return math::ClampNonNegative(metric->Evaluate(point, center) - radius);
 }
@@ -137,13 +139,15 @@ double BallBound<VecType, TMetricType>::MinDistance(
  * Calculates minimum bound-to-bound squared distance.
  */
 template<typename VecType, typename TMetricType>
-double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) const
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MinDistance(const BallBound& other)
+    const
 {
   if (radius < 0)
-    return DBL_MAX;
+    return std::numeric_limits<ElemType>::max();
   else
   {
-    const double delta = metric->Evaluate(center, other.center) - radius -
+    const ElemType delta = metric->Evaluate(center, other.center) - radius -
         other.radius;
     return math::ClampNonNegative(delta);
   }
@@ -154,12 +158,13 @@ double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) cons
  */
 template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-double BallBound<VecType, TMetricType>::MaxDistance(
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MaxDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
   if (radius < 0)
-    return DBL_MAX;
+    return std::numeric_limits<ElemType>::max();
   else
     return metric->Evaluate(point, center) + radius;
 }
@@ -168,11 +173,12 @@ double BallBound<VecType, TMetricType>::MaxDistance(
  * Computes maximum distance.
  */
 template<typename VecType, typename TMetricType>
-double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
+typename BallBound<VecType, TMetricType>::ElemType
+BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
     const
 {
   if (radius < 0)
-    return DBL_MAX;
+    return std::numeric_limits<ElemType>::max();
   else
     return metric->Evaluate(other.center, center) + radius + other.radius;
 }
@@ -184,30 +190,34 @@ double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
  */
 template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-math::Range BallBound<VecType, TMetricType>::RangeDistance(
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::RangeDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
   if (radius < 0)
-    return math::Range(DBL_MAX, DBL_MAX);
+    return math::Range(std::numeric_limits<ElemType>::max(),
+                       std::numeric_limits<ElemType>::max());
   else
   {
-    const double dist = metric->Evaluate(center, point);
+    const ElemType dist = metric->Evaluate(center, point);
     return math::Range(math::ClampNonNegative(dist - radius),
                                               dist + radius);
   }
 }
 
 template<typename VecType, typename TMetricType>
-math::Range BallBound<VecType, TMetricType>::RangeDistance(
+math::RangeType<typename BallBound<VecType, TMetricType>::ElemType>
+BallBound<VecType, TMetricType>::RangeDistance(
     const BallBound& other) const
 {
   if (radius < 0)
-    return math::Range(DBL_MAX, DBL_MAX);
+    return math::Range(std::numeric_limits<ElemType>::max(),
+                       std::numeric_limits<ElemType>::max());
   else
   {
-    const double dist = metric->Evaluate(center, other.center);
-    const double sumradius = radius + other.radius;
+    const ElemType dist = metric->Evaluate(center, other.center);
+    const ElemType sumradius = radius + other.radius;
     return math::Range(math::ClampNonNegative(dist - sumradius),
                                               dist + sumradius);
   }
@@ -250,14 +260,14 @@ BallBound<VecType, TMetricType>::operator|=(const MatType& data)
   // Now iteratively add points.
   for (size_t i = 0; i < data.n_cols; ++i)
   {
-    const double dist = metric->Evaluate(center, (VecType) data.col(i));
+    const ElemType dist = metric->Evaluate(center, (VecType) data.col(i));
 
     // See if the new point lies outside the bound.
     if (dist > radius)
     {
       // Move towards the new point and increase the radius just enough to
       // accommodate the new point.
-      arma::vec diff = data.col(i) - center;
+      const VecType diff = data.col(i) - center;
       center += ((dist - radius) / (2 * dist)) * diff;
       radius = 0.5 * (dist + radius);
     }




More information about the mlpack-git mailing list