[mlpack-git] master: Add copy, move, and serialization, plus tests for the octree. (b305938)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Sep 23 17:21:46 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit b305938d5887c0b8f6bcb246f789bd34725cf4a4
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Sep 23 17:21:46 2016 -0400
Add copy, move, and serialization, plus tests for the octree.
>---------------------------------------------------------------
b305938d5887c0b8f6bcb246f789bd34725cf4a4
src/mlpack/core/tree/octree/octree.hpp | 46 +++++++
src/mlpack/core/tree/octree/octree_impl.hpp | 156 ++++++++++++++++++++++
src/mlpack/tests/octree_test.cpp | 194 ++++++++++++++++++++++++++++
3 files changed, 396 insertions(+)
diff --git a/src/mlpack/core/tree/octree/octree.hpp b/src/mlpack/core/tree/octree/octree.hpp
index 3fec915..f7fa385 100644
--- a/src/mlpack/core/tree/octree/octree.hpp
+++ b/src/mlpack/core/tree/octree/octree.hpp
@@ -193,6 +193,31 @@ class Octree
const size_t maxLeafSize = 20);
/**
+ * Copy the given tree. Be careful! This may use a lot of memory.
+ *
+ * @param other Tree to copy from.
+ */
+ Octree(const Octree& other);
+
+ /**
+ * Move the given tree. The tree passed as a parameter will be emptied and
+ * will not be usable after this call.
+ *
+ * @param other Tree to move.
+ */
+ Octree(Octree&& other);
+
+ /**
+ * Initialize the tree from a boost::serialization archive.
+ *
+ * @param ar Archive to load tree from. Must be an iarchive, not an oarchive.
+ */
+ template<typename Archive>
+ Octree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+ /**
* Destroy the tree.
*/
~Octree();
@@ -200,6 +225,11 @@ class Octree
//! Return the dataset used by this node.
const MatType& Dataset() const { return *dataset; }
+ //! Get the pointer to the parent.
+ Octree* Parent() const { return parent; }
+ //! Modify the pointer to the parent (be careful!).
+ Octree*& Parent() { return parent; }
+
//! Return the bound object for this node.
const bound::HRectBound<MetricType>& Bound() const { return bound; }
//! Modify the bound object for this node.
@@ -334,6 +364,22 @@ class Octree
//! Store the center of the bounding region in the given vector.
void Center(arma::vec& center) const { bound.Center(center); }
+ //! Serialize the tree.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ protected:
+ /**
+ * A default constructor. This is meant to only be used with
+ * boost::serialization, which is allowed with the friend declaration below.
+ * This does not return a valid treee! The method must be protected, so that
+ * the serialization shim can work with the default constructor.
+ */
+ Octree();
+
+ //! Friend access is given for the default constructor.
+ friend class boost::serialization::access;
+
private:
/**
* Split the node, using the given center and the given maximum width of this
diff --git a/src/mlpack/core/tree/octree/octree_impl.hpp b/src/mlpack/core/tree/octree/octree_impl.hpp
index 33ec3e2..987d139 100644
--- a/src/mlpack/core/tree/octree/octree_impl.hpp
+++ b/src/mlpack/core/tree/octree/octree_impl.hpp
@@ -292,6 +292,8 @@ Octree<MetricType, StatisticType, MatType>::Octree(
parent->Bound().Center(parentCenter);
parentDistance = metric.Evaluate(trueCenter, parentCenter);
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+
// Initialize the statistic.
stat = StatisticType(*this);
}
@@ -325,10 +327,85 @@ Octree<MetricType, StatisticType, MatType>::Octree(
parent->Bound().Center(parentCenter);
parentDistance = metric.Evaluate(trueCenter, parentCenter);
+ furthestDescendantDistance = 0.5 * bound.Diameter();
+
// Initialize the statistic.
stat = StatisticType(*this);
}
+//! Copy the given tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(const Octree& other) :
+ begin(other.begin),
+ count(other.count),
+ bound(other.bound),
+ dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
+ parent(NULL),
+ stat(other.stat),
+ parentDistance(other.parentDistance),
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ metric(other.metric)
+{
+ // If we have any children, we need to create them, and then ensure that their
+ // parent links are set right.
+ for (size_t i = 0; i < other.NumChildren(); ++i)
+ {
+ children.push_back(new Octree(other.Child(i)));
+ children[i]->parent = this;
+ children[i]->dataset = this->dataset;
+ }
+}
+
+//! Move the given tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree(Octree&& other) :
+ children(std::move(other.children)),
+ begin(other.begin),
+ count(other.count),
+ bound(std::move(other.bound)),
+ dataset(other.dataset),
+ parent(other.parent),
+ stat(std::move(other.stat)),
+ parentDistance(other.parentDistance),
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ metric(std::move(other.metric))
+{
+ // Update the parent pointers of the direct children.
+ for (size_t i = 0; i < children.size(); ++i)
+ children[i]->parent = this;
+
+ other.begin = 0;
+ other.count = 0;
+ other.dataset = new MatType();
+ other.parentDistance = 0.0;
+ other.furthestDescendantDistance = 0.0;
+ other.parent = NULL;
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+Octree<MetricType, StatisticType, MatType>::Octree() :
+ begin(0),
+ count(0),
+ bound(0),
+ dataset(new MatType()),
+ parent(NULL),
+ parentDistance(0.0),
+ furthestDescendantDistance(0.0)
+{
+ // Nothing to do.
+}
+
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename Archive>
+Octree<MetricType, StatisticType, MatType>::Octree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type*) :
+ Octree() // Create an empty tree.
+{
+ // De-serialize the tree into this object.
+ ar >> data::CreateNVP(*this, "tree");
+}
+
template<typename MetricType, typename StatisticType, typename MatType>
Octree<MetricType, StatisticType, MatType>::~Octree()
{
@@ -542,6 +619,85 @@ Octree<MetricType, StatisticType, MatType>::RangeDistance(
return bound.RangeDistance(point);
}
+//! Serialize the tree.
+template<typename MetricType, typename StatisticType, typename MatType>
+template<typename Archive>
+void Octree<MetricType, StatisticType, MatType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ //
+
+ // If we're loading and we have children, they need to be deleted.
+ if (Archive::is_loading::value)
+ {
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+ children.clear();
+
+ if (!parent)
+ delete dataset;
+ }
+
+ ar & CreateNVP(begin, "begin");
+ ar & CreateNVP(count, "count");
+ ar & CreateNVP(bound, "bound");
+ ar & CreateNVP(stat, "stat");
+ ar & CreateNVP(parentDistance, "parentDistance");
+ ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+ ar & CreateNVP(metric, "metric");
+
+ // Due to quirks of boost::serialization, depending on how the user
+ // serializes the tree, it's possible that the root of the tree will
+ // accidentally be serialized twice. So if we are a first-level child, we
+ // avoid serializing the parent. The true (non-duplicated) parent will fix
+ // the parent link.
+ bool hasFakeParent = false;
+ if (Archive::is_saving::value && parent != NULL && parent->parent == NULL)
+ {
+ Octree* fakeParent = NULL;
+ hasFakeParent = true;
+ ar & CreateNVP(fakeParent, "parent");
+ ar & CreateNVP(hasFakeParent, "hasFakeParent");
+ }
+ else
+ {
+ ar & CreateNVP(parent, "parent");
+ ar & CreateNVP(hasFakeParent, "hasFakeParent");
+ }
+
+ // Only serialize the dataset if we don't have a fake parent. Otherwise, the
+ // real parent will come and set it later.
+ if (!hasFakeParent)
+ ar & CreateNVP(dataset, "dataset");
+
+ size_t numChildren = 0;
+ if (Archive::is_saving::value)
+ numChildren = children.size();
+ ar & CreateNVP(numChildren, "numChildren");
+ if (Archive::is_loading::value)
+ children.resize(numChildren);
+
+ for (size_t i = 0; i < numChildren; ++i)
+ {
+ std::ostringstream oss;
+ oss << "child" << i;
+ ar & CreateNVP(children[i], oss.str());
+ }
+
+ // Fix the child pointers, if they were set to a fake parent.
+ if (Archive::is_loading::value && parent == NULL)
+ {
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ children[i]->dataset = this->dataset;
+ children[i]->parent = this;
+ }
+ }
+}
+
//! Split the node.
template<typename MetricType, typename StatisticType, typename MatType>
void Octree<MetricType, StatisticType, MatType>::SplitNode(
diff --git a/src/mlpack/tests/octree_test.cpp b/src/mlpack/tests/octree_test.cpp
index b2647b3..f237016 100644
--- a/src/mlpack/tests/octree_test.cpp
+++ b/src/mlpack/tests/octree_test.cpp
@@ -9,6 +9,7 @@
#include <boost/test/unit_test.hpp>
#include "test_tools.hpp"
+#include "serialization.hpp"
using namespace mlpack;
using namespace mlpack::math;
@@ -145,4 +146,197 @@ BOOST_AUTO_TEST_CASE(ReverseMappingsTest)
}
}
+/**
+ * Make sure no children at the same level are overlapping.
+ */
+template<typename TreeType>
+void CheckOverlap(TreeType& node)
+{
+ // Check each combination of children.
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ for (size_t j = i + 1; j < node.NumChildren(); ++j)
+ BOOST_REQUIRE_EQUAL(node.Child(i).Bound().Overlap(node.Child(j).Bound()),
+ 0.0); // We need exact equality here.
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ CheckOverlap(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(OverlapTest)
+{
+ // Test with both constructors.
+ arma::mat dataset(3, 300, arma::fill::randu);
+
+ Octree<> t1(dataset);
+ Octree<> t2(std::move(dataset));
+
+ CheckOverlap(t1);
+ CheckOverlap(t2);
+}
+
+/**
+ * Make sure no points are further than the furthest point distance, and that no
+ * descendants are further than the furthest descendant distance.
+ */
+template<typename TreeType>
+void CheckFurthestDistances(TreeType& node)
+{
+ arma::vec center;
+ node.Center(center);
+
+ // Compare points held in the node.
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ // Handle floating-point inaccuracies.
+ BOOST_REQUIRE_LE(metric::EuclideanDistance::Evaluate(node.Dataset().col(node.Point(i)),
+ center), node.FurthestPointDistance() * (1 + 1e-5));
+ }
+
+ // Compare descendants held in the node.
+ for (size_t i = 0; i < node.NumDescendants(); ++i)
+ {
+ // Handle floating-point inaccuracies.
+ BOOST_REQUIRE_LE(metric::EuclideanDistance::Evaluate(node.Dataset().col(node.Descendant(i)),
+ center), node.FurthestDescendantDistance() * (1 + 1e-5));
+ }
+
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ CheckFurthestDistances(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(FurthestDistanceTest)
+{
+ // Test with both constructors.
+ arma::mat dataset(3, 500, arma::fill::randu);
+
+ Octree<> t1(dataset);
+ Octree<> t2(std::move(dataset));
+
+ CheckFurthestDistances(t1);
+ CheckFurthestDistances(t2);
+}
+
+/**
+ * The maximum number of children a node can have is limited by the
+ * dimensionality. So we test to make sure there are no cases where we have too
+ * many children.
+ */
+template<typename TreeType>
+void CheckNumChildren(TreeType& node)
+{
+ BOOST_REQUIRE_LE(node.NumChildren(), std::pow(2, node.Dataset().n_rows));
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ CheckNumChildren(node.Child(i));
+}
+
+BOOST_AUTO_TEST_CASE(MaxNumChildrenTest)
+{
+ for (size_t d = 1; d < 10; ++d)
+ {
+ arma::mat dataset(d, 1000 * d, arma::fill::randu);
+ Octree<> t(std::move(dataset));
+
+ CheckNumChildren(t);
+ }
+}
+
+/**
+ * Test the copy constructor.
+ */
+template<typename TreeType>
+void CheckSameNode(TreeType& node1, TreeType& node2)
+{
+ BOOST_REQUIRE_EQUAL(node1.NumChildren(), node2.NumChildren());
+ BOOST_REQUIRE_NE(&node1.Dataset(), &node2.Dataset());
+
+ // Make sure the children actually got copied.
+ for (size_t i = 0; i < node1.NumChildren(); ++i)
+ BOOST_REQUIRE_NE(&node1.Child(i), &node2.Child(i));
+
+ // Check that all the points are the same.
+ BOOST_REQUIRE_EQUAL(node1.NumPoints(), node2.NumPoints());
+ BOOST_REQUIRE_EQUAL(node1.NumDescendants(), node2.NumDescendants());
+ for (size_t i = 0; i < node1.NumPoints(); ++i)
+ BOOST_REQUIRE_EQUAL(node1.Point(i), node2.Point(i));
+ for (size_t i = 0; i < node1.NumDescendants(); ++i)
+ BOOST_REQUIRE_EQUAL(node1.Descendant(i), node2.Descendant(i));
+
+ // Check that the bound is the same.
+ BOOST_REQUIRE_EQUAL(node1.Bound().Dim(), node2.Bound().Dim());
+ for (size_t d = 0; d < node1.Bound().Dim(); ++d)
+ {
+ BOOST_REQUIRE_CLOSE(node1.Bound()[d].Lo(), node2.Bound()[d].Lo(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node1.Bound()[d].Hi(), node2.Bound()[d].Hi(), 1e-5);
+ }
+
+ // Check that the furthest point and descendant distance are the same.
+ BOOST_REQUIRE_CLOSE(node1.FurthestPointDistance(),
+ node2.FurthestPointDistance(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node1.FurthestDescendantDistance(),
+ node2.FurthestDescendantDistance(), 1e-5);
+}
+
+BOOST_AUTO_TEST_CASE(CopyConstructorTest)
+{
+ // Use a small random dataset.
+ arma::mat dataset(3, 100, arma::fill::randu);
+
+ Octree<> t(dataset);
+ Octree<> t2(t);
+
+ CheckSameNode(t, t2);
+}
+
+/**
+ * Test the move constructor.
+ */
+BOOST_AUTO_TEST_CASE(MoveConstructorTest)
+{
+ // Use a small random dataset.
+ arma::mat dataset(3, 100, arma::fill::randu);
+
+ Octree<> t(std::move(dataset));
+ Octree<> tcopy(t);
+
+ // Move the tree.
+ Octree<> t2(std::move(t));
+
+ // Make sure the original tree has no data.
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_rows, 0);
+ BOOST_REQUIRE_EQUAL(t.Dataset().n_cols, 0);
+ BOOST_REQUIRE_EQUAL(t.NumChildren(), 0);
+ BOOST_REQUIRE_EQUAL(t.NumPoints(), 0);
+ BOOST_REQUIRE_EQUAL(t.NumDescendants(), 0);
+ BOOST_REQUIRE_SMALL(t.FurthestPointDistance(), 1e-5);
+ BOOST_REQUIRE_SMALL(t.FurthestDescendantDistance(), 1e-5);
+ BOOST_REQUIRE_EQUAL(t.Bound().Dim(), 0);
+
+ // Check that the new tree is the same as our copy.
+ CheckSameNode(tcopy, t2);
+}
+
+/**
+ * Test serialization.
+ */
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+ // Use a small random dataset.
+ arma::mat dataset(3, 500, arma::fill::randu);
+ Octree<> t(std::move(dataset));
+
+ Octree<>* xmlTree;
+ Octree<>* binaryTree;
+ Octree<>* textTree;
+
+ SerializePointerObjectAll(&t, xmlTree, binaryTree, textTree);
+
+ CheckSameNode(t, *xmlTree);
+ CheckSameNode(t, *binaryTree);
+ CheckSameNode(t, *textTree);
+
+ delete xmlTree;
+ delete binaryTree;
+ delete textTree;
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list