[mlpack-git] master: Add some simple octree tests. (60bbe7e)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Sep 23 10:19:19 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit 60bbe7e1344ed474eeb7803bab38d793f17bac8b
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 23 10:19:19 2016 -0400
Add some simple octree tests.
>---------------------------------------------------------------
60bbe7e1344ed474eeb7803bab38d793f17bac8b
src/mlpack/core/tree/CMakeLists.txt | 4 +
src/mlpack/core/tree/octree/octree.hpp | 174 ++++++++-
src/mlpack/core/tree/octree/octree_impl.hpp | 542 ++++++++++++++++++++++++++--
src/mlpack/tests/CMakeLists.txt | 1 +
src/mlpack/tests/octree_test.cpp | 148 ++++++++
5 files changed, 835 insertions(+), 34 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 2a849ae..7cf2166 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -50,6 +50,10 @@ set(SOURCES
hollow_ball_bound_impl.hpp
hrectbound.hpp
hrectbound_impl.hpp
+ octree.hpp
+ octree/octree.hpp
+ octree/octree_impl.hpp
+ octree/traits.hpp
rectangle_tree.hpp
rectangle_tree/rectangle_tree.hpp
rectangle_tree/rectangle_tree_impl.hpp
diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index 544fa75..3fec915 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -9,11 +9,12 @@
#include <mlpack/core.hpp>
#include "../hrectbound.hpp"
+#include "../statistic.hpp"
namespace mlpack {
namespace tree {
-template<typename MetricType,
+template<typename MetricType = metric::EuclideanDistance,
typename StatisticType = EmptyStatistic,
typename MatType = arma::mat>
class Octree
@@ -36,11 +37,19 @@ class Octree
size_t count;
//! The minimum bounding rectangle of the points held in the node (and its
//! children).
- HRectBound<MeetricType> bound;
+ bound::HRectBound<MetricType> bound;
//! The dataset.
MatType* dataset;
//! The parent (NULL if this node is the root).
Octree* parent;
+ //! The statistic.
+ StatisticType stat;
+ //! The distance from the center of this node to the center of the parent.
+ ElemType parentDistance;
+ //! The distance to the furthest descendant, cached to speed things up.
+ ElemType furthestDescendantDistance;
+ //! An instantiated metric.
+ MetricType metric;
public:
/**
@@ -97,7 +106,7 @@ class Octree
* @param data Dataset to create tree from. This will be copied!
* @param maxLeafSize Maximum number of points in a leaf node.
*/
- Octree(const MatType& data, const size_t maxLeafSize = 20);
+ Octree(MatType&& data, const size_t maxLeafSize = 20);
/**
* Construct this as the root node of an octree on the given dataset. This
@@ -183,6 +192,148 @@ class Octree
const double width,
const size_t maxLeafSize = 20);
+ /**
+ * Destroy the tree.
+ */
+ ~Octree();
+
+ //! Return the dataset used by this node.
+ const MatType& Dataset() const { return *dataset; }
+
+ //! Return the bound object for this node.
+ const bound::HRectBound<MetricType>& Bound() const { return bound; }
+ //! Modify the bound object for this node.
+ bound::HRectBound<MetricType>& Bound() { return bound; }
+
+ //! Return the statistic object for this node.
+ const StatisticType& Stat() const { return stat; }
+ //! Modify the statistic object for this node.
+ StatisticType& Stat() { return stat; }
+
+ //! Return the number of children in this node.
+ size_t NumChildren() const;
+
+ //! Return the metric that this tree uses.
+ MetricType Metric() const { return MetricType(); }
+
+ /**
+ * Return the index of the nearest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
+ */
+ template<typename VecType>
+ size_t GetNearestChild(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+
+ /**
+ * Return the index of the furthest child node to the given query point. If
+ * this is a leaf node, it will return NumChildren() (invalid index).
+ */
+ template<typename VecType>
+ size_t GetFurthestChild(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
+
+ /**
+ * Return the index of the nearest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
+ */
+ size_t GetNearestChild(const Octree& queryNode) const;
+
+ /**
+ * Return the index of the furthest child node to the given query node. If it
+ * can't decide, it will return NumChildren() (invalid index).
+ */
+ size_t GetFurthestChild(const Octree& queryNode) const;
+
+ /**
+ * Return the furthest distance to a point held in this node. If this is not
+ * a leaf node, then the distance is 0 because the node holds no points.
+ */
+ ElemType FurthestPointDistance() const;
+
+ /**
+ * Return the furthest possible descendant distance. This returns the maximum
+ * distance from the centroid to the edge of the bound and not the empirical
+ * quantity which is the actual furthest descendant distance. So the actual
+ * furthest descendant distance may be less than what this method returns (but
+ * it will never be greater than this).
+ */
+ ElemType FurthestDescendantDistance() const;
+
+ //! Return the minimum distance from the center of the node to any bound edge.
+ ElemType MinimumBoundDistance() const;
+
+ //! Return the distance from the center of this node to the center of the
+ //! parent node.
+ ElemType ParentDistance() const { return parentDistance; }
+ //! Modify the distance from the center of this node to the center of the
+ //! parent node.
+ ElemType& ParentDistance() { return parentDistance; }
+
+ /**
+ * Return the specified child. If the index is out of bounds, unspecified
+ * behavior will occur.
+ */
+ const Octree& Child(const size_t child) const { return *children[child]; }
+
+ /**
+ * Return the specified child. If the index is out of bounds, unspecified
+ * behavior will occur.
+ */
+ Octree& Child(const size_t child) { return *children[child]; }
+
+ /**
+ * Return the pointer to the given child. This allows the child itself to be
+ * modified.
+ */
+ Octree*& ChildPtr(const size_t child) { return children[child]; }
+
+ //! Return the number of points in this node (0 if not a leaf).
+ size_t NumPoints() const;
+
+ //! Return the number of descendants of this node.
+ size_t NumDescendants() const;
+
+ /**
+ * Return the index (with reference to the dataset) of a particular
+ * descendant.
+ */
+ size_t Descendant(const size_t index) const;
+
+ /**
+ * Return the index (with reference to the dataset) of a particular point in
+ * this node. If the given index is invalid (i.e. if it is greater than
+ * NumPoints()), the indices returned will be invalid.
+ */
+ size_t Point(const size_t index) const;
+
+ //! Return the minimum distance to another node.
+ ElemType MinDistance(const Octree* other) const;
+ //! Return the maximum distance to another node.
+ ElemType MaxDistance(const Octree* other) const;
+ //! Return the minimum and maximum distance to another node.
+ math::RangeType<ElemType> RangeDistance(const Octree* other) const;
+
+ //! Return the minimum distance to the given point.
+ template<typename VecType>
+ ElemType MinDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+ //! Return the maximum distance to the given point.
+ template<typename VecType>
+ ElemType MaxDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+ //! Return the minimum and maximum distance to another node.
+ template<typename VecType>
+ math::RangeType<ElemType> RangeDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+
+ //! Store the center of the bounding region in the given vector.
+ void Center(arma::vec& center) const { bound.Center(center); }
+
private:
/**
* Split the node, using the given center and the given maximum width of this
@@ -190,8 +341,11 @@ class Octree
*
* @param center Center of the node.
* @param width Width of the current node.
+ * @param maxLeafSize Maximum number of points allowed in a leaf.
*/
- void SplitNode(const arma::vec& center, const double width);
+ void SplitNode(const arma::vec& center,
+ const double width,
+ const size_t maxLeafSize);
/**
* Split the node, using the given center and the given maximum width of this
@@ -200,8 +354,18 @@ class Octree
* @param center Center of the node.
* @param width Width of the current node.
* @param oldFromNew Mappings from old to new.
+ * @param maxLeafSize Maximum number of points allowed in a leaf.
*/
void SplitNode(const arma::vec& center,
const double width,
- std::vector<size_t>& oldFromNew);
+ std::vector<size_t>& oldFromNew,
+ const size_t maxLeafSize);
};
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "octree_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/octree_impl.hpp b/src/mlpack/core/tree/octree/octree_impl.hpp
index cd8234a..33ec3e2 100644
--- a/src/mlpack/core/tree/octree/octree_impl.hpp
+++ b/src/mlpack/core/tree/octree/octree_impl.hpp
@@ -9,19 +9,82 @@
#include "octree.hpp"
+namespace mlpack {
+namespace tree {
+
//! Construct the tree.
template<typename MetricType, typename StatisticType, typename MatType>
Octree<MetricType, StatisticType, MatType>::Octree(const MatType& dataset,
- const double maxLeafSize) :
+ const size_t maxLeafSize) :
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
dataset(new MatType(dataset)),
+ parent(NULL),
+ parentDistance(0.0)
+{
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
+
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+}
+
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+ const MatType& dataset,
+ std::vector<size_t>& oldFromNew,
+ const size_t maxLeafSize) :
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
+ dataset(new MatType(dataset)),
+ parent(NULL),
+ parentDistance(0.0)
{
- // Calculate empirical center of data.
- bound |= *dataset;
- arma::vec center = bound.Center();
- double maxWidth = bound.MaxWidth();
+ oldFromNew.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; ++i)
+ oldFromNew[i] = i;
+
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
- SplitNode(center, maxWidth);
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
// Initialize the statistic.
stat = StatisticType(*this);
@@ -32,20 +95,122 @@ template<typename MetricType, typename StatisticType, typename MatType>
Octree<MetricType, StatisticType, MatType>::Octree(
const MatType& dataset,
std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
const size_t maxLeafSize) :
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
dataset(new MatType(dataset)),
+ parent(NULL),
+ parentDistance(0.0)
+{
+ oldFromNew.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; ++i)
+ oldFromNew[i] = i;
+
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
+
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+
+ // Map the newFromOld indices correctly.
+ newFromOld.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; i++)
+ newFromOld[oldFromNew[i]] = i;
+}
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(MatType&& dataset,
+ const size_t maxLeafSize) :
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
+ dataset(new MatType(std::move(dataset))),
+ parent(NULL),
+ parentDistance(0.0)
{
- // Calculate empirical center of data.
- bound |= *dataset;
- arma::vec center = bound.Center();
- double maxWidth = bound.MaxWidth();
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
+
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
- oldFromNew.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; ++i)
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+}
+
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+ MatType&& dataset,
+ std::vector<size_t>& oldFromNew,
+ const size_t maxLeafSize) :
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
+ dataset(new MatType(std::move(dataset))),
+ parent(NULL),
+ parentDistance(0.0)
+{
+ oldFromNew.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; ++i)
oldFromNew[i] = i;
- SplitNode(center, maxWidth, oldFromNew);
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
+
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
// Initialize the statistic.
stat = StatisticType(*this);
@@ -54,34 +219,341 @@ Octree<MetricType, StatisticType, MatType>::Octree(
//! Construct the tree.
template<typename MetricType, typename StatisticType, typename MatType>
Octree<MetricType, StatisticType, MatType>::Octree(
- const MatType& dataset,
+ MatType&& dataset,
std::vector<size_t>& oldFromNew,
+ std::vector<size_t>& newFromOld,
const size_t maxLeafSize) :
- dataset(new MatType(dataset)),
+ begin(0),
+ count(dataset.n_cols),
+ bound(dataset.n_rows),
+ dataset(new MatType(std::move(dataset))),
+ parent(NULL),
+ parentDistance(0.0)
+{
+ oldFromNew.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; ++i)
+ oldFromNew[i] = i;
+
+ if (count > 0)
+ {
+ // Calculate empirical center of data.
+ bound |= *this->dataset;
+ arma::vec center;
+ bound.Center(center);
+ double maxWidth = 0.0;
+ for (size_t i = 0; i < bound.Dim(); ++i)
+ if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+ maxWidth = bound[i].Hi() - bound[i].Lo();
+
+ SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+ }
+ else
+ {
+ furthestDescendantDistance = 0.0;
+ }
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+
+ // Map the newFromOld indices correctly.
+ newFromOld.resize(this->dataset->n_cols);
+ for (size_t i = 0; i < this->dataset->n_cols; i++)
+ newFromOld[oldFromNew[i]] = i;
+}
+
+//! Construct a child node.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+ Octree* parent,
+ const size_t begin,
+ const size_t count,
+ const arma::vec& center,
+ const double width,
+ const size_t maxLeafSize) :
+ begin(begin),
+ count(count),
+ bound(parent->dataset->n_rows),
+ dataset(parent->dataset),
+ parent(parent)
{
// Calculate empirical center of data.
- bound |= *dataset;
- arma::vec center = bound.Center();
- double maxWidth = bound.MaxWidth();
+ bound |= dataset->cols(begin, begin + count - 1);
- oldFromNew.resize(data.n_cols);
- for (size_t i = 0; i < data.n_cols; ++i)
- oldFromNew[i] = i;
+ // Now split the node.
+ SplitNode(center, width, maxLeafSize);
- SplitNode(center, maxWidth, oldFromNew);
+ // Calculate the distance from the empirical center of this node to the
+ // empirical center of the parent.
+ arma::vec trueCenter, parentCenter;
+ bound.Center(trueCenter);
+ parent->Bound().Center(parentCenter);
+ parentDistance = metric.Evaluate(trueCenter, parentCenter);
// Initialize the statistic.
stat = StatisticType(*this);
}
+//! Construct a child node.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+ Octree* parent,
+ const size_t begin,
+ const size_t count,
+ std::vector<size_t>& oldFromNew,
+ const arma::vec& center,
+ const double width,
+ const size_t maxLeafSize) :
+ begin(begin),
+ count(count),
+ bound(parent->dataset->n_rows),
+ dataset(parent->dataset),
+ parent(parent)
+{
+ // Calculate empirical center of data.
+ bound |= dataset->cols(begin, begin + count - 1);
+
+ // Now split the node.
+ SplitNode(center, width, oldFromNew, maxLeafSize);
+
+ // Calculate the distance from the empirical center of this node to the
+ // empirical center of the parent.
+ arma::vec trueCenter, parentCenter;
+ bound.Center(trueCenter);
+ parent->Bound().Center(parentCenter);
+ parentDistance = metric.Evaluate(trueCenter, parentCenter);
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::~Octree()
+{
+ // Delete the dataset if we aren't the parent.
+ if (!parent)
+ delete dataset;
+
+ // Now delete each of the children.
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+ children.clear();
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumChildren() const
+{
+ return children.size();
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+size_t Octree<MetricType, StatisticType, MatType>::GetNearestChild(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+ // It's possible that this could be improved by caching which children we have
+ // and which we don't, but for now this is just a brute force search.
+ ElemType bestDistance = DBL_MAX;
+ size_t bestIndex = NumChildren();
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ const double dist = children[i]->MinDistance(point);
+ if (dist < bestDistance)
+ {
+ bestDistance = dist;
+ bestIndex = i;
+ }
+ }
+
+ return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+size_t Octree<MetricType, StatisticType, MatType>::GetFurthestChild(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+ // It's possible that this could be improved by caching which children we have
+ // and which we don't, but for now this is just a brute force search.
+ ElemType bestDistance = -1.0; // Initialize to invalid distance.
+ size_t bestIndex = NumChildren();
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ const double dist = children[i]->MaxDistance(point);
+ if (dist > bestDistance)
+ {
+ bestDistance = dist;
+ bestIndex = i;
+ }
+ }
+
+ return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::GetNearestChild(
+ const Octree& queryNode) const
+{
+ // It's possible that this could be improved by caching which children we have
+ // and which we don't, but for now this is just a brute force search.
+ ElemType bestDistance = DBL_MAX;
+ size_t bestIndex = NumChildren();
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ const double dist = children[i]->MaxDistance(queryNode);
+ if (dist < bestDistance)
+ {
+ bestDistance = dist;
+ bestIndex = i;
+ }
+ }
+
+ return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::GetFurthestChild(
+ const Octree& queryNode) const
+{
+ // It's possible that this could be improved by caching which children we have
+ // and which we don't, but for now this is just a brute force search.
+ ElemType bestDistance = -1.0; // Initialize to invalid distance.
+ size_t bestIndex = NumChildren();
+ for (size_t i = 0; i < NumChildren(); ++i)
+ {
+ const double dist = children[i]->MaxDistance(queryNode);
+ if (dist > bestDistance)
+ {
+ bestDistance = dist;
+ bestIndex = i;
+ }
+ }
+
+ return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::FurthestPointDistance()
+ const
+{
+ // If we are not a leaf, then this distance is 0. Otherwise, return the
+ // furthest descendant distance.
+ return (children.size() > 0) ? 0.0 : furthestDescendantDistance;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::FurthestDescendantDistance() const
+{
+ return furthestDescendantDistance;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinimumBoundDistance() const
+{
+ return bound.MinWidth() / 2.0;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumPoints() const
+{
+ // We have no points unless we are a leaf;
+ return (children.size() > 0) ? 0 : count;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumDescendants() const
+{
+ return count;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::Descendant(
+ const size_t index) const
+{
+ return begin + index;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::Point(const size_t index)
+ const
+{
+ return begin + index;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinDistance(const Octree* other)
+ const
+{
+ return bound.MinDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MaxDistance(const Octree* other)
+ const
+{
+ return bound.MaxDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+math::RangeType<typename Octree<MetricType, StatisticType, MatType>::ElemType>
+Octree<MetricType, StatisticType, MatType>::RangeDistance(const Octree* other)
+ const
+{
+ return bound.RangeDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+ return bound.MinDistance(point);
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MaxDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+ return bound.MaxDistance(point);
+}
+
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+math::RangeType<typename Octree<MetricType, StatisticType, MatType>::ElemType>
+Octree<MetricType, StatisticType, MatType>::RangeDistance(
+ const VecType& point,
+ typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+ return bound.RangeDistance(point);
+}
//! Split the node.
template<typename MetricType, typename StatisticType, typename MatType>
void Octree<MetricType, StatisticType, MatType>::SplitNode(
const arma::vec& center,
- const double width)
+ const double width,
+ const size_t maxLeafSize)
{
+ // No need to split if we have fewer than the maximum number of points in this
+ // node.
+ if (count <= maxLeafSize)
+ return;
+
// We must split the dataset by sequentially creating each of the children.
// We do this in two steps: first we make a pass to count the number of points
// that will fall into each child; then in the second pass we rearrange the
@@ -101,7 +573,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
// the points fall on. The last dimension represents the most significant
// bit in the assignment; the bit is '1' if it falls to the right of the
// center.
- if (dataset(d, begin + i) > center(d))
+ if ((*dataset)(d, begin + i) > center(d))
assignments(i) |= (1 << d);
}
@@ -130,7 +602,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
for (size_t d = 0; d < center.n_elem; ++d)
{
// Is the dimension "right" (1) or "left" (0)?
- if ((i >> d) & 1 == 0)
+ if (((i >> d) & 1) == 0)
childCenter[d] = center[d] - childWidth;
else
childCenter[d] = center[d] + childWidth;
@@ -148,8 +620,14 @@ template<typename MetricType, typename StatisticType, typename MatType>
void Octree<MetricType, StatisticType, MatType>::SplitNode(
const arma::vec& center,
const double width,
- std::vector<size_t>& oldFromNew)
+ std::vector<size_t>& oldFromNew,
+ const size_t maxLeafSize)
{
+ // No need to split if we have fewer than the maximum number of points in this
+ // node.
+ if (count <= maxLeafSize)
+ return;
+
// We must split the dataset by sequentially creating each of the children.
// We do this in two steps: first we make a pass to count the number of points
// that will fall into each child; then in the second pass we rearrange the
@@ -169,7 +647,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
// the points fall on. The last dimension represents the most significant
// bit in the assignment; the bit is '1' if it falls to the right of the
// center.
- if (dataset(d, begin + i) > center(d))
+ if ((*dataset)(d, begin + i) > center(d))
assignments(i) |= (1 << d);
}
@@ -183,8 +661,9 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
// really a problem. We use non-contiguous submatrix views to extract the
// columns in the correct order.
dataset->cols(begin, begin + count - 1) = dataset->cols(begin + ordering);
+ std::vector<size_t> oldFromNewCopy(oldFromNew); // We need the old indices.
for (size_t i = 0; i < count; ++i)
- oldFromNew[ordering[i] + begin] = i + begin;
+ oldFromNew[i + begin] = oldFromNewCopy[ordering[i] + begin];
// Now that the dataset is reordered, we can create the children.
size_t childBegin = begin;
@@ -200,7 +679,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
for (size_t d = 0; d < center.n_elem; ++d)
{
// Is the dimension "right" (1) or "left" (0)?
- if ((i >> d) & 1 == 0)
+ if (((i >> d) & 1) == 0)
childCenter[d] = center[d] - childWidth;
else
childCenter[d] = center[d] + childWidth;
@@ -212,3 +691,8 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
childBegin += childCounts[i];
}
}
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 9ad4092..099aa29 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -57,6 +57,7 @@ add_executable(mlpack_test
network_util_test.cpp
nmf_test.cpp
nystroem_method_test.cpp
+ octree_test.cpp
pca_test.cpp
perceptron_test.cpp
quic_svd_test.cpp
diff --git a/src/mlpack/tests/octree_test.cpp b/src/mlpack/tests/octree_test.cpp
new file mode 100644
index 0000000..b2647b3
--- /dev/null
+++ b/src/mlpack/tests/octree_test.cpp
@@ -0,0 +1,148 @@
+/**
+ * @file octree_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test various properties of the Octree.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/octree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "test_tools.hpp"
+
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
+
+BOOST_AUTO_TEST_SUITE(OctreeTest);
+
+/**
+ * Build a quad-tree (2-d octree) on 4 points, and guarantee four points are
+ * created.
+ */
+BOOST_AUTO_TEST_CASE(SimpleQuadtreeTest)
+{
+ // Four corners of the unit square.
+ arma::mat dataset("0 0 1 1; 0 1 0 1");
+
+ Octree<> t(dataset, 1);
+
+ BOOST_REQUIRE_EQUAL(t.NumChildren(), 4);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 4);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 2);
+ BOOST_REQUIRE_EQUAL(t.NumDescendants(), 4);
+ BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+ for (size_t i = 0; i < 4; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(t.Child(i).NumDescendants(), 1);
+ BOOST_REQUIRE_EQUAL(t.Child(i).NumPoints(), 1);
+ }
+}
+
+/**
+ * Build an octree on 3 points and make sure that only three children are
+ * created.
+ */
+BOOST_AUTO_TEST_CASE(OctreeMissingChildTest)
+{
+ // Only three corners of the unit square.
+ arma::mat dataset("0 0 1; 0 1 1");
+
+ Octree<> t(dataset, 1);
+
+ BOOST_REQUIRE_EQUAL(t.NumChildren(), 3);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 3);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 2);
+ BOOST_REQUIRE_EQUAL(t.NumDescendants(), 3);
+ BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+ for (size_t i = 0; i < 3; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(t.Child(i).NumDescendants(), 1);
+ BOOST_REQUIRE_EQUAL(t.Child(i).NumPoints(), 1);
+ }
+}
+
+/**
+ * Ensure that building an empty octree does not fail.
+ */
+BOOST_AUTO_TEST_CASE(EmptyOctreeTest)
+{
+ arma::mat dataset;
+ Octree<> t(dataset);
+
+ BOOST_REQUIRE_EQUAL(t.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 0);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 0);
+ BOOST_REQUIRE_EQUAL(t.NumDescendants(), 0);
+ BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+}
+
+/**
+ * Ensure that maxLeafSize is respected.
+ */
+BOOST_AUTO_TEST_CASE(MaxLeafSizeTest)
+{
+ arma::mat dataset(5, 15, arma::fill::randu);
+ Octree<> t1(dataset, 20);
+ Octree<> t2(std::move(dataset), 20);
+
+ BOOST_REQUIRE_EQUAL(t1.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(t1.NumDescendants(), 15);
+ BOOST_REQUIRE_EQUAL(t1.NumPoints(), 15);
+
+ BOOST_REQUIRE_EQUAL(t2.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(t2.NumDescendants(), 15);
+ BOOST_REQUIRE_EQUAL(t2.NumPoints(), 15);
+}
+
+/**
+ * Check that the mappings given are correct.
+ */
+BOOST_AUTO_TEST_CASE(MappingsTest)
+{
+ // Test with both constructors.
+ arma::mat dataset(3, 5, arma::fill::randu);
+ arma::mat datacopy(dataset);
+ std::vector<size_t> oldFromNewCopy, oldFromNewMove;
+
+ Octree<> t1(dataset, oldFromNewCopy, 1);
+ Octree<> t2(std::move(dataset), oldFromNewMove, 1);
+
+ for (size_t i = 0; i < oldFromNewCopy.size(); ++i)
+ {
+ BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewCopy[i]) -
+ t1.Dataset().col(i)), 1e-3);
+ BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewMove[i]) -
+ t2.Dataset().col(i)), 1e-3);
+ }
+}
+
+/**
+ * Check that the reverse mappings are correct too.
+ */
+BOOST_AUTO_TEST_CASE(ReverseMappingsTest)
+{
+ // Test with both constructors.
+ arma::mat dataset(3, 300, arma::fill::randu);
+ arma::mat datacopy(dataset);
+ std::vector<size_t> oldFromNewCopy, oldFromNewMove, newFromOldCopy,
+ newFromOldMove;
+
+ Octree<> t1(dataset, oldFromNewCopy, newFromOldCopy);
+ Octree<> t2(std::move(dataset), oldFromNewMove, newFromOldMove);
+
+ for (size_t i = 0; i < oldFromNewCopy.size(); ++i)
+ {
+ BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewCopy[i]) -
+ t1.Dataset().col(i)), 1e-3);
+ BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewMove[i]) -
+ t2.Dataset().col(i)), 1e-3);
+
+ BOOST_REQUIRE_EQUAL(newFromOldCopy[oldFromNewCopy[i]], i);
+ BOOST_REQUIRE_EQUAL(newFromOldMove[oldFromNewMove[i]], i);
+ }
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list