[mlpack-svn] r10830 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Dec 15 03:43:09 EST 2011
Author: rcurtin
Date: 2011-12-15 03:43:09 -0500 (Thu, 15 Dec 2011)
New Revision: 10830
Added:
mlpack/trunk/src/mlpack/core/tree/ballbound.hpp
mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp
Removed:
mlpack/trunk/src/mlpack/core/tree/dballbound.hpp
mlpack/trunk/src/mlpack/core/tree/dballbound_impl.hpp
Modified:
mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
mlpack/trunk/src/mlpack/core/tree/bounds.hpp
Log:
Refactor BallBound API and clean it up.
Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2011-12-15 08:02:06 UTC (rev 10829)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2011-12-15 08:43:09 UTC (rev 10830)
@@ -3,11 +3,11 @@
# Define the files we need to compile.
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
+ ballbound.hpp
+ ballbound_impl.hpp
binary_space_tree.hpp
binary_space_tree_impl.hpp
bounds.hpp
- dballbound.hpp
- dballbound_impl.hpp
hrectbound.hpp
hrectbound_impl.hpp
periodichrectbound.hpp
Copied: mlpack/trunk/src/mlpack/core/tree/ballbound.hpp (from rev 10797, mlpack/trunk/src/mlpack/core/tree/dballbound.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/ballbound.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/ballbound.hpp 2011-12-15 08:43:09 UTC (rev 10830)
@@ -0,0 +1,126 @@
+/**
+ * @file ballbound.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Interface to a ball bound that works in arbitrary metric spaces.
+ */
+
+#ifndef __MLPACK_CORE_TREE_BALLBOUND_HPP
+#define __MLPACK_CORE_TREE_BALLBOUND_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+
+namespace mlpack {
+namespace bound {
+
+/**
+ * Ball bound that works in the regular Euclidean metric space.
+ *
+ * @tparam VecType Type of vector (arma::vec or arma::spvec).
+ */
+template<typename VecType = arma::vec>
+class BallBound
+{
+ public:
+ typedef VecType Vec;
+
+ private:
+ double radius;
+ VecType center;
+
+ public:
+ BallBound() : radius(0) { }
+
+ /**
+ * Create the ball bound with the specified dimensionality.
+ *
+ * @param dimension Dimensionality of ball bound.
+ */
+ BallBound(const size_t dimension) : radius(0), center(dimension) { }
+
+ /**
+ * Create the ball bound with the specified radius and center.
+ *
+ * @param radius Radius of ball bound.
+ * @param center Center of ball bound.
+ */
+ BallBound(const double radius, const VecType& center) :
+ radius(radius), center(center) { }
+
+ //! Get the radius of the ball.
+ double Radius() const { return radius; }
+ //! Modify the radius of the ball.
+ double& Radius() { return radius; }
+
+ //! Get the center point of the ball.
+ const VecType& Center() const { return center; }
+ //! Modify the center point of the ball.
+ VecType& Center() { return center; }
+
+ // Get the range in a certain dimension.
+ math::Range operator[](const size_t i) const;
+
+ /**
+ * Determines if a point is within this bound.
+ */
+ bool Contains(const VecType& point) const;
+
+ /**
+ * Gets the center.
+ *
+ * Don't really use this directly. This is only here for consistency
+ * with DHrectBound, so it can plug in more directly if a "centroid"
+ * is needed.
+ */
+ void CalculateMidpoint(VecType& centroid) const;
+
+ /**
+ * Calculates minimum bound-to-point squared distance.
+ */
+ double MinDistance(const VecType& point) const;
+
+ /**
+ * Calculates minimum bound-to-bound squared distance.
+ */
+ double MinDistance(const BallBound& other) const;
+
+ /**
+ * Computes maximum distance.
+ */
+ double MaxDistance(const VecType& point) const;
+
+ /**
+ * Computes maximum distance.
+ */
+ double MaxDistance(const BallBound& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-point distance.
+ */
+ math::Range RangeDistance(const VecType& other) const;
+
+ /**
+ * Calculates minimum and maximum bound-to-bound distance.
+ *
+ * Example: bound1.MinDistanceSq(other) for minimum distance.
+ */
+ math::Range RangeDistance(const BallBound& other) const;
+
+ /**
+ * Expand the bound to include the given node.
+ */
+ const BallBound& operator|=(const BallBound& other);
+
+ /**
+ * Expand the bound to include the given point.
+ */
+ const BallBound& operator|=(const VecType& point);
+};
+
+}; // namespace bound
+}; // namespace mlpack
+
+#include "ballbound_impl.hpp"
+
+#endif // __MLPACK_CORE_TREE_DBALLBOUND_HPP
Copied: mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp (from rev 10737, mlpack/trunk/src/mlpack/core/tree/dballbound_impl.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/ballbound_impl.hpp 2011-12-15 08:43:09 UTC (rev 10830)
@@ -0,0 +1,146 @@
+/**
+ * @file ballbound_impl.hpp
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Implementation of BallBound ball bound metric policy class.
+ *
+ * @experimental
+ */
+#ifndef __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
+#define __MLPACK_CORE_TREE_BALLBOUND_IMPL_HPP
+
+// In case it hasn't been included already.
+#include "ballbound.hpp"
+
+namespace mlpack {
+namespace bound {
+
+//! Get the range in a certain dimension.
+template<typename VecType>
+math::Range BallBound<VecType>::operator[](const size_t i) const
+{
+ return math::Range(center[i] - radius, center[i] + radius);
+}
+
+/**
+ * Determines if a point is within the bound.
+ */
+template<typename VecType>
+bool BallBound<VecType>::Contains(const VecType& point) const
+{
+ return metric::EuclideanDistance::Evaluate(center, point) <= radius;
+}
+
+/**
+ * Gets the center.
+ *
+ * Don't really use this directly. This is only here for consistency
+ * with DHrectBound, so it can plug in more directly if a "centroid"
+ * is needed.
+ */
+template<typename VecType>
+void BallBound<VecType>::CalculateMidpoint(VecType& centroid) const
+{
+ centroid = center;
+}
+
+/**
+ * Calculates minimum bound-to-point squared distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MinDistance(const VecType& point) const
+{
+ return math::ClampNonNegative(metric::EuclideanDistance::Evaluate(point,
+ center) - radius);
+}
+
+/**
+ * Calculates minimum bound-to-bound squared distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MinDistance(const BallBound& other) const
+{
+ double delta = metric::EuclideanDistance::Evaluate(center, other.center)
+ - radius - other.radius;
+ return math::ClampNonNegative(delta);
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MaxDistance(const VecType& point) const
+{
+ return metric::EuclideanDistance::Evaluate(point, center) + radius;
+}
+
+/**
+ * Computes maximum distance.
+ */
+template<typename VecType>
+double BallBound<VecType>::MaxDistance(const BallBound& other) const
+{
+ return metric::EuclideanDistance::Evaluate(other.center, center) + radius
+ + other.radius;
+}
+
+/**
+ * Calculates minimum and maximum bound-to-bound squared distance.
+ *
+ * Example: bound1.MinDistanceSq(other) for minimum squared distance.
+ */
+template<typename VecType>
+math::Range BallBound<VecType>::RangeDistance(const VecType& point)
+ const
+{
+ double dist = metric::EuclideanDistance::Evaluate(center, point);
+ return math::Range(math::ClampNonNegative(dist - radius),
+ dist + radius);
+}
+
+template<typename VecType>
+math::Range BallBound<VecType>::RangeDistance(
+ const BallBound& other) const
+{
+ double dist = metric::EuclideanDistance::Evaluate(center, other.center);
+ double sumradius = radius + other.radius;
+ return math::Range(math::ClampNonNegative(dist - sumradius),
+ dist + sumradius);
+}
+
+/**
+ * Expand the bound to include the given bound.
+ */
+template<typename VecType>
+const BallBound<VecType>&
+BallBound<VecType>::operator|=(
+ const BallBound<VecType>& other)
+{
+ double dist = metric::EuclideanDistance::Evaluate(center, other);
+
+ // Now expand the radius as necessary.
+ if (dist > radius)
+ radius = dist;
+
+ return *this;
+}
+
+/**
+ * Expand the bound to include the given point.
+ */
+template<typename VecType>
+const BallBound<VecType>&
+BallBound<VecType>::operator|=(const VecType& point)
+{
+ double dist = metric::EuclideanDistance::Evaluate(center, point);
+
+ if (dist > radius)
+ radius = dist;
+
+ return *this;
+}
+
+}; // namespace bound
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
Modified: mlpack/trunk/src/mlpack/core/tree/bounds.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/bounds.hpp 2011-12-15 08:02:06 UTC (rev 10829)
+++ mlpack/trunk/src/mlpack/core/tree/bounds.hpp 2011-12-15 08:43:09 UTC (rev 10830)
@@ -12,6 +12,6 @@
#include "hrectbound.hpp"
#include "periodichrectbound.hpp"
-#include "dballbound.hpp"
+#include "ballbound.hpp"
#endif // __MLPACK_CORE_TREE_BOUNDS_HPP
Deleted: mlpack/trunk/src/mlpack/core/tree/dballbound.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/dballbound.hpp 2011-12-15 08:02:06 UTC (rev 10829)
+++ mlpack/trunk/src/mlpack/core/tree/dballbound.hpp 2011-12-15 08:43:09 UTC (rev 10830)
@@ -1,139 +0,0 @@
-/**
- * @file dballbound.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- * Interface to a ball bound that works in arbitrary metric spaces.
- *
- * @experimental
- */
-
-#ifndef __MLPACK_CORE_TREE_DBALLBOUND_HPP
-#define __MLPACK_CORE_TREE_DBALLBOUND_HPP
-
-#include <mlpack/core.hpp>
-#include <mlpack/core/metrics/lmetric.hpp>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Ball bound that works in arbitrary metric spaces.
- *
- * See LMetric for an example metric template parameter.
- *
- * To initialize this, set the radius with @c set_radius
- * and set the point by initializing @c point() directly.
- */
-template<typename TMetric = mlpack::metric::SquaredEuclideanDistance,
- typename TPoint = arma::vec>
-class DBallBound
-{
- public:
- typedef TPoint Point;
- typedef TMetric Metric;
-
- private:
- double radius_;
- TPoint center_;
-
- public:
- /**
- * Return the radius of the ball bound.
- */
- double radius() const { return radius_; }
-
- /**
- * Set the radius of the bound.
- */
- void set_radius(double d) { radius_ = d; }
-
- /**
- * Return the center point.
- */
- const TPoint& center() const { return center_; }
-
- /**
- * Return the center point.
- */
- TPoint& center() { return center_; }
-
- /**
- * Determines if a point is within this bound.
- */
- bool Contains(const Point& point) const;
-
- /**
- * Gets the center.
- *
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
- void CalculateMidpoint(Point *centroid) const;
-
- /**
- * Calculates minimum bound-to-point squared distance.
- */
- double MinDistance(const Point& point) const;
- double MinDistanceSq(const Point& point) const;
-
- /**
- * Calculates minimum bound-to-bound squared distance.
- */
- double MinDistance(const DBallBound& other) const;
- double MinDistanceSq(const DBallBound& other) const;
-
- /**
- * Computes maximum distance.
- */
- double MaxDistance(const Point& point) const;
- double MaxDistanceSq(const Point& point) const;
-
- /**
- * Computes maximum distance.
- */
- double MaxDistance(const DBallBound& other) const;
- double MaxDistanceSq(const DBallBound& other) const;
-
- /**
- * Calculates minimum and maximum bound-to-bound squared distance.
- *
- * Example: bound1.MinDistanceSq(other) for minimum squared distance.
- */
- math::Range RangeDistance(const DBallBound& other) const;
- math::Range RangeDistanceSq(const DBallBound& other) const;
-
- /**
- * Calculates closest-to-their-midpoint bounding box distance,
- * i.e. calculates their midpoint and finds the minimum box-to-point
- * distance.
- *
- * Equivalent to:
- * <code>
- * other.CalcMidpoint(&other_midpoint)
- * return MinDistanceSqToPoint(other_midpoint)
- * </code>
- */
- double MinToMid(const DBallBound& other) const;
- double MinToMidSq(const DBallBound& other) const;
-
- /**
- * Computes minimax distance, where the other node is trying to avoid me.
- */
- double MinimaxDistance(const DBallBound& other) const;
- double MinimaxDistanceSq(const DBallBound& other) const;
-
- /**
- * Calculates midpoint-to-midpoint bounding box distance.
- */
- double MidDistance(const DBallBound& other) const;
- double MidDistanceSq(const DBallBound& other) const;
- double MidDistance(const Point& point) const;
-};
-
-}; // namespace bound
-}; // namespace mlpack
-
-#include "dballbound_impl.hpp"
-
-#endif // __MLPACK_CORE_TREE_DBALLBOUND_HPP
Deleted: mlpack/trunk/src/mlpack/core/tree/dballbound_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/dballbound_impl.hpp 2011-12-15 08:02:06 UTC (rev 10829)
+++ mlpack/trunk/src/mlpack/core/tree/dballbound_impl.hpp 2011-12-15 08:43:09 UTC (rev 10830)
@@ -1,193 +0,0 @@
-/**
- * @file dballbound_impl.hpp
- *
- * Bounds that are useful for binary space partitioning trees.
- * Implementation of DBallBound ball bound metric policy class.
- *
- * @experimental
- */
-#ifndef __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
-#define __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
-
-#include <mlpack/core.hpp>
-
-namespace mlpack {
-namespace bound {
-
-/**
- * Determines if a point is within the bound.
- */
-template<typename TMetric, typename TPoint>
-bool DBallBound<TMetric, TPoint>::Contains(const Point& point) const
-{
- return MidDistance(point) <= radius_;
-}
-
-/**
- * Gets the center.
- *
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
-template<typename TMetric, typename TPoint>
-void DBallBound<TMetric, TPoint>::CalculateMidpoint(Point *centroid) const
-{
- (*centroid) = center_;
-}
-
-/**
- * Calculates minimum bound-to-point squared distance.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinDistance(const Point& point) const
-{
- return math::ClampNonNegative(MidDistance(point) - radius_);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinDistanceSq(const Point& point) const
-{
- return std::pow(MinDistance(point), 2);
-}
-
-/**
- * Calculates minimum bound-to-bound squared distance.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinDistance(const DBallBound& other) const
-{
- double delta = MidDistance(other.center_) - radius_ - other.radius_;
- return math::ClampNonNegative(delta);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinDistanceSq(const DBallBound& other) const
-{
- return std::pow(MinDistance(other), 2);
-}
-
-/**
- * Computes maximum distance.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MaxDistance(const Point& point) const
-{
- return MidDistance(point) + radius_;
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MaxDistanceSq(const Point& point) const
-{
- return std::pow(MaxDistance(point), 2);
-}
-
-/**
- * Computes maximum distance.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MaxDistance(const DBallBound& other) const
-{
- return MidDistance(other.center_) + radius_ + other.radius_;
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MaxDistanceSq(const DBallBound& other) const
-{
- return std::pow(MaxDistance(other), 2);
-}
-
-/**
- * Calculates minimum and maximum bound-to-bound squared distance.
- *
- * Example: bound1.MinDistanceSq(other) for minimum squared distance.
- */
-template<typename TMetric, typename TPoint>
-math::Range DBallBound<TMetric, TPoint>::RangeDistance(
- const DBallBound& other) const
-{
- double delta = MidDistance(other.center_);
- double sumradius = radius_ + other.radius_;
- return math::Range(
- math::ClampNonNegative(delta - sumradius),
- delta + sumradius);
-}
-
-template<typename TMetric, typename TPoint>
-math::Range DBallBound<TMetric, TPoint>::RangeDistanceSq(
- const DBallBound& other) const
-{
- double delta = MidDistance(other.center_);
- double sumradius = radius_ + other.radius_;
- return math::Range(
- std::pow(math::ClampNonNegative(delta - sumradius), 2),
- std::pow(delta + sumradius, 2));
-}
-
-/**
- * Calculates closest-to-their-midpoint bounding box distance,
- * i.e. calculates their midpoint and finds the minimum box-to-point
- * distance.
- *
- * Equivalent to:
- * <code>
- * other.CalcMidpoint(&other_midpoint)
- * return MinDistanceSqToPoint(other_midpoint)
- * </code>
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinToMid(const DBallBound& other) const
-{
- double delta = MidDistance(other.center_) - radius_;
- return math::ClampNonNegative(delta);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinToMidSq(const DBallBound& other) const
-{
- return std::pow(MinToMid(other), 2);
-}
-
-/**
- * Computes minimax distance, where the other node is trying to avoid me.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinimaxDistance(
- const DBallBound& other) const
-{
- double delta = MidDistance(other.center_) + other.radius_ - radius_;
- return math::ClampNonNegative(delta);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MinimaxDistanceSq(
- const DBallBound& other) const
-{
- return std::pow(MinimaxDistance(other), 2);
-}
-
-/**
- * Calculates midpoint-to-midpoint bounding box distance.
- */
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MidDistance(const DBallBound& other) const
-{
- return MidDistance(other.center_);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MidDistanceSq(const DBallBound& other) const
-{
- return std::pow(MidDistance(other), 2);
-}
-
-template<typename TMetric, typename TPoint>
-double DBallBound<TMetric, TPoint>::MidDistance(const Point& point) const
-{
- return Metric::Evaluate(center_, point);
-}
-
-}; // namespace bound
-}; // namespace mlpack
-
-#endif // __MLPACK_CORE_TREE_DBALLBOUND_IMPL_HPP
More information about the mlpack-svn
mailing list