[mlpack-git] master: Add BinarySpaceTree::Serialize() and tests. Also add a constructor to allow constructing from an iarchive. (763f377)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Jul 10 19:00:22 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/4a97187bbba7ce8a6191b714949dd818ef0f37d2...e5905e62c15d1bcff21e6359b11efcd7ab6d7ca0

>---------------------------------------------------------------

commit 763f3774b1da47d4a41c62e05064064b54d74e16
Author: ryan <ryan at ratml.org>
Date:   Tue Apr 28 12:13:25 2015 -0400

    Add BinarySpaceTree::Serialize() and tests.
    Also add a constructor to allow constructing from an iarchive.


>---------------------------------------------------------------

763f3774b1da47d4a41c62e05064064b54d74e16
 .../tree/binary_space_tree/binary_space_tree.hpp   |  39 ++++-
 .../binary_space_tree/binary_space_tree_impl.hpp   | 147 +++++++++++++++++-
 src/mlpack/tests/serialization_test.cpp            | 166 ++++++++++++++++++++-
 3 files changed, 347 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index d17532d..729143a 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -65,8 +65,6 @@ class BinarySpaceTree
   //! The worst possible distance to the furthest descendant, cached to speed
   //! things up.
   double furthestDescendantDistance;
-  //! The minimum distance from the center to any edge of the bound.
-  double minimumBoundDistance;
   //! The dataset.  If we are the root of the tree, we own the dataset and must
   //! delete it.
   MatType* dataset;
@@ -210,9 +208,26 @@ class BinarySpaceTree
   BinarySpaceTree(const BinarySpaceTree& other);
 
   /**
+   * Initialize the tree from a boost::serialization archive.
+   *
+   * @param ar Archive to load tree from.  Must be an iarchive, not an oarchive!
+   */
+  template<typename Archive>
+  BinarySpaceTree(Archive& ar,
+                  const typename boost::enable_if<typename Archive::is_loading>::type* =
+0);
+
+  /**
    * Deletes this node, deallocating the memory for the children and calling
    * their destructors in turn.  This will invalidate any pointers or references
-   * to any nodes which are children of this one.
+   * to any nodes which are children of this one.  Also, if this is the root of
+   * the tree (specifically, if Parent() == NULL), it will delete the dataset.
+   *
+   * So, if you want to preserve other parts of the tree, the easiest way is
+   * probably to set Left() or Right() to NULL and maintain those branches
+   * yourself.  If you want to preserve the dataset, the easiest way is probably
+   * to set Parent() to NULL before deleting the object.  Either way, if you
+   * find yourself doing that, exercise care and caution.
    */
   ~BinarySpaceTree();
 
@@ -443,8 +458,26 @@ class BinarySpaceTree
   void SplitNode(std::vector<size_t>& oldFromNew,
                  const size_t maxLeafSize);
 
+ protected:
+  /**
+   * A default constructor.  This is meant to only be used with
+   * boost::serialization, which is allowed with the friend declaration below.
+   * This does not return a valid tree!  The method must be protected, so that
+   * the serialization shim can work with the default constructor.
+   */
+  BinarySpaceTree();
+
+  //! Friend access is given for the default constructor.
+  friend class boost::serialization::access;
+
  public:
   /**
+   * Serialize the tree.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int version);
+
+  /**
    * Returns a string representation of this object.
    */
   std::string ToString() const;
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index ebc4c23..863fa75 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -12,6 +12,7 @@
 #include <mlpack/core/util/cli.hpp>
 #include <mlpack/core/util/log.hpp>
 #include <mlpack/core/util/string_util.hpp>
+#include <queue>
 
 namespace mlpack {
 namespace tree {
@@ -212,7 +213,8 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     stat(other.stat),
     parentDistance(other.parentDistance),
     furthestDescendantDistance(other.furthestDescendantDistance),
-    dataset(new MatType(*other.dataset)) // Copy matrix.
+    // Copy matrix, but only if we are the root.
+    dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
 {
   // Create left and right children (if any).
   if (other.Left())
@@ -226,6 +228,50 @@ BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::BinarySpaceTree(
     right = new BinarySpaceTree(*other.Right());
     right->Parent() = this; // Set parent to this, not other tree.
   }
+
+  // Propagate matrix, but only if we are the root.
+  if (parent == NULL)
+  {
+    std::queue<BinarySpaceTree*> queue;
+    if (left)
+      queue.push(left);
+    if (right)
+      queue.push(right);
+    while (!queue.empty())
+    {
+      BinarySpaceTree* node = queue.front();
+      queue.pop();
+
+      node->dataset = dataset;
+      if (node->left)
+        queue.push(node->left);
+      if (node->right)
+        queue.push(node->right);
+    }
+  }
+}
+
+/**
+ * Initialize the tree from an archive.
+ */
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+template<typename Archive>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    BinarySpaceTree(Archive& ar,
+      const typename boost::enable_if<typename Archive::is_loading>::type*) : BinarySpaceTree()
+{
+  if (!Archive::is_loading::value)
+  {
+    throw std::invalid_argument("Archive::is_loading is false; use an iarchive,"
+        " not an oarchive!");
+  }
+
+  // We've delegated to the constructor which gives us an empty tree, and now we
+  // can serialize from it.
+  ar >> data::CreateNVP(*this, "tree");
 }
 
 /**
@@ -620,6 +666,105 @@ void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::SplitNode(
   right->ParentDistance() = rightParentDistance;
 }
 
+// Default constructor (private), for boost::serialization.
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::
+    BinarySpaceTree() :
+    left(NULL),
+    right(NULL),
+    parent(NULL),
+    begin(0),
+    count(0),
+    stat(*this),
+    parentDistance(0),
+    furthestDescendantDistance(0),
+    dataset(NULL)
+{
+  // Nothing to do.
+}
+
+/**
+ * Serialize the tree.
+ */
+template<typename BoundType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType>
+template<typename Archive>
+void BinarySpaceTree<BoundType, StatisticType, MatType, SplitType>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // If we're loading, and we have children, they need to be deleted.
+  if (Archive::is_loading::value)
+  {
+    if (left)
+      delete left;
+    if (right)
+      delete right;
+    if (!parent)
+      delete dataset;
+  }
+
+  ar & CreateNVP(parent, "parent");
+  ar & CreateNVP(begin, "begin");
+  ar & CreateNVP(count, "count");
+  ar & CreateNVP(bound, "bound");
+  ar & CreateNVP(stat, "statistic");
+  ar & CreateNVP(parentDistance, "parentDistance");
+  ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+  ar & CreateNVP(dataset, "dataset");
+
+  // Save children last; otherwise boost::serialization gets confused.
+  ar & CreateNVP(left, "left");
+  ar & CreateNVP(right, "right");
+
+  // Due to quirks of boost::serialization, if a tree is saved as an object and
+  // not a pointer, the first level of the tree will be duplicated on load.
+  // Therefore, if we are the root of the tree, then we need to make sure our
+  // children's parent links are correct, and delete the duplicated node if
+  // necessary.
+  if (Archive::is_loading::value)
+  {
+    // Get parents of left and right children, or, NULL, if they don't exist.
+    BinarySpaceTree* leftParent = left ? left->Parent() : NULL;
+    BinarySpaceTree* rightParent = right ? right->Parent() : NULL;
+
+    // Reassign parent links if necessary.
+    if (left && left->Parent() != this)
+      left->Parent() = this;
+    if (right && right->Parent() != this)
+      right->Parent() = this;
+
+    // Do we need to delete the left parent?
+    if (leftParent != NULL && leftParent != this)
+    {
+      // Sever the duplicate parent's children.  Ensure we don't delete the
+      // dataset, by faking the duplicated parent's parent (that is, we need to
+      // set the parent to something non-NULL; 'this' works).
+      leftParent->Parent() = this;
+      leftParent->Left() = NULL;
+      leftParent->Right() = NULL;
+      delete leftParent;
+    }
+
+    // Do we need to delete the right parent?
+    if (rightParent != NULL && rightParent != this && rightParent != leftParent)
+    {
+      // Sever the duplicate parent's children, in the same way as above.
+      rightParent->Parent() = this;
+      rightParent->Left() = NULL;
+      rightParent->Right() = NULL;
+      delete rightParent;
+    }
+  }
+}
+
 /**
  * Returns a string representation of this object.
  */
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 23f07cb..96498ec 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -20,12 +20,14 @@
 #include <mlpack/core/tree/ballbound.hpp>
 #include <mlpack/core/tree/hrectbound.hpp>
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
 using namespace mlpack::regression;
 using namespace mlpack::bound;
 using namespace mlpack::metric;
+using namespace mlpack::tree;
 using namespace arma;
 using namespace boost;
 using namespace boost::archive;
@@ -212,9 +214,53 @@ void SerializeObject(T& t, T& newT)
 template<typename T>
 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
 {
-  SerializeObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
   SerializeObject<T, text_iarchive, text_oarchive>(t, textT);
   SerializeObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
+  SerializeObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
+}
+
+// Save and load a non-default-constructible mlpack object.
+template<typename T, typename IArchiveType, typename OArchiveType>
+void SerializePointerObject(T* t, T*& newT)
+{
+  ofstream ofs("test");
+  OArchiveType o(ofs);
+
+  bool success = true;
+  try
+  {
+    o << data::CreateNVP(*t, "t");
+  }
+  catch (archive_exception& e)
+  {
+    success = false;
+  }
+  ofs.close();
+
+  BOOST_REQUIRE_EQUAL(success, true);
+
+  ifstream ifs("test");
+  IArchiveType i(ifs);
+
+  try
+  {
+    newT = new T(i);
+  }
+  catch (std::exception& e)
+  {
+    success = false;
+  }
+  ifs.close();
+
+  BOOST_REQUIRE_EQUAL(success, true);
+}
+
+template<typename T>
+void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
+{
+  SerializePointerObject<T, text_iarchive, text_oarchive>(t, textT);
+  SerializePointerObject<T, binary_iarchive, binary_oarchive>(t, binaryT);
+  SerializePointerObject<T, xml_iarchive, xml_oarchive>(t, xmlT);
 }
 
 // Utility function to check the equality of two Armadillo matrices.
@@ -521,4 +567,122 @@ BOOST_AUTO_TEST_CASE(HRectBoundTest)
   BOOST_REQUIRE_CLOSE(b.MinWidth(), binaryB.MinWidth(), 1e-8);
 }
 
+template<typename TreeType>
+void CheckTrees(TreeType& tree,
+                TreeType& xmlTree,
+                TreeType& textTree,
+                TreeType& binaryTree)
+{
+  const typename TreeType::Mat* dataset = &tree.Dataset();
+
+  // Make sure that the data matrices are the same.
+  if (tree.Parent() == NULL)
+  {
+    CheckMatrices(*dataset,
+                  xmlTree.Dataset(),
+                  textTree.Dataset(),
+                  binaryTree.Dataset());
+
+    // Also ensure that the other parents are null too.
+    BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL);
+    BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL);
+    BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL);
+  }
+
+  // Make sure the number of children is the same.
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren());
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
+  BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
+
+  // Make sure the number of descendants is the same.
+  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants());
+  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants());
+  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants());
+
+  // Make sure the number of points is the same.
+  BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints());
+  BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints());
+  BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints());
+
+  // Check that each point is the same.
+  for (size_t i = 0; i < tree.NumPoints(); ++i)
+  {
+    BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i));
+    BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i));
+    BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i));
+  }
+
+  // Check that the parent distance is the same.
+  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8);
+
+  // Check that the furthest descendant distance is the same.
+  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+      xmlTree.FurthestDescendantDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+      textTree.FurthestDescendantDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
+      binaryTree.FurthestDescendantDistance(), 1e-8);
+
+  // Check that the minimum bound distance is the same.
+  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+      xmlTree.MinimumBoundDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+      textTree.MinimumBoundDistance(), 1e-8);
+  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
+      binaryTree.MinimumBoundDistance(), 1e-8);
+
+  // Recurse into the children.
+  for (size_t i = 0; i < tree.NumChildren(); ++i)
+  {
+    // Check that the child dataset is the same.
+    BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset());
+    BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset());
+    BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset());
+
+    // Make sure the parent link is right.
+    BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree);
+    BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree);
+    BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree);
+
+    CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i),
+        binaryTree.Child(i));
+  }
+}
+
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeTest)
+{
+  arma::mat data;
+  data.randu(3, 100);
+  BinarySpaceTree<HRectBound<2>> tree(data);
+
+  BinarySpaceTree<HRectBound<2>>* xmlTree;
+  BinarySpaceTree<HRectBound<2>>* textTree;
+  BinarySpaceTree<HRectBound<2>>* binaryTree;
+
+  SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+
+  CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+  delete xmlTree;
+  delete textTree;
+  delete binaryTree;
+}
+
+BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
+{
+  arma::mat data;
+  data.randu(3, 100);
+  BinarySpaceTree<HRectBound<2>> tree(data);
+
+  BinarySpaceTree<HRectBound<2>> xmlTree(tree);
+  BinarySpaceTree<HRectBound<2>> textTree(tree);
+  BinarySpaceTree<HRectBound<2>> binaryTree(tree);
+
+  SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+  CheckTrees(tree, xmlTree, textTree, binaryTree);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list