[mlpack-git] master: Add BinarySpaceTree::Serialize() and tests. Also add a constructor to allow constructing from an iarchive. (763f377)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:22 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0
>---------------------------------------------------------------
commit 763f3774b1da47d4a41c62e05064064b54d74e16
Author: ryan <ryan at ratml.org>
Date: Tue Apr 28 12:13:25 2015 -0400
Add BinarySpaceTree::Serialize() and tests.
Also add a constructor to allow constructing from an iarchive.
>---------------------------------------------------------------
763f3774b1da47d4a41c62e05064064b54d74e16
.../tree/binary_space_tree/binary_space_tree.hpp | 39 ++++-
.../binary_space_tree/binary_space_tree_impl.hpp | 147 +++++++++++++++++-
src/mlpack/tests/serialization_test.cpp | 166 ++++++++++++++++++++-
3 files changed, 347 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index d17532d..729143a 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -65,8 +65,6 @@ class BinarySpaceTree
//! The worst possible distance to the furthest descendant, cached to speed
//! things up.
double furthestDescendantDistance;
- //! The minimum distance from the center to any edge of the bound.
- double minimumBoundDistance;
//! The dataset. If we are the root of the tree, we own the dataset and must
//! delete it.
MatType* dataset;
@@ -210,9 +208,26 @@ class BinarySpaceTree
BinarySpaceTree(const BinarySpaceTree& other);
/**
+ * Initialize the tree from a boost::serialization archive.
+ *
+ * @param ar Archive to load tree from. Must be an iarchive, not an oarchive!
+ */
+ template<typename Archive>
+ BinarySpaceTree(Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type* =
+0);
+
+ /**
* Deletes this node, deallocating the memory for the children and calling
* their destructors in turn. This will invalidate any pointers or references
- * to any nodes which are children of this one.
+ * to any nodes which are children of this one. Also, if this is the root of
+ * the tree (specifically, if Parent() == NULL), it will delete the dataset.
+ *
+ * So, if you want to preserve other parts of the tree, the easiest way is
+ * probably to set Left() or Right() to NULL and maintain those branches
+ * yourself. If you want to preserve the dataset, the easiest way is probably
+ * to set Parent() to NULL before deleting the object. Either way, if you
+ * find yourself doing that, exercise care and caution.
*/
~BinarySpaceTree();
@@ -443,8 +458,26 @@ class BinarySpaceTree
void SplitNode(std::vector<size_t>& oldFromNew,
const size_t maxLeafSize);
+ protected:
+ /**
+ * A default constructor. This is meant to only be used with
+ * boost::serialization, which is allowed with the friend declaration below.
+ * This does not return a valid tree! The method must be protected, so that
+ * the serialization shim can work with the default constructor.
+ */
+ BinarySpaceTree();
+
+ //! Friend access is given for the default constructor.
+ friend class boost::serialization::access;
+
public:
/**
+ * Serialize the tree.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int version);
+
+ /**
* Returns a string representation of this object.
*/
std::string ToString() const;
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index ebc4c23..863fa75 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -12,6 +12,7 @@
#include <mlpack/core/util/cli.hpp>
#include <mlpack/core/util/log.hpp>
#include <mlpack/core/util/string_util.hpp>
+#include <queue>
namespace mlpack {
namespace tree {
@@ -212,7 +213,8 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
stat(other.stat),
parentDistance(other.parentDistance),
furthestDescendantDistance(other.furthestDescendantDistance),
- dataset(new MatType(*other.dataset)) // Copy matrix.
+ // Copy matrix, but only if we are the root.
+ dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
{
// Create left and right children (if any).
if (other.Left())
@@ -226,6 +228,50 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
right = new BinarySpaceTree(*other.Right());
right->Parent() = this; // Set parent to this, not other tree.
}
+
+ // Propagate matrix, but only if we are the root.
+ if (parent == NULL)
+ {
+ std::queue<BinarySpaceTree*> queue;
+ if (left)
+ queue.push(left);
+ if (right)
+ queue.push(right);
+ while (!queue.empty())
+ {
+ BinarySpaceTree* node = queue.front();
+ queue.pop();
+
+ node->dataset = dataset;
+ if (node->left)
+ queue.push(node->left);
+ if (node->right)
+ queue.push(node->right);
+ }
+ }
+}
+
+/**
+ * Initialize the tree from an archive.
+ */
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename Archive>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ BinarySpaceTree(Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type*) : BinarySpaceTree()
+{
+ if (!Archive::is_loading::value)
+ {
+ throw std::invalid_argument("Archive::is_loading is false; use an iarchive,"
+ " not an oarchive!");
+ }
+
+ // We've delegated to the constructor which gives us an empty tree, and now we
+ // can serialize from it.
+ ar >> data::CreateNVP(*this, "tree");
}
/**
@@ -620,6 +666,105 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
right->ParentDistance() = rightParentDistance;
}
+// Default constructor (private), for boost::serialization.
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+ BinarySpaceTree() :
+ left(NULL),
+ right(NULL),
+ parent(NULL),
+ begin(0),
+ count(0),
+ stat(*this),
+ parentDistance(0),
+ furthestDescendantDistance(0),
+ dataset(NULL)
+{
+ // Nothing to do.
+}
+
+/**
+ * Serialize the tree.
+ */
+template<typename BoundType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType>
+template<typename Archive>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // If we're loading, and we have children, they need to be deleted.
+ if (Archive::is_loading::value)
+ {
+ if (left)
+ delete left;
+ if (right)
+ delete right;
+ if (!parent)
+ delete dataset;
+ }
+
+ ar & CreateNVP(parent, "parent");
+ ar & CreateNVP(begin, "begin");
+ ar & CreateNVP(count, "count");
+ ar & CreateNVP(bound, "bound");
+ ar & CreateNVP(stat, "statistic");
+ ar & CreateNVP(parentDistance, "parentDistance");
+ ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+ ar & CreateNVP(dataset, "dataset");
+
+ // Save children last; otherwise boost::serialization gets confused.
+ ar & CreateNVP(left, "left");
+ ar & CreateNVP(right, "right");
+
+ // Due to quirks of boost::serialization, if a tree is saved as an object and
+ // not a pointer, the first level of the tree will be duplicated on load.
+ // Therefore, if we are the root of the tree, then we need to make sure our
+ // children's parent links are correct, and delete the duplicated node if
+ // necessary.
+ if (Archive::is_loading::value)
+ {
+ // Get parents of left and right children, or, NULL, if they don't exist.
+ BinarySpaceTree* leftParent = left ? left->Parent() : NULL;
+ BinarySpaceTree* rightParent = right ? right->Parent() : NULL;
+
+ // Reassign parent links if necessary.
+ if (left && left->Parent() != this)
+ left->Parent() = this;
+ if (right && right->Parent() != this)
+ right->Parent() = this;
+
+ // Do we need to delete the left parent?
+ if (leftParent != NULL && leftParent != this)
+ {
+ // Sever the duplicate parent's children. Ensure we don't delete the
+ // dataset, by faking the duplicated parent's parent (that is, we need to
+ // set the parent to something non-NULL; 'this' works).
+ leftParent->Parent() = this;
+ leftParent->Left() = NULL;
+ leftParent->Right() = NULL;
+ delete leftParent;
+ }
+
+ // Do we need to delete the right parent?
+ if (rightParent != NULL && rightParent != this && rightParent != leftParent)
+ {
+ // Sever the duplicate parent's children, in the same way as above.
+ rightParent->Parent() = this;
+ rightParent->Left() = NULL;
+ rightParent->Right() = NULL;
+ delete rightParent;
+ }
+ }
+}
+
/**
* Returns a string representation of this object.
*/
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 23f07cb..96498ec 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -20,12 +20,14 @@
#include <mlpack/core/tree/ballbound.hpp>
#include <mlpack/core/tree/hrectbound.hpp>
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
using namespace mlpack::regression;
using namespace mlpack::bound;
using namespace mlpack::metric;
+using namespace mlpack::tree;
using namespace arma;
using namespace boost;
using namespace boost::archive;
@@ -212,9 +214,53 @@ void SerializeObject(T& t, T& newT)
template<typename T>
void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
{
- SerializeObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
SerializeObject<T, text_iarchive, text_oarchive>(t, textT);
SerializeObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
+ SerializeObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
+}
+
+// Save and load a non-default-constructible mlpack object.
+template<typename T, typename IArchiveType, typename OArchiveType>
+void SerializePointerObject(T* t, T*& newT)
+{
+ ofstream ofs("test");
+ OArchiveType o(ofs);
+
+ bool success = true;
+ try
+ {
+ o << data::CreateNVP(*t, "t");
+ }
+ catch (archive_exception& e)
+ {
+ success = false;
+ }
+ ofs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+
+ ifstream ifs("test");
+ IArchiveType i(ifs);
+
+ try
+ {
+ newT = new T(i);
+ }
+ catch (std::exception& e)
+ {
+ success = false;
+ }
+ ifs.close();
+
+ BOOST_REQUIRE_EQUAL(success, true);
+}
+
+template<typename T>
+void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
+{
+ SerializePointerObject<T, text_iarchive, text_oarchive>(t, textT);
+ SerializePointerObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
+ SerializePointerObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
}
// Utility function to check the equality of two Armadillo matrices.
@@ -521,4 +567,122 @@ BOOST_AUTO_TEST_CASE(HRectBoundTest)
BOOST_REQUIRE_CLOSE(b.MinWidth(), binaryB.MinWidth(), 1e-8);
}
+template<typename TreeType>
+void CheckTrees(TreeType& tree,
+ TreeType& xmlTree,
+ TreeType& textTree,
+ TreeType& binaryTree)
+{
+ const typename TreeType::Mat* dataset = &tree.Dataset();
+
+ // Make sure that the data matrices are the same.
+ if (tree.Parent() == NULL)
+ {
+ CheckMatrices(*dataset,
+ xmlTree.Dataset(),
+ textTree.Dataset(),
+ binaryTree.Dataset());
+
+ // Also ensure that the other parents are null too.
+ BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL);
+ BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL);
+ BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL);
+ }
+
+ // Make sure the number of children is the same.
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren());
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
+ BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
+
+ // Make sure the number of descendants is the same.
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants());
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants());
+ BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants());
+
+ // Make sure the number of points is the same.
+ BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints());
+ BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints());
+ BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints());
+
+ // Check that each point is the same.
+ for (size_t i = 0; i < tree.NumPoints(); ++i)
+ {
+ BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i));
+ BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i));
+ BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i));
+ }
+
+ // Check that the parent distance is the same.
+ BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8);
+
+ // Check that the furthest descendant distance is the same.
+ BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+ xmlTree.FurthestDescendantDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+ textTree.FurthestDescendantDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+ binaryTree.FurthestDescendantDistance(), 1e-8);
+
+ // Check that the minimum bound distance is the same.
+ BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+ xmlTree.MinimumBoundDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+ textTree.MinimumBoundDistance(), 1e-8);
+ BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+ binaryTree.MinimumBoundDistance(), 1e-8);
+
+ // Recurse into the children.
+ for (size_t i = 0; i < tree.NumChildren(); ++i)
+ {
+ // Check that the child dataset is the same.
+ BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset());
+ BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset());
+ BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset());
+
+ // Make sure the parent link is right.
+ BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree);
+ BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree);
+ BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree);
+
+ CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i),
+ binaryTree.Child(i));
+ }
+}
+
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeTest)
+{
+ arma::mat data;
+ data.randu(3, 100);
+ BinarySpaceTree<HRectBound<2>> tree(data);
+
+ BinarySpaceTree<HRectBound<2>>* xmlTree;
+ BinarySpaceTree<HRectBound<2>>* textTree;
+ BinarySpaceTree<HRectBound<2>>* binaryTree;
+
+ SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+
+ CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+ delete xmlTree;
+ delete textTree;
+ delete binaryTree;
+}
+
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
+{
+ arma::mat data;
+ data.randu(3, 100);
+ BinarySpaceTree<HRectBound<2>> tree(data);
+
+ BinarySpaceTree<HRectBound<2>> xmlTree(tree);
+ BinarySpaceTree<HRectBound<2>> textTree(tree);
+ BinarySpaceTree<HRectBound<2>> binaryTree(tree);
+
+ SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+ CheckTrees(tree, xmlTree, textTree, binaryTree);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list