[mlpack-svn] r16854 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 25 11:52:46 EDT 2014


Author: rcurtin
Date: Fri Jul 25 11:52:46 2014
New Revision: 16854

Log:
Contribution from Yash to solve #250 and make BallBound usable.


Modified:
   mlpack/trunk/src/mlpack/core/tree/ballbound.hpp
   mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/ballbound.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/ballbound.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/ballbound.hpp	Fri Jul 25 11:52:46 2014
@@ -15,29 +15,52 @@
 namespace bound {
 
 /**
- * Ball bound that works in the regular Euclidean metric space.
+ * Ball bound encloses a set of points at a specific distance (radius) from a
+ * specific point (center). TMetricType is the custom metric type that defaults
+ * to the Euclidean (L2) distance.
  *
  * @tparam VecType Type of vector (arma::vec or arma::sp_vec).
+ * @tparam TMetricType metric type used in the distance measure.
  */
-template<typename VecType = arma::vec>
+template<typename VecType = arma::vec,
+         typename TMetricType = metric::LMetric<2, true> >
 class BallBound
 {
  public:
   typedef VecType Vec;
+  //! Need this for Binary Space Partion Tree
+  typedef TMetricType MetricType;
 
  private:
+
+  //! The radius of the ball bound.
   double radius;
+
+  //! The center of the ball bound.
   VecType center;
 
+  //! The metric used in this bound.
+  TMetricType* metric;
+
+  /**
+   * To know whether this object allocated memory to the metric member
+   * variable. This will be true except in the copy constructor and the
+   * overloaded assignment operator. We need this to know whether we should
+   * delete the metric member variable in the destructor.
+   */
+  bool ownsMetric;
+
  public:
-  BallBound() : radius(0) { }
+
+  //! Empty Constructor.
+  BallBound();
 
   /**
    * Create the ball bound with the specified dimensionality.
    *
    * @param dimension Dimensionality of ball bound.
    */
-  BallBound(const size_t dimension) : radius(0), center(dimension) { }
+  BallBound(const size_t dimension);
 
   /**
    * Create the ball bound with the specified radius and center.
@@ -45,8 +68,16 @@
    * @param radius Radius of ball bound.
    * @param center Center of ball bound.
    */
-  BallBound(const double radius, const VecType& center) :
-      radius(radius), center(center) { }
+  BallBound(const double radius, const VecType& center);
+
+  //! Copy constructor. To prevent memory leaks.
+  BallBound(const BallBound& other);
+
+  //! For the same reason as the Copy Constructor. To prevent memory leaks.
+  BallBound& operator=(const BallBound& other);
+
+  //! Destructor to release allocated memory.
+  ~BallBound();
 
   //! Get the radius of the ball.
   double Radius() const { return radius; }
@@ -58,7 +89,16 @@
   //! Modify the center point of the ball.
   VecType& Center() { return center; }
 
-  // Get the range in a certain dimension.
+  //! Get the dimensionality of the ball.
+  double Dim() const { return center.n_elem; }
+
+  /**
+   * Get the minimum width of the bound (this is same as the diameter).
+   * For ball bounds, width along all dimensions remain same.
+   */
+  double MinWidth() const { return radius * 2.0; }
+
+  //! Get the range in a certain dimension.
   math::Range operator[](const size_t i) const;
 
   /**
@@ -67,13 +107,11 @@
   bool Contains(const VecType& point) const;
 
   /**
-   * Gets the center.
+   * Place the centroid of BallBound into the given vector.
    *
-   * Don't really use this directly.  This is only here for consistency
-   * with DHrectBound, so it can plug in more directly if a "centroid"
-   * is needed.
+   * @param centroid Vector which the centroid will be written to.
    */
-  void CalculateMidpoint(VecType& centroid) const;
+  void Centroid(VecType& centroid) const { centroid = center; }
 
   /**
    * Calculates minimum bound-to-point squared distance.
@@ -133,6 +171,16 @@
   const BallBound& operator|=(const MatType& data);
 
   /**
+   * Returns the diameter of the ballbound.
+   */
+  double Diameter() const { return 2 * radius; }
+
+  /**
+   * Returns the distance metric used in this bound.
+   */
+  TMetricType Metric() const { return *metric; }
+
+  /**
    * Returns a string representation of this object.
    */
   std::string ToString() const;

Modified: mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp	(original)
+++ mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp	Fri Jul 25 11:52:46 2014
@@ -17,9 +17,73 @@
 namespace mlpack {
 namespace bound {
 
+//! Empty Constructor.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound() :
+    radius(-DBL_MAX),
+    metric(new TMetricType()),
+    ownsMetric(true)
+{ /* Nothing to do. */ }
+
+/**
+ * Create the ball bound with the specified dimensionality.
+ *
+ * @param dimension Dimensionality of ball bound.
+ */
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
+    radius(-DBL_MAX),
+    center(dimension),
+    metric(new TMetricType()),
+    ownsMetric(true)
+{ /* Nothing to do. */ }
+
+/**
+ * Create the ball bound with the specified radius and center.
+ *
+ * @param radius Radius of ball bound.
+ * @param center Center of ball bound.
+ */
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const double radius,
+    const VecType& center) :
+    radius(radius),
+    center(center),
+    metric(new TMetricType()),
+    ownsMetric(true)
+{ /* Nothing to do. */ }
+
+//! Copy Constructor. To prevent memory leaks.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const BallBound& other) :
+    radius(other.radius),
+    center(other.center),
+    metric(other.metric),
+    ownsMetric(false)
+{ /* Nothing to do. */ }
+
+//! For the same reason as the Copy Constructor. To prevent memory leaks.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>& BallBound<VecType, TMetricType>::operator=(
+    const BallBound& other)
+{
+  radius = other.radius;
+  center = other.center;
+  metric = other.metric;
+  ownsMetric = false;
+}
+
+//! Destructor to release allocated memory.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::~BallBound()
+{
+  if (ownsMetric)
+    delete metric;
+}
+
 //! Get the range in a certain dimension.
-template<typename VecType>
-math::Range BallBound<VecType>::operator[](const size_t i) const
+template<typename VecType, typename TMetricType>
+math::Range BallBound<VecType, TMetricType>::operator[](const size_t i) const
 {
   if (radius < 0)
     return math::Range();
@@ -30,56 +94,42 @@
 /**
  * Determines if a point is within the bound.
  */
-template<typename VecType>
-bool BallBound<VecType>::Contains(const VecType& point) const
+template<typename VecType, typename TMetricType>
+bool BallBound<VecType, TMetricType>::Contains(const VecType& point) const
 {
   if (radius < 0)
     return false;
   else
-    return metric::EuclideanDistance::Evaluate(center, point) <= radius;
-}
-
-/**
- * Gets the center.
- *
- * Don't really use this directly.  This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
-template<typename VecType>
-void BallBound<VecType>::CalculateMidpoint(VecType& centroid) const
-{
-  centroid = center;
+    return metric->Evaluate(center, point) <= radius;
 }
 
 /**
  * Calculates minimum bound-to-point squared distance.
  */
-template<typename VecType>
+template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-double BallBound<VecType>::MinDistance(
+double BallBound<VecType, TMetricType>::MinDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
   if (radius < 0)
     return DBL_MAX;
   else
-    return math::ClampNonNegative(metric::EuclideanDistance::Evaluate(point,
-        center) - radius);
+    return math::ClampNonNegative(metric->Evaluate(point, center) - radius);
 }
 
 /**
  * Calculates minimum bound-to-bound squared distance.
  */
-template<typename VecType>
-double BallBound<VecType>::MinDistance(const BallBound& other) const
+template<typename VecType, typename TMetricType>
+double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) const
 {
   if (radius < 0)
     return DBL_MAX;
   else
   {
-    double delta = metric::EuclideanDistance::Evaluate(center, other.center)
-        - radius - other.radius;
+    const double delta = metric->Evaluate(center, other.center) - radius -
+        other.radius;
     return math::ClampNonNegative(delta);
   }
 }
@@ -87,29 +137,29 @@
 /**
  * Computes maximum distance.
  */
-template<typename VecType>
+template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-double BallBound<VecType>::MaxDistance(
+double BallBound<VecType, TMetricType>::MaxDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
   if (radius < 0)
     return DBL_MAX;
   else
-    return metric::EuclideanDistance::Evaluate(point, center) + radius;
+    return metric->Evaluate(point, center) + radius;
 }
 
 /**
  * Computes maximum distance.
  */
-template<typename VecType>
-double BallBound<VecType>::MaxDistance(const BallBound& other) const
+template<typename VecType, typename TMetricType>
+double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
+    const
 {
   if (radius < 0)
     return DBL_MAX;
   else
-    return metric::EuclideanDistance::Evaluate(other.center, center) + radius
-        + other.radius;
+    return metric->Evaluate(other.center, center) + radius + other.radius;
 }
 
 /**
@@ -117,9 +167,9 @@
  *
  * Example: bound1.MinDistanceSq(other) for minimum squared distance.
  */
-template<typename VecType>
+template<typename VecType, typename TMetricType>
 template<typename OtherVecType>
-math::Range BallBound<VecType>::RangeDistance(
+math::Range BallBound<VecType, TMetricType>::RangeDistance(
     const OtherVecType& point,
     typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
 {
@@ -127,22 +177,22 @@
     return math::Range(DBL_MAX, DBL_MAX);
   else
   {
-    double dist = metric::EuclideanDistance::Evaluate(center, point);
+    const double dist = metric->Evaluate(center, point);
     return math::Range(math::ClampNonNegative(dist - radius),
                                               dist + radius);
   }
 }
 
-template<typename VecType>
-math::Range BallBound<VecType>::RangeDistance(
+template<typename VecType, typename TMetricType>
+math::Range BallBound<VecType, TMetricType>::RangeDistance(
     const BallBound& other) const
 {
   if (radius < 0)
     return math::Range(DBL_MAX, DBL_MAX);
   else
   {
-    double dist = metric::EuclideanDistance::Evaluate(center, other.center);
-    double sumradius = radius + other.radius;
+    const double dist = metric->Evaluate(center, other.center);
+    const double sumradius = radius + other.radius;
     return math::Range(math::ClampNonNegative(dist - sumradius),
                                               dist + sumradius);
   }
@@ -151,12 +201,12 @@
 /**
  * Expand the bound to include the given bound.
  *
-template<typename VecType>
+template<typename VecType, typename TMetricType>
 const BallBound<VecType>&
-BallBound<VecType>::operator|=(
+BallBound<VecType, TMetricType>::operator|=(
     const BallBound<VecType>& other)
 {
-  double dist = metric::EuclideanDistance::Evaluate(center, other);
+  double dist = metric->Evaluate(center, other);
 
   // Now expand the radius as necessary.
   if (dist > radius)
@@ -166,12 +216,15 @@
 }*/
 
 /**
- * Expand the bound to include the given point.
+ * Expand the bound to include the given point. Algorithm adapted from
+ * Jack Ritter, "An Efficient Bounding Sphere" in Graphics Gems (1990).
+ * The difference lies in the way we initialize the ball bound. The way we
+ * expand the bound is same.
  */
-template<typename VecType>
+template<typename VecType, typename TMetricType>
 template<typename MatType>
-const BallBound<VecType>&
-BallBound<VecType>::operator|=(const MatType& data)
+const BallBound<VecType, TMetricType>&
+BallBound<VecType, TMetricType>::operator|=(const MatType& data)
 {
   if (radius < 0)
   {
@@ -179,35 +232,37 @@
     radius = 0;
   }
 
-  // Now iteratively add points.  There is probably a closed-form solution to
-  // find the minimum bounding circle, and it is probably faster.
-  for (size_t i = 1; i < data.n_cols; ++i)
+  // Now iteratively add points.
+  for (size_t i = 0; i < data.n_cols; ++i)
   {
-    double dist = metric::EuclideanDistance::Evaluate(center, (VecType)
-        data.col(i)) - radius;
+    const double dist = metric->Evaluate(center, (VecType) data.col(i));
 
-    if (dist > 0)
+    // See if the new point lies outside the bound.
+    if (dist > radius)
     {
-      // Move (dist / 2) towards the new point and increase radius by
-      // (dist / 2).
+      // Move towards the new point and increase the radius just enough to
+      // accomodate the new point.
       arma::vec diff = data.col(i) - center;
-      center += 0.5 * diff;
-      radius += 0.5 * dist;
+      center += ((dist - radius) / (2 * dist)) * diff;
+      radius = 0.5 * (dist + radius);
     }
   }
 
   return *this;
 }
+
 /**
  * Returns a string representation of this object.
  */
-template<typename VecType>
-std::string BallBound<VecType>::ToString() const
+template<typename VecType, typename TMetricType>
+std::string BallBound<VecType, TMetricType>::ToString() const
 {
   std::ostringstream convert;
   convert << "BallBound [" << this << "]" << std::endl;
   convert << "  Radius:  " << radius << std::endl;
   convert << "  Center:" << std::endl << center;
+  convert << "  ownsMetric: " << ownsMetric << std::endl;
+  convert << "  Metric:" << std::endl << metric->ToString();
   return convert.str();
 }
 



More information about the mlpack-svn mailing list