[mlpack-git] master: Add Serialize() to CoverTree and tests too. (95088aa)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:18 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656

>---------------------------------------------------------------

commit 95088aa9f1dbad1c9a29931c72a51111880745d3
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 11 14:16:43 2015 -0400

    Add Serialize() to CoverTree and tests too.


>---------------------------------------------------------------

95088aa9f1dbad1c9a29931c72a51111880745d3
 src/mlpack/core/tree/cover_tree/cover_tree.hpp     |  24 +++-
 .../core/tree/cover_tree/cover_tree_impl.hpp       | 135 ++++++++++++++++++++-
 src/mlpack/tests/serialization_test.cpp            | 113 ++++++++++++++++-
 3 files changed, 265 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 1f27a60..ca7bbbd 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -198,6 +198,14 @@ class CoverTree
   CoverTree(const CoverTree& other);
 
   /**
+   * Create a cover tree from a boost::serialization archive.
+   */
+  template<typename Archive>
+  CoverTree(
+      Archive& ar,
+      const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+  /**
    * Delete this cover tree node and its children.
    */
   ~CoverTree();
@@ -457,6 +465,18 @@ class CoverTree
    */
   void RemoveNewImplicitNodes();
 
+ protected:
+  /**
+   * A default constructor.  This is meant to only be used with
+   * boost::serialization, which is allowed with the friend declaration below.
+   * This does not return a valid tree!  This method must be protected, so that
+   * the serialization shim can work with the default constructor.
+   */
+  CoverTree();
+
+  //! Friend access is given for the default constructor.
+  friend class boost::serialization::access;
+
  public:
   /**
    * Returns a string representation of this object.
@@ -476,8 +496,8 @@ class CoverTree
   size_t distanceComps;
 };
 
-}; // namespace tree
-}; // namespace mlpack
+} // namespace tree
+} // namespace mlpack
 
 // Include implementation.
 #include "cover_tree_impl.hpp"
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index e9212fc..a36fb83 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -306,6 +306,24 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
   }
 }
 
+// Construct from a boost::serialization archive.
+template<
+    typename MetricType,
+    typename StatisticType,
+    typename MatType,
+    typename RootPointPolicy
+>
+template<typename Archive>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
+    Archive& ar,
+    const typename boost::enable_if<typename Archive::is_loading>::type*) :
+    CoverTree() // Create an empty CoverTree.
+{
+  // Now, serialize to our empty tree.
+  ar >> data::CreateNVP(*this, "tree");
+}
+
+
 template<
     typename MetricType,
     typename StatisticType,
@@ -1122,6 +1140,31 @@ inline void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
 }
 
 /**
+ * Default constructor, only for use with boost::serialization.
+ */
+template<
+    typename MetricType,
+    typename StatisticType,
+    typename MatType,
+    typename RootPointPolicy
+>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree() :
+    dataset(NULL),
+    point(0),
+    scale(INT_MIN),
+    base(0.0),
+    numDescendants(0),
+    parent(NULL),
+    parentDistance(0.0),
+    furthestDescendantDistance(0.0),
+    localMetric(false),
+    localDataset(false),
+    metric(NULL)
+{
+  // Nothing to do.
+}
+
+/**
  * Returns a string representation of this object.
  */
 template<
@@ -1156,7 +1199,95 @@ std::string CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
   return convert.str();
 }
 
-}; // namespace tree
-}; // namespace mlpack
+/**
+ * Serialize to/from a boost::serialization archive.
+ */
+template<
+    typename MetricType,
+    typename StatisticType,
+    typename MatType,
+    typename RootPointPolicy
+>
+template<typename Archive>
+void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::Serialize(
+    Archive& ar,
+    const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // If we're loading, and we have children, they need to be deleted.  We may
+  // also need to delete the local metric and dataset.
+  if (Archive::is_loading::value)
+  {
+    for (size_t i = 0; i < children.size(); ++i)
+      delete children[i];
+
+    if (localMetric && metric)
+      delete metric;
+    if (localDataset && dataset)
+      delete dataset;
+  }
+
+  ar & CreateNVP(dataset, "dataset");
+  ar & CreateNVP(point, "point");
+  ar & CreateNVP(scale, "scale");
+  ar & CreateNVP(base, "base");
+  ar & CreateNVP(stat, "stat");
+  ar & CreateNVP(numDescendants, "numDescendants");
+  ar & CreateNVP(parent, "parent");
+  ar & CreateNVP(parentDistance, "parentDistance");
+  ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+  ar & CreateNVP(metric, "metric");
+
+  if (Archive::is_loading::value && parent == NULL)
+  {
+    localMetric = true;
+    localDataset = true;
+  }
+
+  // Lastly, serialize the children.
+  size_t numChildren = children.size();
+  ar & CreateNVP(numChildren, "numChildren");
+  if (Archive::is_loading::value)
+    children.resize(numChildren);
+  for (size_t i = 0; i < numChildren; ++i)
+  {
+    std::ostringstream oss;
+    oss << "child" << i;
+    ar & CreateNVP(children[i], oss.str());
+  }
+
+  // Due to quirks of boost::serialization, if a tree is saved as an object and
+  // not a pointer, the first level of the tree will be duplicated on load.
+  // Therefore, if we are the root of the tree, then we need to make sure our
+  // children's parent links are correct, and delete the duplicated node if
+  // necessary.
+  if (Archive::is_loading::value)
+  {
+    // Look through each child individually.
+    for (size_t i = 0; i < children.size(); ++i)
+    {
+      if (children[i]->Parent() != this)
+      {
+        // Disallow the duplicate parent from deleting anything.  But only
+        // delete the parent if this is the first child (we are assuming that
+        // each of the other children has the same incorrect parent).
+        if (i == 0)
+        {
+          children[i]->Parent()->localMetric = false;
+          children[i]->Parent()->localDataset = false;
+          children[i]->Parent()->children.clear();
+          delete children[i]->Parent();
+        }
+
+        // Fix the child's parent link.
+        children[i]->Parent() = this;
+      }
+    }
+  }
+}
+
+} // namespace tree
+} // namespace mlpack
 
 #endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 8a92130..5a71632 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -21,6 +21,7 @@
 #include <mlpack/core/tree/hrectbound.hpp>
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 
 #include <mlpack/methods/perceptron/perceptron.hpp>
 #include <mlpack/methods/logistic_regression/logistic_regression.hpp>
@@ -713,15 +714,121 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
   typedef KDTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
   TreeType tree(data);
 
-  TreeType xmlTree(tree);
-  TreeType textTree(tree);
-  TreeType binaryTree(tree);
+  arma::mat otherData;
+  otherData.randu(5, 50);
+  TreeType xmlTree(otherData);
+  TreeType textTree(xmlTree);
+  TreeType binaryTree(xmlTree);
 
   SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
 
   CheckTrees(tree, xmlTree, textTree, binaryTree);
 }
 
+BOOST_AUTO_TEST_CASE(CoverTreeTest)
+{
+  arma::mat data;
+  data.randu(3, 100);
+  typedef StandardCoverTree<EuclideanDistance, EmptyStatistic, arma::mat>
+      TreeType;
+  TreeType tree(data);
+
+  TreeType* xmlTree;
+  TreeType* textTree;
+  TreeType* binaryTree;
+
+  SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+
+  CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+  // Also check a few other things.
+  std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+  stack.push(&tree);
+  xmlStack.push(xmlTree);
+  textStack.push(textTree);
+  binaryStack.push(binaryTree);
+  while (!stack.empty())
+  {
+    TreeType* node = stack.top();
+    TreeType* xmlNode = xmlStack.top();
+    TreeType* textNode = textStack.top();
+    TreeType* binaryNode = binaryStack.top();
+    stack.pop();
+    xmlStack.pop();
+    textStack.pop();
+    binaryStack.pop();
+
+    BOOST_REQUIRE_EQUAL(node->Scale(), xmlNode->Scale());
+    BOOST_REQUIRE_EQUAL(node->Scale(), textNode->Scale());
+    BOOST_REQUIRE_EQUAL(node->Scale(), binaryNode->Scale());
+
+    BOOST_REQUIRE_CLOSE(node->Base(), xmlNode->Base(), 1e-5);
+    BOOST_REQUIRE_CLOSE(node->Base(), textNode->Base(), 1e-5);
+    BOOST_REQUIRE_CLOSE(node->Base(), binaryNode->Base(), 1e-5);
+
+    for (size_t i = 0; i < node->NumChildren(); ++i)
+    {
+      stack.push(&node->Child(i));
+      xmlStack.push(&xmlNode->Child(i));
+      textStack.push(&textNode->Child(i));
+      binaryStack.push(&binaryNode->Child(i));
+    }
+  }
+}
+
+BOOST_AUTO_TEST_CASE(CoverTreeOverwriteTest)
+{
+  arma::mat data;
+  data.randu(3, 100);
+  typedef StandardCoverTree<EuclideanDistance, EmptyStatistic, arma::mat>
+      TreeType;
+  TreeType tree(data);
+
+  arma::mat otherData;
+  otherData.randu(5, 50);
+  TreeType xmlTree(otherData);
+  TreeType textTree(xmlTree);
+  TreeType binaryTree(xmlTree);
+
+  SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+  CheckTrees(tree, xmlTree, textTree, binaryTree);
+
+  // Also check a few other things.
+  std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+  stack.push(&tree);
+  xmlStack.push(&xmlTree);
+  textStack.push(&textTree);
+  binaryStack.push(&binaryTree);
+  while (!stack.empty())
+  {
+    TreeType* node = stack.top();
+    TreeType* xmlNode = xmlStack.top();
+    TreeType* textNode = textStack.top();
+    TreeType* binaryNode = binaryStack.top();
+    stack.pop();
+    xmlStack.pop();
+    textStack.pop();
+    binaryStack.pop();
+
+    BOOST_REQUIRE_EQUAL(node->Scale(), xmlNode->Scale());
+    BOOST_REQUIRE_EQUAL(node->Scale(), textNode->Scale());
+    BOOST_REQUIRE_EQUAL(node->Scale(), binaryNode->Scale());
+
+    BOOST_REQUIRE_CLOSE(node->Base(), xmlNode->Base(), 1e-5);
+    BOOST_REQUIRE_CLOSE(node->Base(), textNode->Base(), 1e-5);
+    BOOST_REQUIRE_CLOSE(node->Base(), binaryNode->Base(), 1e-5);
+
+    for (size_t i = 0; i < node->NumChildren(); ++i)
+    {
+      stack.push(&node->Child(i));
+      xmlStack.push(&xmlNode->Child(i));
+      textStack.push(&textNode->Child(i));
+      binaryStack.push(&binaryNode->Child(i));
+    }
+  }
+}
+
 BOOST_AUTO_TEST_CASE(PerceptronTest)
 {
   // Create a perceptron.  Train it randomly.  Then check that it hasn't



More information about the mlpack-git mailing list