[mlpack-git] master: Add copy, move, and serialization, plus tests for the octree. (b305938)

gitdub at mlpack.org gitdub at mlpack.org
Fri Sep 23 17:21:46 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit b305938d5887c0b8f6bcb246f789bd34725cf4a4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 23 17:21:46 2016 -0400

    Add copy, move, and serialization, plus tests for the octree.


>---------------------------------------------------------------

b305938d5887c0b8f6bcb246f789bd34725cf4a4
 src/mlpack/core/tree/octree/octree.hpp      |  46 +++++++
 src/mlpack/core/tree/octree/octree_impl.hpp | 156 ++++++++++++++++++++++
 src/mlpack/tests/octree_test.cpp            | 194 ++++++++++++++++++++++++++++
 3 files changed, 396 insertions(+)

diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index 3fec915..f7fa385 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -193,6 +193,31 @@ class Octree
          const size_t maxLeafSize = 20);
 
   /**
+   * Copy the given tree.  Be careful!  This may use a lot of memory.
+   *
+   * @param other Tree to copy from.
+   */
+  Octree(const Octree& other);
+
+  /**
+   * Move the given tree.  The tree passed as a parameter will be emptied and
+   * will not be usable after this call.
+   *
+   * @param other Tree to move.
+   */
+  Octree(Octree&& other);
+
+  /**
+   * Initialize the tree from a boost::serialization archive.
+   *
+   * @param ar Archive to load tree from.  Must be an iarchive, not an oarchive.
+   */
+  template<typename Archive>
+  Octree(
+      Archive& ar,
+      const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+  /**
    * Destroy the tree.
    */
   ~Octree();
@@ -200,6 +225,11 @@ class Octree
   //! Return the dataset used by this node.
   const MatType& Dataset() const { return *dataset; }
 
+  //! Get the pointer to the parent.
+  Octree* Parent() const { return parent; }
+  //! Modify the pointer to the parent (be careful!).
+  Octree*& Parent() { return parent; }
+
   //! Return the bound object for this node.
   const bound::HRectBound<MetricType>& Bound() const { return bound; }
   //! Modify the bound object for this node.
@@ -334,6 +364,22 @@ class Octree
   //! Store the center of the bounding region in the given vector.
   void Center(arma::vec& center) const { bound.Center(center); }
 
+  //! Serialize the tree.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+ protected:
+  /**
+   * A default constructor.  This is meant to only be used with
+   * boost::serialization, which is allowed with the friend declaration below.
+   * This does not return a valid treee!  The method must be protected, so that
+   * the serialization shim can work with the default constructor.
+   */
+  Octree();
+
+  //! Friend access is given for the default constructor.
+  friend class boost::serialization::access;
+
  private:
   /**
    * Split the node, using the given center and the given maximum width of this
diff --git a/src/mlpack/core/tree/octree/octree_impl.hpp b/src/mlpack/core/tree/octree/octree_impl.hpp
index 33ec3e2..987d139 100644
--- a/src/mlpack/core/tree/octree/octree_impl.hpp
+++ b/src/mlpack/core/tree/octree/octree_impl.hpp
@@ -292,6 +292,8 @@ Octree<MetricType, StatisticType, MatType>::Octree(
   parent->Bound().Center(parentCenter);
   parentDistance = metric.Evaluate(trueCenter, parentCenter);
 
+  furthestDescendantDistance = 0.5 * bound.Diameter();
+
   // Initialize the statistic.
   stat = StatisticType(*this);
 }
@@ -325,10 +327,85 @@ Octree<MetricType, StatisticType, MatType>::Octree(
   parent->Bound().Center(parentCenter);
   parentDistance = metric.Evaluate(trueCenter, parentCenter);
 
+  furthestDescendantDistance = 0.5 * bound.Diameter();
+
   // Initialize the statistic.
   stat = StatisticType(*this);
 }
 
+//! Copy the given tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(const Octree& other) :
+    begin(other.begin),
+    count(other.count),
+    bound(other.bound),
+    dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
+    parent(NULL),
+    stat(other.stat),
+    parentDistance(other.parentDistance),
+    furthestDescendantDistance(other.furthestDescendantDistance),
+    metric(other.metric)
+{
+  // If we have any children, we need to create them, and then ensure that their
+  // parent links are set right.
+  for (size_t i = 0; i < other.NumChildren(); ++i)
+  {
+    children.push_back(new Octree(other.Child(i)));
+    children[i]->parent = this;
+    children[i]->dataset = this->dataset;
+  }
+}
+
+//! Move the given tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(Octree&& other) :
+    children(std::move(other.children)),
+    begin(other.begin),
+    count(other.count),
+    bound(std::move(other.bound)),
+    dataset(other.dataset),
+    parent(other.parent),
+    stat(std::move(other.stat)),
+    parentDistance(other.parentDistance),
+    furthestDescendantDistance(other.furthestDescendantDistance),
+    metric(std::move(other.metric))
+{
+  // Update the parent pointers of the direct children.
+  for (size_t i = 0; i < children.size(); ++i)
+    children[i]->parent = this;
+
+  other.begin = 0;
+  other.count = 0;
+  other.dataset = new MatType();
+  other.parentDistance = 0.0;
+  other.furthestDescendantDistance = 0.0;
+  other.parent = NULL;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree() :
+    begin(0),
+    count(0),
+    bound(0),
+    dataset(new MatType()),
+    parent(NULL),
+    parentDistance(0.0),
+    furthestDescendantDistance(0.0)
+{
+  // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename Archive>
+Octree<MetricType, StatisticType, MatType>::Octree(
+    Archive& ar,
+    const typename boost::enable_if<typename Archive::is_loading>::type*) :
+    Octree() // Create an empty tree.
+{
+  // De-serialize the tree into this object.
+  ar >> data::CreateNVP(*this, "tree");
+}
+
 template<typename MetricType, typename StatisticType, typename MatType>
 Octree<MetricType, StatisticType, MatType>::~Octree()
 {
@@ -542,6 +619,85 @@ Octree<MetricType, StatisticType, MatType>::RangeDistance(
   return bound.RangeDistance(point);
 }
 
+//! Serialize the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename Archive>
+void Octree<MetricType, StatisticType, MatType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // 
+
+  // If we're loading and we have children, they need to be deleted.
+  if (Archive::is_loading::value)
+  {
+    for (size_t i = 0; i < children.size(); ++i)
+      delete children[i];
+    children.clear();
+
+    if (!parent)
+      delete dataset;
+  }
+
+  ar & CreateNVP(begin, "begin");
+  ar & CreateNVP(count, "count");
+  ar & CreateNVP(bound, "bound");
+  ar & CreateNVP(stat, "stat");
+  ar & CreateNVP(parentDistance, "parentDistance");
+  ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+  ar & CreateNVP(metric, "metric");
+
+  // Due to quirks of boost::serialization, depending on how the user
+  // serializes the tree, it's possible that the root of the tree will
+  // accidentally be serialized twice.  So if we are a first-level child, we
+  // avoid serializing the parent.  The true (non-duplicated) parent will fix
+  // the parent link.
+  bool hasFakeParent = false;
+  if (Archive::is_saving::value && parent != NULL && parent->parent == NULL)
+  {
+    Octree* fakeParent = NULL;
+    hasFakeParent = true;
+    ar & CreateNVP(fakeParent, "parent");
+    ar & CreateNVP(hasFakeParent, "hasFakeParent");
+  }
+  else
+  {
+    ar & CreateNVP(parent, "parent");
+    ar & CreateNVP(hasFakeParent, "hasFakeParent");
+  }
+
+  // Only serialize the dataset if we don't have a fake parent.  Otherwise, the
+  // real parent will come and set it later.
+  if (!hasFakeParent)
+    ar & CreateNVP(dataset, "dataset");
+
+  size_t numChildren = 0;
+  if (Archive::is_saving::value)
+    numChildren = children.size();
+  ar & CreateNVP(numChildren, "numChildren");
+  if (Archive::is_loading::value)
+    children.resize(numChildren);
+
+  for (size_t i = 0; i < numChildren; ++i)
+  {
+    std::ostringstream oss;
+    oss << "child" << i;
+    ar & CreateNVP(children[i], oss.str());
+  }
+
+  // Fix the child pointers, if they were set to a fake parent.
+  if (Archive::is_loading::value && parent == NULL)
+  {
+    for (size_t i = 0; i < children.size(); ++i)
+    {
+      children[i]->dataset = this->dataset;
+      children[i]->parent = this;
+    }
+  }
+}
+
 //! Split the node.
 template<typename MetricType, typename StatisticType, typename MatType>
 void Octree<MetricType, StatisticType, MatType>::SplitNode(
diff --git a/src/mlpack/tests/octree_test.cpp b/src/mlpack/tests/octree_test.cpp
index b2647b3..f237016 100644
--- a/src/mlpack/tests/octree_test.cpp
+++ b/src/mlpack/tests/octree_test.cpp
@@ -9,6 +9,7 @@
 
 #include <boost/test/unit_test.hpp>
 #include "test_tools.hpp"
+#include "serialization.hpp"
 
 using namespace mlpack;
 using namespace mlpack::math;
@@ -145,4 +146,197 @@ BOOST_AUTO_TEST_CASE(ReverseMappingsTest)
   }
 }
 
+/**
+ * Make sure no children at the same level are overlapping.
+ */
+template<typename TreeType>
+void CheckOverlap(TreeType& node)
+{
+  // Check each combination of children.
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    for (size_t j = i + 1; j < node.NumChildren(); ++j)
+      BOOST_REQUIRE_EQUAL(node.Child(i).Bound().Overlap(node.Child(j).Bound()),
+          0.0); // We need exact equality here.
+
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    CheckOverlap(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(OverlapTest)
+{
+  // Test with both constructors.
+  arma::mat dataset(3, 300, arma::fill::randu);
+
+  Octree<> t1(dataset);
+  Octree<> t2(std::move(dataset));
+
+  CheckOverlap(t1);
+  CheckOverlap(t2);
+}
+
+/**
+ * Make sure no points are further than the furthest point distance, and that no
+ * descendants are further than the furthest descendant distance.
+ */
+template<typename TreeType>
+void CheckFurthestDistances(TreeType& node)
+{
+  arma::vec center;
+  node.Center(center);
+
+  // Compare points held in the node.
+  for (size_t i = 0; i < node.NumPoints(); ++i)
+  {
+    // Handle floating-point inaccuracies.
+    BOOST_REQUIRE_LE(metric::EuclideanDistance::Evaluate(node.Dataset().col(node.Point(i)),
+        center), node.FurthestPointDistance() * (1 + 1e-5));
+  }
+
+  // Compare descendants held in the node.
+  for (size_t i = 0; i < node.NumDescendants(); ++i)
+  {
+    // Handle floating-point inaccuracies.
+    BOOST_REQUIRE_LE(metric::EuclideanDistance::Evaluate(node.Dataset().col(node.Descendant(i)),
+        center), node.FurthestDescendantDistance() * (1 + 1e-5));
+  }
+
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    CheckFurthestDistances(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(FurthestDistanceTest)
+{
+  // Test with both constructors.
+  arma::mat dataset(3, 500, arma::fill::randu);
+
+  Octree<> t1(dataset);
+  Octree<> t2(std::move(dataset));
+
+  CheckFurthestDistances(t1);
+  CheckFurthestDistances(t2);
+}
+
+/**
+ * The maximum number of children a node can have is limited by the
+ * dimensionality.  So we test to make sure there are no cases where we have too
+ * many children.
+ */
+template<typename TreeType>
+void CheckNumChildren(TreeType& node)
+{
+  BOOST_REQUIRE_LE(node.NumChildren(), std::pow(2, node.Dataset().n_rows));
+  for (size_t i = 0; i < node.NumChildren(); ++i)
+    CheckNumChildren(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(MaxNumChildrenTest)
+{
+  for (size_t d = 1; d < 10; ++d)
+  {
+    arma::mat dataset(d, 1000 * d, arma::fill::randu);
+    Octree<> t(std::move(dataset));
+
+    CheckNumChildren(t);
+  }
+}
+
+/**
+ * Test the copy constructor.
+ */
+template<typename TreeType>
+void CheckSameNode(TreeType& node1, TreeType& node2)
+{
+  BOOST_REQUIRE_EQUAL(node1.NumChildren(), node2.NumChildren());
+  BOOST_REQUIRE_NE(&node1.Dataset(), &node2.Dataset());
+
+  // Make sure the children actually got copied.
+  for (size_t i = 0; i < node1.NumChildren(); ++i)
+    BOOST_REQUIRE_NE(&node1.Child(i), &node2.Child(i));
+
+  // Check that all the points are the same.
+  BOOST_REQUIRE_EQUAL(node1.NumPoints(), node2.NumPoints());
+  BOOST_REQUIRE_EQUAL(node1.NumDescendants(), node2.NumDescendants());
+  for (size_t i = 0; i < node1.NumPoints(); ++i)
+    BOOST_REQUIRE_EQUAL(node1.Point(i), node2.Point(i));
+  for (size_t i = 0; i < node1.NumDescendants(); ++i)
+    BOOST_REQUIRE_EQUAL(node1.Descendant(i), node2.Descendant(i));
+
+  // Check that the bound is the same.
+  BOOST_REQUIRE_EQUAL(node1.Bound().Dim(), node2.Bound().Dim());
+  for (size_t d = 0; d < node1.Bound().Dim(); ++d)
+  {
+    BOOST_REQUIRE_CLOSE(node1.Bound()[d].Lo(), node2.Bound()[d].Lo(), 1e-5);
+    BOOST_REQUIRE_CLOSE(node1.Bound()[d].Hi(), node2.Bound()[d].Hi(), 1e-5);
+  }
+
+  // Check that the furthest point and descendant distance are the same.
+  BOOST_REQUIRE_CLOSE(node1.FurthestPointDistance(),
+      node2.FurthestPointDistance(), 1e-5);
+  BOOST_REQUIRE_CLOSE(node1.FurthestDescendantDistance(),
+      node2.FurthestDescendantDistance(), 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(CopyConstructorTest)
+{
+  // Use a small random dataset.
+  arma::mat dataset(3, 100, arma::fill::randu);
+
+  Octree<> t(dataset);
+  Octree<> t2(t);
+
+  CheckSameNode(t, t2);
+}
+
+/**
+ * Test the move constructor.
+ */
+BOOST_AUTO_TEST_CASE(MoveConstructorTest)
+{
+  // Use a small random dataset.
+  arma::mat dataset(3, 100, arma::fill::randu);
+
+  Octree<> t(std::move(dataset));
+  Octree<> tcopy(t);
+
+  // Move the tree.
+  Octree<> t2(std::move(t));
+
+  // Make sure the original tree has no data.
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 0);
+  BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 0);
+  BOOST_REQUIRE_EQUAL(t.NumChildren(), 0);
+  BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+  BOOST_REQUIRE_EQUAL(t.NumDescendants(), 0);
+  BOOST_REQUIRE_SMALL(t.FurthestPointDistance(), 1e-5);
+  BOOST_REQUIRE_SMALL(t.FurthestDescendantDistance(), 1e-5);
+  BOOST_REQUIRE_EQUAL(t.Bound().Dim(), 0);
+
+  // Check that the new tree is the same as our copy.
+  CheckSameNode(tcopy, t2);
+}
+
+/**
+ * Test serialization.
+ */
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+  // Use a small random dataset.
+  arma::mat dataset(3, 500, arma::fill::randu);
+  Octree<> t(std::move(dataset));
+
+  Octree<>* xmlTree;
+  Octree<>* binaryTree;
+  Octree<>* textTree;
+
+  SerializePointerObjectAll(&t, xmlTree, binaryTree, textTree);
+
+  CheckSameNode(t, *xmlTree);
+  CheckSameNode(t, *binaryTree);
+  CheckSameNode(t, *textTree);
+
+  delete xmlTree;
+  delete binaryTree;
+  delete textTree;
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list