[mlpack-git] master: Add serialization to DTree. (0e910b1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:47 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c

>---------------------------------------------------------------

commit 0e910b151b0f28f5622d4b2be77e0a4fcd88e190
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Oct 2 21:39:01 2015 +0000

    Add serialization to DTree.


>---------------------------------------------------------------

0e910b151b0f28f5622d4b2be77e0a4fcd88e190
 src/mlpack/methods/det/dtree.hpp        |  39 +++++-
 src/mlpack/tests/serialization_test.cpp | 214 ++++++++++++++++++++++++++++++++
 2 files changed, 251 insertions(+), 2 deletions(-)

diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index be04c5b..429d971 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -283,6 +283,41 @@ class DTree
    */
   std::string ToString() const;
 
+  /**
+   * Serialize the density estimation tree.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */)
+  {
+    using data::CreateNVP;
+
+    ar & CreateNVP(start, "start");
+    ar & CreateNVP(end, "end");
+    ar & CreateNVP(maxVals, "maxVals");
+    ar & CreateNVP(minVals, "minVals");
+    ar & CreateNVP(splitDim, "splitDim");
+    ar & CreateNVP(splitValue, "splitValue");
+    ar & CreateNVP(logNegError, "logNegError");
+    ar & CreateNVP(subtreeLeavesLogNegError, "subtreeLeavesLogNegError");
+    ar & CreateNVP(subtreeLeaves, "subtreeLeaves");
+    ar & CreateNVP(root, "root");
+    ar & CreateNVP(ratio, "ratio");
+    ar & CreateNVP(logVolume, "logVolume");
+    ar & CreateNVP(bucketTag, "bucketTag");
+    ar & CreateNVP(alphaUpper, "alphaUpper");
+
+    if (Archive::is_loading::value)
+    {
+      if (left)
+        delete left;
+      if (right)
+        delete right;
+    }
+
+    ar & CreateNVP(left, "left");
+    ar & CreateNVP(right, "right");
+  }
+
  private:
 
   // Utility methods.
@@ -307,7 +342,7 @@ class DTree
 
 };
 
-}; // namespace det
-}; // namespace mlpack
+} // namespace det
+} // namespace mlpack
 
 #endif // __MLPACK_METHODS_DET_DTREE_HPP
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 773b70e..424126c 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -25,6 +25,7 @@
 #include <mlpack/methods/perceptron/perceptron.hpp>
 #include <mlpack/methods/logistic_regression/logistic_regression.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/methods/det/dtree.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -799,4 +800,217 @@ BOOST_AUTO_TEST_CASE(AllkNNTest)
   CheckMatrices(neighbors, xmlNeighbors, textNeighbors, binaryNeighbors);
 }
 
+BOOST_AUTO_TEST_CASE(DETTest)
+{
+  using det::DTree;
+
+  // Create a density estimation tree on a random dataset.
+  arma::mat dataset = arma::randu<arma::mat>(25, 5000);
+
+  DTree tree(dataset);
+
+  arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
+  DTree xmlTree, binaryTree, textTree(otherDataset);
+
+  SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
+
+  std::stack<DTree*> stack, xmlStack, binaryStack, textStack;
+  stack.push(&tree);
+  xmlStack.push(&xmlTree);
+  binaryStack.push(&binaryTree);
+  textStack.push(&textTree);
+
+  while (!stack.empty())
+  {
+    // Get the top node from the stack.
+    DTree* node = stack.top();
+    DTree* xmlNode = xmlStack.top();
+    DTree* binaryNode = binaryStack.top();
+    DTree* textNode = textStack.top();
+
+    stack.pop();
+    xmlStack.pop();
+    binaryStack.pop();
+    textStack.pop();
+
+    // Check that all the members are the same.
+    BOOST_REQUIRE_EQUAL(node->Start(), xmlNode->Start());
+    BOOST_REQUIRE_EQUAL(node->Start(), binaryNode->Start());
+    BOOST_REQUIRE_EQUAL(node->Start(), textNode->Start());
+
+    BOOST_REQUIRE_EQUAL(node->End(), xmlNode->End());
+    BOOST_REQUIRE_EQUAL(node->End(), binaryNode->End());
+    BOOST_REQUIRE_EQUAL(node->End(), textNode->End());
+
+    BOOST_REQUIRE_EQUAL(node->SplitDim(), xmlNode->SplitDim());
+    BOOST_REQUIRE_EQUAL(node->SplitDim(), binaryNode->SplitDim());
+    BOOST_REQUIRE_EQUAL(node->SplitDim(), textNode->SplitDim());
+
+    if (std::abs(node->SplitValue()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->SplitValue(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->SplitValue(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->SplitValue(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->SplitValue(), xmlNode->SplitValue(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->SplitValue(), binaryNode->SplitValue(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->SplitValue(), textNode->SplitValue(), 1e-5);
+    }
+
+    if (std::abs(node->LogNegError()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->LogNegError(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->LogNegError(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->LogNegError(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->LogNegError(), xmlNode->LogNegError(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->LogNegError(), binaryNode->LogNegError(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->LogNegError(), textNode->LogNegError(), 1e-5);
+    }
+
+    if (std::abs(node->SubtreeLeavesLogNegError()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->SubtreeLeavesLogNegError(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->SubtreeLeavesLogNegError(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->SubtreeLeavesLogNegError(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+          xmlNode->SubtreeLeavesLogNegError(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+          binaryNode->SubtreeLeavesLogNegError(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+          textNode->SubtreeLeavesLogNegError(), 1e-5);
+    }
+
+    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), xmlNode->SubtreeLeaves());
+    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), binaryNode->SubtreeLeaves());
+    BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), textNode->SubtreeLeaves());
+
+    if (std::abs(node->Ratio()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->Ratio(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->Ratio(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->Ratio(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->Ratio(), xmlNode->Ratio(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->Ratio(), binaryNode->Ratio(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->Ratio(), textNode->Ratio(), 1e-5);
+    }
+
+    if (std::abs(node->LogVolume()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->LogVolume(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->LogVolume(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->LogVolume(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->LogVolume(), xmlNode->LogVolume(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->LogVolume(), binaryNode->LogVolume(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->LogVolume(), textNode->LogVolume(), 1e-5);
+    }
+
+    if (node->Left() == NULL)
+    {
+      BOOST_REQUIRE(xmlNode->Left() == NULL);
+      BOOST_REQUIRE(binaryNode->Left() == NULL);
+      BOOST_REQUIRE(textNode->Left() == NULL);
+    }
+    else
+    {
+      BOOST_REQUIRE(xmlNode->Left() != NULL);
+      BOOST_REQUIRE(binaryNode->Left() != NULL);
+      BOOST_REQUIRE(textNode->Left() != NULL);
+
+      // Push children onto stack.
+      stack.push(node->Left());
+      xmlStack.push(xmlNode->Left());
+      binaryStack.push(binaryNode->Left());
+      textStack.push(textNode->Left());
+    }
+
+    if (node->Right() == NULL)
+    {
+      BOOST_REQUIRE(xmlNode->Right() == NULL);
+      BOOST_REQUIRE(binaryNode->Right() == NULL);
+      BOOST_REQUIRE(textNode->Right() == NULL);
+    }
+    else
+    {
+      BOOST_REQUIRE(xmlNode->Right() != NULL);
+      BOOST_REQUIRE(binaryNode->Right() != NULL);
+      BOOST_REQUIRE(textNode->Right() != NULL);
+
+      // Push children onto stack.
+      stack.push(node->Right());
+      xmlStack.push(xmlNode->Right());
+      binaryStack.push(binaryNode->Right());
+      textStack.push(textNode->Right());
+    }
+
+    BOOST_REQUIRE_EQUAL(node->Root(), xmlNode->Root());
+    BOOST_REQUIRE_EQUAL(node->Root(), binaryNode->Root());
+    BOOST_REQUIRE_EQUAL(node->Root(), textNode->Root());
+
+    if (std::abs(node->AlphaUpper()) < 1e-5)
+    {
+      BOOST_REQUIRE_SMALL(xmlNode->AlphaUpper(), 1e-5);
+      BOOST_REQUIRE_SMALL(binaryNode->AlphaUpper(), 1e-5);
+      BOOST_REQUIRE_SMALL(textNode->AlphaUpper(), 1e-5);
+    }
+    else
+    {
+      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), xmlNode->AlphaUpper(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), binaryNode->AlphaUpper(), 1e-5);
+      BOOST_REQUIRE_CLOSE(node->AlphaUpper(), textNode->AlphaUpper(), 1e-5);
+    }
+
+    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, xmlNode->MaxVals().n_elem);
+    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, binaryNode->MaxVals().n_elem);
+    BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, textNode->MaxVals().n_elem);
+    for (size_t i = 0; i < node->MaxVals().n_elem; ++i)
+    {
+      if (std::abs(node->MaxVals()[i]) < 1e-5)
+      {
+        BOOST_REQUIRE_SMALL(xmlNode->MaxVals()[i], 1e-5);
+        BOOST_REQUIRE_SMALL(binaryNode->MaxVals()[i], 1e-5);
+        BOOST_REQUIRE_SMALL(textNode->MaxVals()[i], 1e-5);
+      }
+      else
+      {
+        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], xmlNode->MaxVals()[i], 1e-5);
+        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], binaryNode->MaxVals()[i], 1e-5);
+        BOOST_REQUIRE_CLOSE(node->MaxVals()[i], textNode->MaxVals()[i], 1e-5);
+      }
+    }
+
+    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, xmlNode->MinVals().n_elem);
+    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, binaryNode->MinVals().n_elem);
+    BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, textNode->MinVals().n_elem);
+    for (size_t i = 0; i < node->MinVals().n_elem; ++i)
+    {
+      if (std::abs(node->MinVals()[i]) < 1e-5)
+      {
+        BOOST_REQUIRE_SMALL(xmlNode->MinVals()[i], 1e-5);
+        BOOST_REQUIRE_SMALL(binaryNode->MinVals()[i], 1e-5);
+        BOOST_REQUIRE_SMALL(textNode->MinVals()[i], 1e-5);
+      }
+      else
+      {
+        BOOST_REQUIRE_CLOSE(node->MinVals()[i], xmlNode->MinVals()[i], 1e-5);
+        BOOST_REQUIRE_CLOSE(node->MinVals()[i], binaryNode->MinVals()[i], 1e-5);
+        BOOST_REQUIRE_CLOSE(node->MinVals()[i], textNode->MinVals()[i], 1e-5);
+      }
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list