[mlpack-git] master: Add Serialize() to CoverTree and tests too. (95088aa)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:18 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
>---------------------------------------------------------------
commit 95088aa9f1dbad1c9a29931c72a51111880745d3
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 11 14:16:43 2015 -0400
Add Serialize() to CoverTree and tests too.
>---------------------------------------------------------------
95088aa9f1dbad1c9a29931c72a51111880745d3
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 24 +++-
.../core/tree/cover_tree/cover_tree_impl.hpp | 135 ++++++++++++++++++++-
src/mlpack/tests/serialization_test.cpp | 113 ++++++++++++++++-
3 files changed, 265 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 1f27a60..ca7bbbd 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -198,6 +198,14 @@ class CoverTree
CoverTree(const CoverTree& other);
/**
+ * Create a cover tree from a boost::serialization archive.
+ */
+ template<typename Archive>
+ CoverTree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+ /**
* Delete this cover tree node and its children.
*/
~CoverTree();
@@ -457,6 +465,18 @@ class CoverTree
*/
void RemoveNewImplicitNodes();
+ protected:
+ /**
+ * A default constructor. This is meant to only be used with
+ * boost::serialization, which is allowed with the friend declaration below.
+ * This does not return a valid tree! This method must be protected, so that
+ * the serialization shim can work with the default constructor.
+ */
+ CoverTree();
+
+ //! Friend access is given for the default constructor.
+ friend class boost::serialization::access;
+
public:
/**
* Returns a string representation of this object.
@@ -476,8 +496,8 @@ class CoverTree
size_t distanceComps;
};
-}; // namespace tree
-}; // namespace mlpack
+} // namespace tree
+} // namespace mlpack
// Include implementation.
#include "cover_tree_impl.hpp"
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index e9212fc..a36fb83 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -306,6 +306,24 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
}
}
+// Construct from a boost::serialization archive.
+template<
+ typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy
+>
+template<typename Archive>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type*) :
+ CoverTree() // Create an empty CoverTree.
+{
+ // Now, serialize to our empty tree.
+ ar >> data::CreateNVP(*this, "tree");
+}
+
+
template<
typename MetricType,
typename StatisticType,
@@ -1122,6 +1140,31 @@ inline void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
}
/**
+ * Default constructor, only for use with boost::serialization.
+ */
+template<
+ typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy
+>
+CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree() :
+ dataset(NULL),
+ point(0),
+ scale(INT_MIN),
+ base(0.0),
+ numDescendants(0),
+ parent(NULL),
+ parentDistance(0.0),
+ furthestDescendantDistance(0.0),
+ localMetric(false),
+ localDataset(false),
+ metric(NULL)
+{
+ // Nothing to do.
+}
+
+/**
* Returns a string representation of this object.
*/
template<
@@ -1156,7 +1199,95 @@ std::string CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
return convert.str();
}
-}; // namespace tree
-}; // namespace mlpack
+/**
+ * Serialize to/from a boost::serialization archive.
+ */
+template<
+ typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename RootPointPolicy
+>
+template<typename Archive>
+void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::Serialize(
+ Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // If we're loading, and we have children, they need to be deleted. We may
+ // also need to delete the local metric and dataset.
+ if (Archive::is_loading::value)
+ {
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+
+ if (localMetric && metric)
+ delete metric;
+ if (localDataset && dataset)
+ delete dataset;
+ }
+
+ ar & CreateNVP(dataset, "dataset");
+ ar & CreateNVP(point, "point");
+ ar & CreateNVP(scale, "scale");
+ ar & CreateNVP(base, "base");
+ ar & CreateNVP(stat, "stat");
+ ar & CreateNVP(numDescendants, "numDescendants");
+ ar & CreateNVP(parent, "parent");
+ ar & CreateNVP(parentDistance, "parentDistance");
+ ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+ ar & CreateNVP(metric, "metric");
+
+ if (Archive::is_loading::value && parent == NULL)
+ {
+ localMetric = true;
+ localDataset = true;
+ }
+
+ // Lastly, serialize the children.
+ size_t numChildren = children.size();
+ ar & CreateNVP(numChildren, "numChildren");
+ if (Archive::is_loading::value)
+ children.resize(numChildren);
+ for (size_t i = 0; i < numChildren; ++i)
+ {
+ std::ostringstream oss;
+ oss << "child" << i;
+ ar & CreateNVP(children[i], oss.str());
+ }
+
+ // Due to quirks of boost::serialization, if a tree is saved as an object and
+ // not a pointer, the first level of the tree will be duplicated on load.
+ // Therefore, if we are the root of the tree, then we need to make sure our
+ // children's parent links are correct, and delete the duplicated node if
+ // necessary.
+ if (Archive::is_loading::value)
+ {
+ // Look through each child individually.
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ if (children[i]->Parent() != this)
+ {
+ // Disallow the duplicate parent from deleting anything. But only
+ // delete the parent if this is the first child (we are assuming that
+ // each of the other children has the same incorrect parent).
+ if (i == 0)
+ {
+ children[i]->Parent()->localMetric = false;
+ children[i]->Parent()->localDataset = false;
+ children[i]->Parent()->children.clear();
+ delete children[i]->Parent();
+ }
+
+ // Fix the child's parent link.
+ children[i]->Parent() = this;
+ }
+ }
+ }
+}
+
+} // namespace tree
+} // namespace mlpack
#endif
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 8a92130..5a71632 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -21,6 +21,7 @@
#include <mlpack/core/tree/hrectbound.hpp>
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/methods/perceptron/perceptron.hpp>
#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
@@ -713,15 +714,121 @@ BOOST_AUTO_TEST_CASE(BinarySpaceTreeOverwriteTest)
typedef KDTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
TreeType tree(data);
- TreeType xmlTree(tree);
- TreeType textTree(tree);
- TreeType binaryTree(tree);
+ arma::mat otherData;
+ otherData.randu(5, 50);
+ TreeType xmlTree(otherData);
+ TreeType textTree(xmlTree);
+ TreeType binaryTree(xmlTree);
SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
CheckTrees(tree, xmlTree, textTree, binaryTree);
}
+BOOST_AUTO_TEST_CASE(CoverTreeTest)
+{
+ arma::mat data;
+ data.randu(3, 100);
+ typedef StandardCoverTree<EuclideanDistance, EmptyStatistic, arma::mat>
+ TreeType;
+ TreeType tree(data);
+
+ TreeType* xmlTree;
+ TreeType* textTree;
+ TreeType* binaryTree;
+
+ SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+
+ CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+ // Also check a few other things.
+ std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+ stack.push(&tree);
+ xmlStack.push(xmlTree);
+ textStack.push(textTree);
+ binaryStack.push(binaryTree);
+ while (!stack.empty())
+ {
+ TreeType* node = stack.top();
+ TreeType* xmlNode = xmlStack.top();
+ TreeType* textNode = textStack.top();
+ TreeType* binaryNode = binaryStack.top();
+ stack.pop();
+ xmlStack.pop();
+ textStack.pop();
+ binaryStack.pop();
+
+ BOOST_REQUIRE_EQUAL(node->Scale(), xmlNode->Scale());
+ BOOST_REQUIRE_EQUAL(node->Scale(), textNode->Scale());
+ BOOST_REQUIRE_EQUAL(node->Scale(), binaryNode->Scale());
+
+ BOOST_REQUIRE_CLOSE(node->Base(), xmlNode->Base(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Base(), textNode->Base(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Base(), binaryNode->Base(), 1e-5);
+
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ {
+ stack.push(&node->Child(i));
+ xmlStack.push(&xmlNode->Child(i));
+ textStack.push(&textNode->Child(i));
+ binaryStack.push(&binaryNode->Child(i));
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(CoverTreeOverwriteTest)
+{
+ arma::mat data;
+ data.randu(3, 100);
+ typedef StandardCoverTree<EuclideanDistance, EmptyStatistic, arma::mat>
+ TreeType;
+ TreeType tree(data);
+
+ arma::mat otherData;
+ otherData.randu(5, 50);
+ TreeType xmlTree(otherData);
+ TreeType textTree(xmlTree);
+ TreeType binaryTree(xmlTree);
+
+ SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+ CheckTrees(tree, xmlTree, textTree, binaryTree);
+
+ // Also check a few other things.
+ std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+ stack.push(&tree);
+ xmlStack.push(&xmlTree);
+ textStack.push(&textTree);
+ binaryStack.push(&binaryTree);
+ while (!stack.empty())
+ {
+ TreeType* node = stack.top();
+ TreeType* xmlNode = xmlStack.top();
+ TreeType* textNode = textStack.top();
+ TreeType* binaryNode = binaryStack.top();
+ stack.pop();
+ xmlStack.pop();
+ textStack.pop();
+ binaryStack.pop();
+
+ BOOST_REQUIRE_EQUAL(node->Scale(), xmlNode->Scale());
+ BOOST_REQUIRE_EQUAL(node->Scale(), textNode->Scale());
+ BOOST_REQUIRE_EQUAL(node->Scale(), binaryNode->Scale());
+
+ BOOST_REQUIRE_CLOSE(node->Base(), xmlNode->Base(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Base(), textNode->Base(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Base(), binaryNode->Base(), 1e-5);
+
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ {
+ stack.push(&node->Child(i));
+ xmlStack.push(&xmlNode->Child(i));
+ textStack.push(&textNode->Child(i));
+ binaryStack.push(&binaryNode->Child(i));
+ }
+ }
+}
+
BOOST_AUTO_TEST_CASE(PerceptronTest)
{
// Create a perceptron. Train it randomly. Then check that it hasn't
More information about the mlpack-git
mailing list