[mlpack-git] master: Refactor to hold dataset internally (for serialization). (6e9a746)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:16 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656

>---------------------------------------------------------------

commit 6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 11 14:44:07 2015 -0400

    Refactor to hold dataset internally (for serialization).


>---------------------------------------------------------------

6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 32 +++++++--
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 76 ++++++++++++++++++----
 src/mlpack/tests/rectangle_tree_test.cpp           |  6 +-
 3 files changed, 95 insertions(+), 19 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index ce69dea..97c7468 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -24,8 +24,7 @@ namespace tree /** Trees and tree-building procedures. */ {
  * the constructor with the dataset to build the tree on, and the entire tree
  * will be built.
  *
- * This tree does allow growth, so you can add and delete nodes
- * from it.
+ * This tree does allow growth, so you can add and delete nodes from it.
  *
  * @tparam MetricType This *must* be EuclideanDistance, but the template
  *     parameter is required to satisfy the TreeType API.
@@ -98,7 +97,10 @@ class RectangleTree
   //! The discance to the furthest descendant, cached to speed things up.
   double furthestDescendantDistance;
   //! The dataset.
-  const MatType& dataset;
+  const MatType* dataset;
+  //! Whether or not we are responsible for deleting the dataset.  This is
+  //! probably not aligned well...
+  bool ownsDataset;
   //! The mapping to the dataset
   std::vector<size_t> points;
   //! The local dataset
@@ -156,6 +158,14 @@ class RectangleTree
   RectangleTree(const RectangleTree& other, const bool deepCopy = true);
 
   /**
+   * Construct the tree from a boost::serialization archive.
+   */
+  template<typename Archive>
+  RectangleTree(
+      Archive& ar,
+      const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+  /**
    * Deletes this node, deallocating the memory for the children and calling
    * their destructors in turn.  This will invalidate any younters or references
    * to any nodes which are children of this one.
@@ -307,9 +317,9 @@ class RectangleTree
   RectangleTree*& Parent() { return parent; }
 
   //! Get the dataset which the tree is built on.
-  const MatType& Dataset() const { return dataset; }
+  const MatType& Dataset() const { return *dataset; }
   //! Modify the dataset which the tree is built on.  Be careful!
-  MatType& Dataset() { return const_cast<MatType&>(dataset); }
+  MatType& Dataset() { return const_cast<MatType&>(*dataset); }
 
   //! Get the points vector for this node.
   const std::vector<size_t>& Points() const { return points; }
@@ -511,6 +521,18 @@ class RectangleTree
    */
   void SplitNode(std::vector<bool>& relevels);
 
+ protected:
+  /**
+   * A default constructor.  This is meant to only be used with
+   * boost::serialization, which is allowed with the friend declaration below.
+   * This does not return a valid tree!  This method must be protected, so that
+   * the serialization shim can work with the default constructor.
+   */
+  RectangleTree();
+
+  //! Friend access is given for the default constructor.
+  friend class boost::serialization::access;
+
  public:
   /**
    * Condense the bounding rectangles for this node based on the removal of the
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index c733817..d325ab4 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -41,7 +41,8 @@ RectangleTree(const MatType& data,
     bound(data.n_rows),
     splitHistory(bound.Dim()),
     parentDistance(0),
-    dataset(data),
+    dataset(new MatType(data)),
+    ownsDataset(true),
     points(maxLeafSize + 1), // Add one to make splitting the node simpler.
     localDataset(new MatType(data.n_rows, static_cast<int> (maxLeafSize) + 1))
 {
@@ -75,7 +76,8 @@ RectangleTree(
     bound(parentNode->Bound().Dim()),
     splitHistory(bound.Dim()),
     parentDistance(0),
-    dataset(parentNode->Dataset()),
+    dataset(&parentNode->Dataset()),
+    ownsDataset(false),
     points(maxLeafSize + 1), // Add one to make splitting the node simpler.
     localDataset(new MatType(static_cast<int> (parentNode->Bound().Dim()),
                              static_cast<int> (maxLeafSize) + 1))
@@ -108,7 +110,8 @@ RectangleTree(
     bound(other.bound),
     splitHistory(other.SplitHistory()),
     parentDistance(other.ParentDistance()),
-    dataset(other.dataset),
+    dataset(new MatType(*other.dataset)),
+    ownsDataset(true),
     points(other.Points()),
     localDataset(NULL)
 {
@@ -135,6 +138,25 @@ RectangleTree(
 }
 
 /**
+ * Construct the tree from a boost::serialization archive.
+ */
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType,
+         typename DescentType>
+template<typename Archive>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+RectangleTree(
+    Archive& ar,
+    const typename boost::enable_if<typename Archive::is_loading>::type*) :
+    RectangleTree() // Use default constructor.
+{
+  // Now serialize.
+  ar >> data::CreateNVP(*this, "tree");
+}
+
+/**
  * Deletes this node, deallocating the memory for the children and calling
  * their destructors in turn.  This will invalidate any pointers or references
  * to any nodes which are children of this one.
@@ -150,6 +172,9 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   for (size_t i = 0; i < numChildren; i++)
     delete children[i];
 
+  if (ownsDataset)
+    delete dataset;
+
   delete localDataset;
 }
 
@@ -201,7 +226,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     InsertPoint(const size_t point)
 {
   // Expand the bound regardless of whether it is a leaf node.
-  bound |= dataset.col(point);
+  bound |= dataset->col(point);
 
   std::vector<bool> lvls(TreeDepth());
   for (size_t i = 0; i < lvls.size(); i++)
@@ -210,7 +235,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   // If this is a leaf node, we stop here and add the point.
   if (numChildren == 0)
   {
-    localDataset->col(count) = dataset.col(point);
+    localDataset->col(count) = dataset->col(point);
     points[count++] = point;
     SplitNode(lvls);
     return;
@@ -219,7 +244,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   // If it is not a leaf node, we use the DescentHeuristic to choose a child
   // to which we recurse.
   const size_t descentNode = DescentType::ChooseDescentNode(this,
-      dataset.col(point));
+      dataset->col(point));
   children[descentNode]->InsertPoint(point, lvls);
 }
 
@@ -238,12 +263,12 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
     InsertPoint(const size_t point, std::vector<bool>& relevels)
 {
   // Expand the bound regardless of whether it is a leaf node.
-  bound |= dataset.col(point);
+  bound |= dataset->col(point);
 
   // If this is a leaf node, we stop here and add the point.
   if (numChildren == 0)
   {
-    localDataset->col(count) = dataset.col(point);
+    localDataset->col(count) = dataset->col(point);
     points[count++] = point;
     SplitNode(relevels);
     return;
@@ -252,7 +277,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   // If it is not a leaf node, we use the DescentHeuristic to choose a child
   // to which we recurse.
   const size_t descentNode = DescentType::ChooseDescentNode(this,
-      dataset.col(point));
+      dataset->col(point));
   children[descentNode]->InsertPoint(point, relevels);
 }
 
@@ -320,14 +345,14 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
         localDataset->col(i) = localDataset->col(--count); // Decrement count.
         points[i] = points[count];
         // This function wil ensure that minFill is satisfied.
-        CondenseTree(dataset.col(point), lvls, true);
+        CondenseTree(dataset->col(point), lvls, true);
         return true;
       }
     }
   }
 
   for (size_t i = 0; i < numChildren; i++)
-    if (children[i]->Bound().Contains(dataset.col(point)))
+    if (children[i]->Bound().Contains(dataset->col(point)))
       if (children[i]->DeletePoint(point, lvls))
         return true;
 
@@ -355,14 +380,14 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
         localDataset->col(i) = localDataset->col(--count);
         points[i] = points[count];
         // This function will ensure that minFill is satisfied.
-        CondenseTree(dataset.col(point), relevels, true);
+        CondenseTree(dataset->col(point), relevels, true);
         return true;
       }
     }
   }
 
   for (size_t i = 0; i < numChildren; i++)
-    if (children[i]->Bound().Contains(dataset.col(point)))
+    if (children[i]->Bound().Contains(dataset->col(point)))
       if (children[i]->DeletePoint(point, relevels))
         return true;
 
@@ -591,6 +616,31 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
   }
 }
 
+//! Default constructor for boost::serialization.
+template<typename MetricType,
+         typename StatisticType,
+         typename MatType,
+         typename SplitType,
+         typename DescentType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+RectangleTree() :
+    maxNumChildren(0), // Try to give sensible defaults, but it shouldn't matter
+    minNumChildren(0), // because this tree isn't valid anyway and is only used
+    numChildren(0),    // by boost::serialization.
+    parent(NULL),
+    begin(0),
+    count(0),
+    maxLeafSize(0),
+    minLeafSize(0),
+    parentDistance(0.0),
+    furthestDescendantDistance(0.0),
+    dataset(NULL),
+    ownsDataset(false),
+    localDataset(NULL)
+{
+  // Nothing to do.
+}
+
 /**
  * Condense the tree.  This shrinks the bounds and moves up the tree if
  * applicable.  If a node goes below minimum fill, this code will deal with it.
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index d03b406..9f42c15 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -467,12 +467,16 @@ BOOST_AUTO_TEST_CASE(PointDynamicAdd)
       arma::mat> TreeType;
   TreeType tree(dataset, 20, 6, 5, 2, 0);
 
-  // Add numIter new points to the dataset.
+  // Add numIter new points to the dataset.  The tree copies the dataset, so we
+  // must modify both the original dataset and the one that the tree holds.
+  // (This API is clunky.  It should be redone sometime.)
+  tree.Dataset().reshape(8, 1000 + numIter);
   dataset.reshape(8, 1000 + numIter);
   arma::mat tmpData;
   tmpData.randu(8, numIter);
   for (int i = 0; i < numIter; i++)
   {
+    tree.Dataset().col(1000 + i) = tmpData.col(i);
     dataset.col(1000 + i) = tmpData.col(i);
     tree.InsertPoint(1000 + i);
   }



More information about the mlpack-git mailing list