[mlpack-git] master: Serialization for RectangleTree. Not working---committed in order to work on another system. Also has debugging output. (168a49a)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Oct 14 05:02:48 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/81e72d4410ae417f7a8536bd3c61865e2f62c934...ce49a4b5f0b7d12d4955c09e45c69891a6f83e8a
>---------------------------------------------------------------
commit 168a49a3783db14a3d49865d71c68790bd929511
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 12 15:12:15 2015 -0400
Serialization for RectangleTree. Not working---committed in order to work on another system. Also has debugging output.
>---------------------------------------------------------------
168a49a3783db14a3d49865d71c68790bd929511
.../core/tree/rectangle_tree/rectangle_tree.hpp | 13 +++
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 107 ++++++++++++++++++-
src/mlpack/tests/serialization_test.cpp | 117 +++++++++++++++++++++
3 files changed, 236 insertions(+), 1 deletion(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 97c7468..e54ae1a 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -61,6 +61,13 @@ class RectangleTree
for (int i = 0; i < dim; i++)
history[i] = false;
}
+
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */)
+ {
+ ar & data::CreateNVP(lastDimension, "lastDimension");
+ ar & data::CreateNVP(history, "history");
+ }
} SplitHistoryStruct;
private:
@@ -576,6 +583,12 @@ class RectangleTree
* Returns a string representation of this object.
*/
std::string ToString() const;
+
+ /**
+ * Serialize the tree.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
};
} // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index d325ab4..4a44435 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -540,6 +540,8 @@ template<typename MetricType,
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::NumDescendants() const
{
+ std::cout << "NumDescendants() [" << this << "], with " << numChildren
+ << "children.\n";
if (numChildren == 0)
{
return count;
@@ -548,8 +550,10 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
{
size_t n = 0;
for (size_t i = 0; i < numChildren; i++)
+ {
+ std::cout << "child " << i << ": " << children[i] << ".\n";
n += children[i]->NumDescendants();
-
+ }
return n;
}
}
@@ -632,6 +636,7 @@ RectangleTree() :
count(0),
maxLeafSize(0),
minLeafSize(0),
+ splitHistory(0),
parentDistance(0.0),
furthestDescendantDistance(0.0),
dataset(NULL),
@@ -930,6 +935,106 @@ std::string RectangleTree<MetricType, StatisticType, MatType, SplitType,
return convert.str();
}
+/**
+ * Serialize the tree.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType>
+template<typename Archive>
+void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+ Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ using data::CreateNVP;
+
+ // Clean up memory, if necessary.
+ if (Archive::is_loading::value)
+ {
+ for (size_t i = 0; i < numChildren; i++)
+ delete children[i];
+ children.clear();
+
+ if (ownsDataset && dataset)
+ delete dataset;
+
+ if (localDataset)
+ delete localDataset;
+ }
+
+ ar & CreateNVP(maxNumChildren, "maxNumChildren");
+ ar & CreateNVP(minNumChildren, "minNumChildren");
+ ar & CreateNVP(numChildren, "numChildren");
+ ar & CreateNVP(parent, "parent");
+ ar & CreateNVP(begin, "begin");
+ ar & CreateNVP(count, "count");
+ ar & CreateNVP(maxLeafSize, "maxLeafSize");
+ ar & CreateNVP(minLeafSize, "minLeafSize");
+ ar & CreateNVP(bound, "bound");
+ ar & CreateNVP(stat, "stat");
+ ar & CreateNVP(splitHistory, "splitHistory");
+ ar & CreateNVP(parentDistance, "parentDistance");
+ ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+ ar & CreateNVP(dataset, "dataset");
+
+ // If we are loading and we are the root, we own the dataset.
+ if (Archive::is_loading::value && parent == NULL)
+ ownsDataset = true;
+
+ ar & CreateNVP(points, "points");
+ ar & CreateNVP(localDataset, "localDataset");
+
+ // Because 'children' holds mlpack types (that have Serialize()), we can't use
+ // the std::vector serialization.
+ if (Archive::is_loading::value)
+ children.resize(numChildren);
+ for (size_t i = 0; i < numChildren; ++i)
+ {
+ std::ostringstream oss;
+ oss << "child" << i;
+ ar & CreateNVP(children[i], oss.str());
+ }
+
+
+ // Due to quirks of boost::serialization, if a tree is saved as an object and
+ // not a pointer, the first level of the tree will be duplicated on load.
+ // Therefore, if we are the root of the tree, then we need to make sure our
+ // children's parent links are correct, and delete the duplicated node if
+ // necessary.
+ if (Archive::is_loading::value)
+ {
+ // Look through each child individually.
+ for (size_t i = 0; i < children.size(); ++i)
+ {
+ if (children[i]->Parent() != this)
+ {
+ // Disallow the duplicate parent from deleting anything. But only
+ // delete the parent if this is the first child (we are assuming that
+ // each of the other children has the same incorrect parent).
+ if (i == 0)
+ {
+ children[i]->Parent()->ownsDataset = false;
+ children[i]->Parent()->children.clear();
+ delete children[i]->Parent();
+ }
+
+ // Fix the child's parent link.
+ children[i]->Parent() = this;
+ }
+ }
+ }
+
+ if (Archive::is_loading::value)
+ {
+ std::cout << "loaded node " << this << " with " << numChildren << " (" <<
+ children.size() << ") children\n";
+ for (size_t i = 0; i < numChildren; ++i)
+ std::cout << "child " << i << ": " << children[i] << ".\n";
+ }
+}
+
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 5a71632..bc84a5d 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,6 +22,7 @@
#include <mlpack/core/metrics/mahalanobis_distance.hpp>
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
#include <mlpack/methods/perceptron/perceptron.hpp>
#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
@@ -610,6 +611,8 @@ void CheckTrees(TreeType& tree,
TreeType& binaryTree)
{
const typename TreeType::Mat* dataset = &tree.Dataset();
+ std::cout << "check tree node " << tree.NumChildren() << " desc " <<
+tree.NumDescendants() << ".\n";
// Make sure that the data matrices are the same.
if (tree.Parent() == NULL)
@@ -631,6 +634,9 @@ void CheckTrees(TreeType& tree,
BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
// Make sure the number of descendants is the same.
+ std::cout << "xmltree numdesc\n";
+ const size_t numDesc = binaryTree.NumDescendants();
+ std::cout << "xmltree numdesc done.\n";
BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants());
BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants());
BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants());
@@ -829,6 +835,117 @@ BOOST_AUTO_TEST_CASE(CoverTreeOverwriteTest)
}
}
+BOOST_AUTO_TEST_CASE(RectangleTreeTest)
+{
+ arma::mat data;
+ data.randu(3, 1000);
+ typedef RTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ TreeType tree(data);
+
+ TreeType* xmlTree;
+ TreeType* textTree;
+ TreeType* binaryTree;
+
+ SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+ std::cout << "serialization complete\n";
+
+ CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+ // Check a few other things too.
+ std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+ stack.push(&tree);
+ xmlStack.push(xmlTree);
+ textStack.push(textTree);
+ binaryStack.push(binaryTree);
+ while (!stack.empty())
+ {
+ // Check more things...
+ TreeType* node = stack.top();
+ TreeType* xmlNode = xmlStack.top();
+ TreeType* textNode = textStack.top();
+ TreeType* binaryNode = binaryStack.top();
+ stack.pop();
+ xmlStack.pop();
+ textStack.pop();
+ binaryStack.pop();
+
+ CheckMatrices(node->LocalDataset(), xmlNode->LocalDataset(),
+ textNode->LocalDataset(), binaryNode->LocalDataset());
+
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), xmlNode->MaxLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), textNode->MaxLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), binaryNode->MaxLeafSize());
+
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), xmlNode->MinLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), textNode->MinLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), binaryNode->MinLeafSize());
+
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), xmlNode->MaxNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), textNode->MaxNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), binaryNode->MaxNumChildren());
+
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), xmlNode->MinNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), textNode->MinNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), binaryNode->MinNumChildren());
+ }
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeOverwriteTest)
+{
+ arma::mat data;
+ data.randu(3, 1000);
+ typedef RTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ TreeType tree(data);
+
+ arma::mat otherData;
+ otherData.randu(5, 50);
+ TreeType xmlTree(otherData);
+ TreeType textTree(otherData);
+ TreeType binaryTree(textTree);
+
+ SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+ CheckTrees(tree, xmlTree, textTree, binaryTree);
+
+ // Check a few other things too.
+ std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+ stack.push(&tree);
+ xmlStack.push(&xmlTree);
+ textStack.push(&textTree);
+ binaryStack.push(&binaryTree);
+ while (!stack.empty())
+ {
+ // Check more things...
+ TreeType* node = stack.top();
+ TreeType* xmlNode = xmlStack.top();
+ TreeType* textNode = textStack.top();
+ TreeType* binaryNode = binaryStack.top();
+ stack.pop();
+ xmlStack.pop();
+ textStack.pop();
+ binaryStack.pop();
+
+ CheckMatrices(node->LocalDataset(), xmlNode->LocalDataset(),
+ textNode->LocalDataset(), binaryNode->LocalDataset());
+
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), xmlNode->MaxLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), textNode->MaxLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), binaryNode->MaxLeafSize());
+
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), xmlNode->MinLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), textNode->MinLeafSize());
+ BOOST_REQUIRE_EQUAL(node->MinLeafSize(), binaryNode->MinLeafSize());
+
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), xmlNode->MaxNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), textNode->MaxNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), binaryNode->MaxNumChildren());
+
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), xmlNode->MinNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), textNode->MinNumChildren());
+ BOOST_REQUIRE_EQUAL(node->MinNumChildren(), binaryNode->MinNumChildren());
+ }
+}
+
BOOST_AUTO_TEST_CASE(PerceptronTest)
{
// Create a perceptron. Train it randomly. Then check that it hasn't
More information about the mlpack-git
mailing list