[mlpack-git] master: Add some simple octree tests. (60bbe7e)

gitdub at mlpack.org gitdub at mlpack.org
Fri Sep 23 10:19:19 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit 60bbe7e1344ed474eeb7803bab38d793f17bac8b
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 23 10:19:19 2016 -0400

    Add some simple octree tests.


>---------------------------------------------------------------

60bbe7e1344ed474eeb7803bab38d793f17bac8b
 src/mlpack/core/tree/CMakeLists.txt         |   4 +
 src/mlpack/core/tree/octree/octree.hpp      | 174 ++++++++-
 src/mlpack/core/tree/octree/octree_impl.hpp | 542 ++++++++++++++++++++++++++--
 src/mlpack/tests/CMakeLists.txt             |   1 +
 src/mlpack/tests/octree_test.cpp            | 148 ++++++++
 5 files changed, 835 insertions(+), 34 deletions(-)

diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 2a849ae..7cf2166 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -50,6 +50,10 @@ set(SOURCES
   hollow_ball_bound_impl.hpp
   hrectbound.hpp
   hrectbound_impl.hpp
+  octree.hpp
+  octree/octree.hpp
+  octree/octree_impl.hpp
+  octree/traits.hpp
   rectangle_tree.hpp
   rectangle_tree/rectangle_tree.hpp
   rectangle_tree/rectangle_tree_impl.hpp
diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index 544fa75..3fec915 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -9,11 +9,12 @@
 
 #include <mlpack/core.hpp>
 #include "../hrectbound.hpp"
+#include "../statistic.hpp"
 
 namespace mlpack {
 namespace tree {
 
-template<typename MetricType,
+template<typename MetricType = metric::EuclideanDistance,
          typename StatisticType = EmptyStatistic,
          typename MatType = arma::mat>
 class Octree
@@ -36,11 +37,19 @@ class Octree
   size_t count;
   //! The minimum bounding rectangle of the points held in the node (and its
   //! children).
-  HRectBound<MeetricType> bound;
+  bound::HRectBound<MetricType> bound;
   //! The dataset.
   MatType* dataset;
   //! The parent (NULL if this node is the root).
   Octree* parent;
+  //! The statistic.
+  StatisticType stat;
+  //! The distance from the center of this node to the center of the parent.
+  ElemType parentDistance;
+  //! The distance to the furthest descendant, cached to speed things up.
+  ElemType furthestDescendantDistance;
+  //! An instantiated metric.
+  MetricType metric;
 
  public:
   /**
@@ -97,7 +106,7 @@ class Octree
    * @param data Dataset to create tree from.  This will be copied!
    * @param maxLeafSize Maximum number of points in a leaf node.
    */
-  Octree(const MatType& data, const size_t maxLeafSize = 20);
+  Octree(MatType&& data, const size_t maxLeafSize = 20);
 
   /**
    * Construct this as the root node of an octree on the given dataset. This
@@ -183,6 +192,148 @@ class Octree
          const double width,
          const size_t maxLeafSize = 20);
 
+  /**
+   * Destroy the tree.
+   */
+  ~Octree();
+
+  //! Return the dataset used by this node.
+  const MatType& Dataset() const { return *dataset; }
+
+  //! Return the bound object for this node.
+  const bound::HRectBound<MetricType>& Bound() const { return bound; }
+  //! Modify the bound object for this node.
+  bound::HRectBound<MetricType>& Bound() { return bound; }
+
+  //! Return the statistic object for this node.
+  const StatisticType& Stat() const { return stat; }
+  //! Modify the statistic object for this node.
+  StatisticType& Stat() { return stat; }
+
+  //! Return the number of children in this node.
+  size_t NumChildren() const;
+
+  //! Return the metric that this tree uses.
+  MetricType Metric() const { return MetricType(); }
+
+  /**
+   * Return the index of the nearest child node to the given query point.  If
+   * this is a leaf node, it will return NumChildren() (invalid index).
+   */
+  template<typename VecType>
+  size_t GetNearestChild(
+      const VecType& point,
+      typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+
+  /**
+   * Return the index of the furthest child node to the given query point.  If
+   * this is a leaf node, it will return NumChildren() (invalid index).
+   */
+  template<typename VecType>
+  size_t GetFurthestChild(
+      const VecType& point,
+      typename boost::enable_if<IsVector<VecType> >::type* = 0) const;
+
+  /**
+   * Return the index of the nearest child node to the given query node.  If it
+   * can't decide, it will return NumChildren() (invalid index).
+   */
+  size_t GetNearestChild(const Octree& queryNode) const;
+
+  /**
+   * Return the index of the furthest child node to the given query node.  If it
+   * can't decide, it will return NumChildren() (invalid index).
+   */
+  size_t GetFurthestChild(const Octree& queryNode) const;
+
+  /**
+   * Return the furthest distance to a point held in this node.  If this is not
+   * a leaf node, then the distance is 0 because the node holds no points.
+   */
+  ElemType FurthestPointDistance() const;
+
+  /**
+   * Return the furthest possible descendant distance.  This returns the maximum
+   * distance from the centroid to the edge of the bound and not the empirical
+   * quantity which is the actual furthest descendant distance.  So the actual
+   * furthest descendant distance may be less than what this method returns (but
+   * it will never be greater than this).
+   */
+  ElemType FurthestDescendantDistance() const;
+
+  //! Return the minimum distance from the center of the node to any bound edge.
+  ElemType MinimumBoundDistance() const;
+
+  //! Return the distance from the center of this node to the center of the
+  //! parent node.
+  ElemType ParentDistance() const { return parentDistance; }
+  //! Modify the distance from the center of this node to the center of the
+  //! parent node.
+  ElemType& ParentDistance() { return parentDistance; }
+
+  /**
+   * Return the specified child.  If the index is out of bounds, unspecified
+   * behavior will occur.
+   */
+  const Octree& Child(const size_t child) const { return *children[child]; }
+
+  /**
+   * Return the specified child.  If the index is out of bounds, unspecified
+   * behavior will occur.
+   */
+  Octree& Child(const size_t child) { return *children[child]; }
+
+  /**
+   * Return the pointer to the given child.  This allows the child itself to be
+   * modified.
+   */
+  Octree*& ChildPtr(const size_t child) { return children[child]; }
+
+  //! Return the number of points in this node (0 if not a leaf).
+  size_t NumPoints() const;
+
+  //! Return the number of descendants of this node.
+  size_t NumDescendants() const;
+
+  /**
+   * Return the index (with reference to the dataset) of a particular
+   * descendant.
+   */
+  size_t Descendant(const size_t index) const;
+
+  /**
+   * Return the index (with reference to the dataset) of a particular point in
+   * this node.  If the given index is invalid (i.e. if it is greater than
+   * NumPoints()), the indices returned will be invalid.
+   */
+  size_t Point(const size_t index) const;
+
+  //! Return the minimum distance to another node.
+  ElemType MinDistance(const Octree* other) const;
+  //! Return the maximum distance to another node.
+  ElemType MaxDistance(const Octree* other) const;
+  //! Return the minimum and maximum distance to another node.
+  math::RangeType<ElemType> RangeDistance(const Octree* other) const;
+
+  //! Return the minimum distance to the given point.
+  template<typename VecType>
+  ElemType MinDistance(
+      const VecType& point,
+      typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+  //! Return the maximum distance to the given point.
+  template<typename VecType>
+  ElemType MaxDistance(
+      const VecType& point,
+      typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+  //! Return the minimum and maximum distance to another node.
+  template<typename VecType>
+  math::RangeType<ElemType> RangeDistance(
+      const VecType& point,
+      typename boost::enable_if<IsVector<VecType>>::type* = 0) const;
+
+  //! Store the center of the bounding region in the given vector.
+  void Center(arma::vec& center) const { bound.Center(center); }
+
  private:
   /**
    * Split the node, using the given center and the given maximum width of this
@@ -190,8 +341,11 @@ class Octree
    *
    * @param center Center of the node.
    * @param width Width of the current node.
+   * @param maxLeafSize Maximum number of points allowed in a leaf.
    */
-  void SplitNode(const arma::vec& center, const double width);
+  void SplitNode(const arma::vec& center,
+                 const double width,
+                 const size_t maxLeafSize);
 
   /**
    * Split the node, using the given center and the given maximum width of this
@@ -200,8 +354,18 @@ class Octree
    * @param center Center of the node.
    * @param width Width of the current node.
    * @param oldFromNew Mappings from old to new.
+   * @param maxLeafSize Maximum number of points allowed in a leaf.
    */
   void SplitNode(const arma::vec& center,
                  const double width,
-                 std::vector<size_t>& oldFromNew);
+                 std::vector<size_t>& oldFromNew,
+                 const size_t maxLeafSize);
 };
+
+} // namespace tree
+} // namespace mlpack
+
+// Include implementation.
+#include "octree_impl.hpp"
+
+#endif
diff --git a/src/mlpack/core/tree/octree/octree_impl.hpp b/src/mlpack/core/tree/octree/octree_impl.hpp
index cd8234a..33ec3e2 100644
--- a/src/mlpack/core/tree/octree/octree_impl.hpp
+++ b/src/mlpack/core/tree/octree/octree_impl.hpp
@@ -9,19 +9,82 @@
 
 #include "octree.hpp"
 
+namespace mlpack {
+namespace tree {
+
 //! Construct the tree.
 template<typename MetricType, typename StatisticType, typename MatType>
 Octree<MetricType, StatisticType, MatType>::Octree(const MatType& dataset,
-                                                   const double maxLeafSize) :
+                                                   const size_t maxLeafSize) :
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
     dataset(new MatType(dataset)),
+    parent(NULL),
+    parentDistance(0.0)
+{
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
+
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
 
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+}
+
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+    const MatType& dataset,
+    std::vector<size_t>& oldFromNew,
+    const size_t maxLeafSize) :
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
+    dataset(new MatType(dataset)),
+    parent(NULL),
+    parentDistance(0.0)
 {
-  // Calculate empirical center of data.
-  bound |= *dataset;
-  arma::vec center = bound.Center();
-  double maxWidth = bound.MaxWidth();
+  oldFromNew.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; ++i)
+    oldFromNew[i] = i;
+
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
 
-  SplitNode(center, maxWidth);
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
 
   // Initialize the statistic.
   stat = StatisticType(*this);
@@ -32,20 +95,122 @@ template<typename MetricType, typename StatisticType, typename MatType>
 Octree<MetricType, StatisticType, MatType>::Octree(
     const MatType& dataset,
     std::vector<size_t>& oldFromNew,
+    std::vector<size_t>& newFromOld,
     const size_t maxLeafSize) :
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
     dataset(new MatType(dataset)),
+    parent(NULL),
+    parentDistance(0.0)
+{
+  oldFromNew.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; ++i)
+    oldFromNew[i] = i;
+
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
+
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
+
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+
+  // Map the newFromOld indices correctly.
+  newFromOld.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; i++)
+    newFromOld[oldFromNew[i]] = i;
+}
 
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(MatType&& dataset,
+                                                   const size_t maxLeafSize) :
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
+    dataset(new MatType(std::move(dataset))),
+    parent(NULL),
+    parentDistance(0.0)
 {
-  // Calculate empirical center of data.
-  bound |= *dataset;
-  arma::vec center = bound.Center();
-  double maxWidth = bound.MaxWidth();
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
+
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
 
-  oldFromNew.resize(data.n_cols);
-  for (size_t i = 0; i < data.n_cols; ++i)
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+}
+
+//! Construct the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+    MatType&& dataset,
+    std::vector<size_t>& oldFromNew,
+    const size_t maxLeafSize) :
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
+    dataset(new MatType(std::move(dataset))),
+    parent(NULL),
+    parentDistance(0.0)
+{
+  oldFromNew.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; ++i)
     oldFromNew[i] = i;
 
-  SplitNode(center, maxWidth, oldFromNew);
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
+
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
 
   // Initialize the statistic.
   stat = StatisticType(*this);
@@ -54,34 +219,341 @@ Octree<MetricType, StatisticType, MatType>::Octree(
 //! Construct the tree.
 template<typename MetricType, typename StatisticType, typename MatType>
 Octree<MetricType, StatisticType, MatType>::Octree(
-    const MatType& dataset,
+    MatType&& dataset,
     std::vector<size_t>& oldFromNew,
+    std::vector<size_t>& newFromOld,
     const size_t maxLeafSize) :
-    dataset(new MatType(dataset)),
+    begin(0),
+    count(dataset.n_cols),
+    bound(dataset.n_rows),
+    dataset(new MatType(std::move(dataset))),
+    parent(NULL),
+    parentDistance(0.0)
+{
+  oldFromNew.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; ++i)
+    oldFromNew[i] = i;
+
+  if (count > 0)
+  {
+    // Calculate empirical center of data.
+    bound |= *this->dataset;
+    arma::vec center;
+    bound.Center(center);
 
+    double maxWidth = 0.0;
+    for (size_t i = 0; i < bound.Dim(); ++i)
+      if (bound[i].Hi() - bound[i].Lo() > maxWidth)
+        maxWidth = bound[i].Hi() - bound[i].Lo();
+
+    SplitNode(center, maxWidth, oldFromNew, maxLeafSize);
+
+    furthestDescendantDistance = 0.5 * bound.Diameter();
+  }
+  else
+  {
+    furthestDescendantDistance = 0.0;
+  }
+
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+
+  // Map the newFromOld indices correctly.
+  newFromOld.resize(this->dataset->n_cols);
+  for (size_t i = 0; i < this->dataset->n_cols; i++)
+    newFromOld[oldFromNew[i]] = i;
+}
+
+//! Construct a child node.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+    Octree* parent,
+    const size_t begin,
+    const size_t count,
+    const arma::vec& center,
+    const double width,
+    const size_t maxLeafSize) :
+    begin(begin),
+    count(count),
+    bound(parent->dataset->n_rows),
+    dataset(parent->dataset),
+    parent(parent)
 {
   // Calculate empirical center of data.
-  bound |= *dataset;
-  arma::vec center = bound.Center();
-  double maxWidth = bound.MaxWidth();
+  bound |= dataset->cols(begin, begin + count - 1);
 
-  oldFromNew.resize(data.n_cols);
-  for (size_t i = 0; i < data.n_cols; ++i)
-    oldFromNew[i] = i;
+  // Now split the node.
+  SplitNode(center, width, maxLeafSize);
 
-  SplitNode(center, maxWidth, oldFromNew);
+  // Calculate the distance from the empirical center of this node to the
+  // empirical center of the parent.
+  arma::vec trueCenter, parentCenter;
+  bound.Center(trueCenter);
+  parent->Bound().Center(parentCenter);
+  parentDistance = metric.Evaluate(trueCenter, parentCenter);
 
   // Initialize the statistic.
   stat = StatisticType(*this);
 }
 
+//! Construct a child node.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(
+    Octree* parent,
+    const size_t begin,
+    const size_t count,
+    std::vector<size_t>& oldFromNew,
+    const arma::vec& center,
+    const double width,
+    const size_t maxLeafSize) :
+    begin(begin),
+    count(count),
+    bound(parent->dataset->n_rows),
+    dataset(parent->dataset),
+    parent(parent)
+{
+  // Calculate empirical center of data.
+  bound |= dataset->cols(begin, begin + count - 1);
+
+  // Now split the node.
+  SplitNode(center, width, oldFromNew, maxLeafSize);
+
+  // Calculate the distance from the empirical center of this node to the
+  // empirical center of the parent.
+  arma::vec trueCenter, parentCenter;
+  bound.Center(trueCenter);
+  parent->Bound().Center(parentCenter);
+  parentDistance = metric.Evaluate(trueCenter, parentCenter);
+
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::~Octree()
+{
+  // Delete the dataset if we aren't the parent.
+  if (!parent)
+    delete dataset;
+
+  // Now delete each of the children.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+  children.clear();
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumChildren() const
+{
+  return children.size();
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+size_t Octree<MetricType, StatisticType, MatType>::GetNearestChild(
+    const VecType& point,
+    typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+  // It's possible that this could be improved by caching which children we have
+  // and which we don't, but for now this is just a brute force search.
+  ElemType bestDistance = DBL_MAX;
+  size_t bestIndex = NumChildren();
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    const double dist = children[i]->MinDistance(point);
+    if (dist < bestDistance)
+    {
+      bestDistance = dist;
+      bestIndex = i;
+    }
+  }
+
+  return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+size_t Octree<MetricType, StatisticType, MatType>::GetFurthestChild(
+    const VecType& point,
+    typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+  // It's possible that this could be improved by caching which children we have
+  // and which we don't, but for now this is just a brute force search.
+  ElemType bestDistance = -1.0; // Initialize to invalid distance.
+  size_t bestIndex = NumChildren();
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    const double dist = children[i]->MaxDistance(point);
+    if (dist > bestDistance)
+    {
+      bestDistance = dist;
+      bestIndex = i;
+    }
+  }
+
+  return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::GetNearestChild(
+    const Octree& queryNode) const
+{
+  // It's possible that this could be improved by caching which children we have
+  // and which we don't, but for now this is just a brute force search.
+  ElemType bestDistance = DBL_MAX;
+  size_t bestIndex = NumChildren();
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    const double dist = children[i]->MaxDistance(queryNode);
+    if (dist < bestDistance)
+    {
+      bestDistance = dist;
+      bestIndex = i;
+    }
+  }
+
+  return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::GetFurthestChild(
+    const Octree& queryNode) const
+{
+  // It's possible that this could be improved by caching which children we have
+  // and which we don't, but for now this is just a brute force search.
+  ElemType bestDistance = -1.0; // Initialize to invalid distance.
+  size_t bestIndex = NumChildren();
+  for (size_t i = 0; i < NumChildren(); ++i)
+  {
+    const double dist = children[i]->MaxDistance(queryNode);
+    if (dist > bestDistance)
+    {
+      bestDistance = dist;
+      bestIndex = i;
+    }
+  }
+
+  return bestIndex;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::FurthestPointDistance()
+    const
+{
+  // If we are not a leaf, then this distance is 0.  Otherwise, return the
+  // furthest descendant distance.
+  return (children.size() > 0) ? 0.0 : furthestDescendantDistance;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::FurthestDescendantDistance() const
+{
+  return furthestDescendantDistance;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinimumBoundDistance() const
+{
+  return bound.MinWidth() / 2.0;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumPoints() const
+{
+  // We have no points unless we are a leaf;
+  return (children.size() > 0) ? 0 : count;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::NumDescendants() const
+{
+  return count;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::Descendant(
+    const size_t index) const
+{
+  return begin + index;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+size_t Octree<MetricType, StatisticType, MatType>::Point(const size_t index)
+    const
+{
+  return begin + index;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinDistance(const Octree* other)
+    const
+{
+  return bound.MinDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MaxDistance(const Octree* other)
+    const
+{
+  return bound.MaxDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+math::RangeType<typename Octree<MetricType, StatisticType, MatType>::ElemType>
+Octree<MetricType, StatisticType, MatType>::RangeDistance(const Octree* other)
+    const
+{
+  return bound.RangeDistance(other->Bound());
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MinDistance(
+    const VecType& point,
+    typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+  return bound.MinDistance(point);
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+typename Octree<MetricType, StatisticType, MatType>::ElemType
+Octree<MetricType, StatisticType, MatType>::MaxDistance(
+    const VecType& point,
+    typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+  return bound.MaxDistance(point);
+}
+
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename VecType>
+math::RangeType<typename Octree<MetricType, StatisticType, MatType>::ElemType>
+Octree<MetricType, StatisticType, MatType>::RangeDistance(
+    const VecType& point,
+    typename boost::enable_if<IsVector<VecType>>::type*) const
+{
+  return bound.RangeDistance(point);
+}
 
 //! Split the node.
 template<typename MetricType, typename StatisticType, typename MatType>
 void Octree<MetricType, StatisticType, MatType>::SplitNode(
     const arma::vec& center,
-    const double width)
+    const double width,
+    const size_t maxLeafSize)
 {
+  // No need to split if we have fewer than the maximum number of points in this
+  // node.
+  if (count <= maxLeafSize)
+    return;
+
   // We must split the dataset by sequentially creating each of the children.
   // We do this in two steps: first we make a pass to count the number of points
   // that will fall into each child; then in the second pass we rearrange the
@@ -101,7 +573,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
       // the points fall on.  The last dimension represents the most significant
       // bit in the assignment; the bit is '1' if it falls to the right of the
       // center.
-      if (dataset(d, begin + i) > center(d))
+      if ((*dataset)(d, begin + i) > center(d))
         assignments(i) |= (1 << d);
     }
 
@@ -130,7 +602,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
     for (size_t d = 0; d < center.n_elem; ++d)
     {
       // Is the dimension "right" (1) or "left" (0)?
-      if ((i >> d) & 1 == 0)
+      if (((i >> d) & 1) == 0)
         childCenter[d] = center[d] - childWidth;
       else
         childCenter[d] = center[d] + childWidth;
@@ -148,8 +620,14 @@ template<typename MetricType, typename StatisticType, typename MatType>
 void Octree<MetricType, StatisticType, MatType>::SplitNode(
     const arma::vec& center,
     const double width,
-    std::vector<size_t>& oldFromNew)
+    std::vector<size_t>& oldFromNew,
+    const size_t maxLeafSize)
 {
+  // No need to split if we have fewer than the maximum number of points in this
+  // node.
+  if (count <= maxLeafSize)
+    return;
+
   // We must split the dataset by sequentially creating each of the children.
   // We do this in two steps: first we make a pass to count the number of points
   // that will fall into each child; then in the second pass we rearrange the
@@ -169,7 +647,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
       // the points fall on.  The last dimension represents the most significant
       // bit in the assignment; the bit is '1' if it falls to the right of the
       // center.
-      if (dataset(d, begin + i) > center(d))
+      if ((*dataset)(d, begin + i) > center(d))
         assignments(i) |= (1 << d);
     }
 
@@ -183,8 +661,9 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
   // really a problem.  We use non-contiguous submatrix views to extract the
   // columns in the correct order.
   dataset->cols(begin, begin + count - 1) = dataset->cols(begin + ordering);
+  std::vector<size_t> oldFromNewCopy(oldFromNew); // We need the old indices.
   for (size_t i = 0; i < count; ++i)
-    oldFromNew[ordering[i] + begin] = i + begin;
+    oldFromNew[i + begin] = oldFromNewCopy[ordering[i] + begin];
 
   // Now that the dataset is reordered, we can create the children.
   size_t childBegin = begin;
@@ -200,7 +679,7 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
     for (size_t d = 0; d < center.n_elem; ++d)
     {
       // Is the dimension "right" (1) or "left" (0)?
-      if ((i >> d) & 1 == 0)
+      if (((i >> d) & 1) == 0)
         childCenter[d] = center[d] - childWidth;
       else
         childCenter[d] = center[d] + childWidth;
@@ -212,3 +691,8 @@ void Octree<MetricType, StatisticType, MatType>::SplitNode(
     childBegin += childCounts[i];
   }
 }
+
+} // namespace tree
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 9ad4092..099aa29 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -57,6 +57,7 @@ add_executable(mlpack_test
   network_util_test.cpp
   nmf_test.cpp
   nystroem_method_test.cpp
+  octree_test.cpp
   pca_test.cpp
   perceptron_test.cpp
   quic_svd_test.cpp
diff --git a/src/mlpack/tests/octree_test.cpp b/src/mlpack/tests/octree_test.cpp
new file mode 100644
index 0000000..b2647b3
--- /dev/null
+++ b/src/mlpack/tests/octree_test.cpp
@@ -0,0 +1,148 @@
+/**
+ * @file octree_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test various properties of the Octree.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/octree.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "test_tools.hpp"
+
+using namespace mlpack;
+using namespace mlpack::math;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
+
+BOOST_AUTO_TEST_SUITE(OctreeTest);
+
+/**
+ * Build a quad-tree (2-d octree) on 4 points, and guarantee four points are
+ * created.
+ */
+BOOST_AUTO_TEST_CASE(SimpleQuadtreeTest)
+{
+  // Four corners of the unit square.
+  arma::mat dataset("0 0 1 1; 0 1 0 1");
+
+  Octree<> t(dataset, 1);
+
+  BOOST_REQUIRE_EQUAL(t.NumChildren(), 4);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 4);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 2);
+  BOOST_REQUIRE_EQUAL(t.NumDescendants(), 4);
+  BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+  for (size_t i = 0; i < 4; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(t.Child(i).NumDescendants(), 1);
+    BOOST_REQUIRE_EQUAL(t.Child(i).NumPoints(), 1);
+  }
+}
+
+/**
+ * Build an octree on 3 points and make sure that only three children are
+ * created.
+ */
+BOOST_AUTO_TEST_CASE(OctreeMissingChildTest)
+{
+  // Only three corners of the unit square.
+  arma::mat dataset("0 0 1; 0 1 1");
+
+  Octree<> t(dataset, 1);
+
+  BOOST_REQUIRE_EQUAL(t.NumChildren(), 3);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 3);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 2);
+  BOOST_REQUIRE_EQUAL(t.NumDescendants(), 3);
+  BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+  for (size_t i = 0; i < 3; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(t.Child(i).NumDescendants(), 1);
+    BOOST_REQUIRE_EQUAL(t.Child(i).NumPoints(), 1);
+  }
+}
+
+/**
+ * Ensure that building an empty octree does not fail.
+ */
+BOOST_AUTO_TEST_CASE(EmptyOctreeTest)
+{
+  arma::mat dataset;
+  Octree<> t(dataset);
+
+  BOOST_REQUIRE_EQUAL(t.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 0);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 0);
+  BOOST_REQUIRE_EQUAL(t.NumDescendants(), 0);
+  BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+}
+
+/**
+ * Ensure that maxLeafSize is respected.
+ */
+BOOST_AUTO_TEST_CASE(MaxLeafSizeTest)
+{
+  arma::mat dataset(5, 15, arma::fill::randu);
+  Octree<> t1(dataset, 20);
+  Octree<> t2(std::move(dataset), 20);
+
+  BOOST_REQUIRE_EQUAL(t1.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(t1.NumDescendants(), 15);
+  BOOST_REQUIRE_EQUAL(t1.NumPoints(), 15);
+
+  BOOST_REQUIRE_EQUAL(t2.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(t2.NumDescendants(), 15);
+  BOOST_REQUIRE_EQUAL(t2.NumPoints(), 15);
+}
+
+/**
+ * Check that the mappings given are correct.
+ */
+BOOST_AUTO_TEST_CASE(MappingsTest)
+{
+  // Test with both constructors.
+  arma::mat dataset(3, 5, arma::fill::randu);
+  arma::mat datacopy(dataset);
+  std::vector<size_t> oldFromNewCopy, oldFromNewMove;
+
+  Octree<> t1(dataset, oldFromNewCopy, 1);
+  Octree<> t2(std::move(dataset), oldFromNewMove, 1);
+
+  for (size_t i = 0; i < oldFromNewCopy.size(); ++i)
+  {
+    BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewCopy[i]) -
+        t1.Dataset().col(i)), 1e-3);
+    BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewMove[i]) -
+        t2.Dataset().col(i)), 1e-3);
+  }
+}
+
+/**
+ * Check that the reverse mappings are correct too.
+ */
+BOOST_AUTO_TEST_CASE(ReverseMappingsTest)
+{
+  // Test with both constructors.
+  arma::mat dataset(3, 300, arma::fill::randu);
+  arma::mat datacopy(dataset);
+  std::vector<size_t> oldFromNewCopy, oldFromNewMove, newFromOldCopy,
+      newFromOldMove;
+
+  Octree<> t1(dataset, oldFromNewCopy, newFromOldCopy);
+  Octree<> t2(std::move(dataset), oldFromNewMove, newFromOldMove);
+
+  for (size_t i = 0; i < oldFromNewCopy.size(); ++i)
+  {
+    BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewCopy[i]) -
+        t1.Dataset().col(i)), 1e-3);
+    BOOST_REQUIRE_SMALL(arma::norm(datacopy.col(oldFromNewMove[i]) -
+        t2.Dataset().col(i)), 1e-3);
+
+    BOOST_REQUIRE_EQUAL(newFromOldCopy[oldFromNewCopy[i]], i);
+    BOOST_REQUIRE_EQUAL(newFromOldMove[oldFromNewMove[i]], i);
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list