[mlpack-git] master: Serialization for RectangleTree. Not working---committed in order to work on another system. Also has debugging output. (168a49a)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Oct 14 05:02:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/81e72d4410ae417f7a8536bd3c61865e2f62c934...ce49a4b5f0b7d12d4955c09e45c69891a6f83e8a

>---------------------------------------------------------------

commit 168a49a3783db14a3d49865d71c68790bd929511
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 12 15:12:15 2015 -0400

    Serialization for RectangleTree. Not working---committed in order to work on another system. Also has debugging output.


>---------------------------------------------------------------

168a49a3783db14a3d49865d71c68790bd929511
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |  13 +++
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 107 ++++++++++++++++++-
 src/mlpack/tests/serialization_test.cpp            | 117 +++++++++++++++++++++
 3 files changed, 236 insertions(+), 1 deletion(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index 97c7468..e54ae1a 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -61,6 +61,13 @@ class RectangleTree
       for (int i = 0; i < dim; i++)
         history[i] = false;
     }
+
+    template<typename Archive>
+    void Serialize(Archive& ar, const unsigned int /* version */)
+    {
+      ar & data::CreateNVP(lastDimension, "lastDimension");
+      ar & data::CreateNVP(history, "history");
+    }
   } SplitHistoryStruct;
 
  private:
@@ -576,6 +583,12 @@ class RectangleTree
    * Returns a string representation of this object.
    */
   std::string ToString() const;
+
+  /**
+   * Serialize the tree.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
 };
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index d325ab4..4a44435 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -540,6 +540,8 @@ template<typename MetricType,
 inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                             DescentType>::NumDescendants() const
 {
+  std::cout << "NumDescendants() [" << this << "], with " << numChildren 
+      << "children.\n";
   if (numChildren == 0)
   {
     return count;
@@ -548,8 +550,10 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
   {
     size_t n = 0;
     for (size_t i = 0; i < numChildren; i++)
+    {
+      std::cout << "child " << i << ": " << children[i] << ".\n";
       n += children[i]->NumDescendants();
-
+    }
     return n;
   }
 }
@@ -632,6 +636,7 @@ RectangleTree() :
     count(0),
     maxLeafSize(0),
     minLeafSize(0),
+    splitHistory(0),
     parentDistance(0.0),
     furthestDescendantDistance(0.0),
     dataset(NULL),
@@ -930,6 +935,106 @@ std::string RectangleTree<MetricType, StatisticType, MatType, SplitType,
   return convert.str();
 }
 
+/**
+ * Serialize the tree.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType,
+         typename DescentType>
+template<typename Archive>
+void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+    Serialize(Archive& ar,
+              const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  // Clean up memory, if necessary.
+  if (Archive::is_loading::value)
+  {
+    for (size_t i = 0; i < numChildren; i++)
+      delete children[i];
+    children.clear();
+
+    if (ownsDataset && dataset)
+      delete dataset;
+
+    if (localDataset)
+      delete localDataset;
+  }
+
+  ar & CreateNVP(maxNumChildren, "maxNumChildren");
+  ar & CreateNVP(minNumChildren, "minNumChildren");
+  ar & CreateNVP(numChildren, "numChildren");
+  ar & CreateNVP(parent, "parent");
+  ar & CreateNVP(begin, "begin");
+  ar & CreateNVP(count, "count");
+  ar & CreateNVP(maxLeafSize, "maxLeafSize");
+  ar & CreateNVP(minLeafSize, "minLeafSize");
+  ar & CreateNVP(bound, "bound");
+  ar & CreateNVP(stat, "stat");
+  ar & CreateNVP(splitHistory, "splitHistory");
+  ar & CreateNVP(parentDistance, "parentDistance");
+  ar & CreateNVP(furthestDescendantDistance, "furthestDescendantDistance");
+  ar & CreateNVP(dataset, "dataset");
+
+  // If we are loading and we are the root, we own the dataset.
+  if (Archive::is_loading::value && parent == NULL)
+    ownsDataset = true;
+
+  ar & CreateNVP(points, "points");
+  ar & CreateNVP(localDataset, "localDataset");
+
+  // Because 'children' holds mlpack types (that have Serialize()), we can't use
+  // the std::vector serialization.
+  if (Archive::is_loading::value)
+    children.resize(numChildren);
+  for (size_t i = 0; i < numChildren; ++i)
+  {
+    std::ostringstream oss;
+    oss << "child" << i;
+    ar & CreateNVP(children[i], oss.str());
+  }
+
+
+  // Due to quirks of boost::serialization, if a tree is saved as an object and
+  // not a pointer, the first level of the tree will be duplicated on load.
+  // Therefore, if we are the root of the tree, then we need to make sure our
+  // children's parent links are correct, and delete the duplicated node if
+  // necessary.
+  if (Archive::is_loading::value)
+  {
+    // Look through each child individually.
+    for (size_t i = 0; i < children.size(); ++i)
+    {
+      if (children[i]->Parent() != this)
+      {
+        // Disallow the duplicate parent from deleting anything.  But only
+        // delete the parent if this is the first child (we are assuming that
+        // each of the other children has the same incorrect parent).
+        if (i == 0)
+        {
+          children[i]->Parent()->ownsDataset = false;
+          children[i]->Parent()->children.clear();
+          delete children[i]->Parent();
+        }
+
+        // Fix the child's parent link.
+        children[i]->Parent() = this;
+      }
+    }
+  }
+
+  if (Archive::is_loading::value)
+  {
+    std::cout << "loaded node " << this << " with " << numChildren << " (" <<
+        children.size() << ") children\n";
+    for (size_t i = 0; i < numChildren; ++i)
+      std::cout << "child " << i << ": " << children[i] << ".\n";
+  }
+}
+
 } // namespace tree
 } // namespace mlpack
 
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 5a71632..bc84a5d 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -22,6 +22,7 @@
 #include <mlpack/core/metrics/mahalanobis_distance.hpp>
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
 
 #include <mlpack/methods/perceptron/perceptron.hpp>
 #include <mlpack/methods/logistic_regression/logistic_regression.hpp>
@@ -610,6 +611,8 @@ void CheckTrees(TreeType& tree,
                 TreeType& binaryTree)
 {
   const typename TreeType::Mat* dataset = &tree.Dataset();
+  std::cout << "check tree node " << tree.NumChildren() << " desc " <<
+tree.NumDescendants() << ".\n";
 
   // Make sure that the data matrices are the same.
   if (tree.Parent() == NULL)
@@ -631,6 +634,9 @@ void CheckTrees(TreeType& tree,
   BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());
 
   // Make sure the number of descendants is the same.
+  std::cout << "xmltree numdesc\n";
+  const size_t numDesc = binaryTree.NumDescendants();
+  std::cout << "xmltree numdesc done.\n";
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants());
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants());
   BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants());
@@ -829,6 +835,117 @@ BOOST_AUTO_TEST_CASE(CoverTreeOverwriteTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(RectangleTreeTest)
+{
+  arma::mat data;
+  data.randu(3, 1000);
+  typedef RTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+  TreeType tree(data);
+
+  TreeType* xmlTree;
+  TreeType* textTree;
+  TreeType* binaryTree;
+
+  SerializePointerObjectAll(&tree, xmlTree, textTree, binaryTree);
+  std::cout << "serialization complete\n";
+
+  CheckTrees(tree, *xmlTree, *textTree, *binaryTree);
+
+  // Check a few other things too.
+  std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+  stack.push(&tree);
+  xmlStack.push(xmlTree);
+  textStack.push(textTree);
+  binaryStack.push(binaryTree);
+  while (!stack.empty())
+  {
+    // Check more things...
+    TreeType* node = stack.top();
+    TreeType* xmlNode = xmlStack.top();
+    TreeType* textNode = textStack.top();
+    TreeType* binaryNode = binaryStack.top();
+    stack.pop();
+    xmlStack.pop();
+    textStack.pop();
+    binaryStack.pop();
+
+    CheckMatrices(node->LocalDataset(), xmlNode->LocalDataset(),
+        textNode->LocalDataset(), binaryNode->LocalDataset());
+
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), xmlNode->MaxLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), textNode->MaxLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), binaryNode->MaxLeafSize());
+
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), xmlNode->MinLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), textNode->MinLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), binaryNode->MinLeafSize());
+
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), xmlNode->MaxNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), textNode->MaxNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), binaryNode->MaxNumChildren());
+
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), xmlNode->MinNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), textNode->MinNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), binaryNode->MinNumChildren());
+  }
+}
+
+BOOST_AUTO_TEST_CASE(RectangleTreeOverwriteTest)
+{
+  arma::mat data;
+  data.randu(3, 1000);
+  typedef RTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+  TreeType tree(data);
+
+  arma::mat otherData;
+  otherData.randu(5, 50);
+  TreeType xmlTree(otherData);
+  TreeType textTree(otherData);
+  TreeType binaryTree(textTree);
+
+  SerializeObjectAll(tree, xmlTree, textTree, binaryTree);
+
+  CheckTrees(tree, xmlTree, textTree, binaryTree);
+
+  // Check a few other things too.
+  std::stack<TreeType*> stack, xmlStack, textStack, binaryStack;
+  stack.push(&tree);
+  xmlStack.push(&xmlTree);
+  textStack.push(&textTree);
+  binaryStack.push(&binaryTree);
+  while (!stack.empty())
+  {
+    // Check more things...
+    TreeType* node = stack.top();
+    TreeType* xmlNode = xmlStack.top();
+    TreeType* textNode = textStack.top();
+    TreeType* binaryNode = binaryStack.top();
+    stack.pop();
+    xmlStack.pop();
+    textStack.pop();
+    binaryStack.pop();
+
+    CheckMatrices(node->LocalDataset(), xmlNode->LocalDataset(),
+        textNode->LocalDataset(), binaryNode->LocalDataset());
+
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), xmlNode->MaxLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), textNode->MaxLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MaxLeafSize(), binaryNode->MaxLeafSize());
+
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), xmlNode->MinLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), textNode->MinLeafSize());
+    BOOST_REQUIRE_EQUAL(node->MinLeafSize(), binaryNode->MinLeafSize());
+
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), xmlNode->MaxNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), textNode->MaxNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MaxNumChildren(), binaryNode->MaxNumChildren());
+
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), xmlNode->MinNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), textNode->MinNumChildren());
+    BOOST_REQUIRE_EQUAL(node->MinNumChildren(), binaryNode->MinNumChildren());
+  }
+}
+
 BOOST_AUTO_TEST_CASE(PerceptronTest)
 {
   // Create a perceptron.  Train it randomly.  Then check that it hasn't



More information about the mlpack-git mailing list