[mlpack-git] master: Refactor to hold dataset internally (for serialization). (6e9a746)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:16 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
>---------------------------------------------------------------
commit 6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 11 14:44:07 2015 -0400
Refactor to hold dataset internally (for serialization).
>---------------------------------------------------------------
6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
.../core/tree/rectangle_tree/rectangle_tree.hpp | 32 +++++++--
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 76 ++++++++++++++++++----
src/mlpack/tests/rectangle_tree_test.cpp | 6 +-
3 files changed, 95 insertions(+), 19 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index ce69dea..97c7468 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -24,8 +24,7 @@ namespace tree /** Trees and tree-building procedures. */ {
* the constructor with the dataset to build the tree on, and the entire tree
* will be built.
*
- * This tree does allow growth, so you can add and delete nodes
- * from it.
+ * This tree does allow growth, so you can add and delete nodes from it.
*
* @tparam MetricType This *must* be EuclideanDistance, but the template
* parameter is required to satisfy the TreeType API.
@@ -98,7 +97,10 @@ class RectangleTree
//! The discance to the furthest descendant, cached to speed things up.
double furthestDescendantDistance;
//! The dataset.
- const MatType& dataset;
+ const MatType* dataset;
+ //! Whether or not we are responsible for deleting the dataset. This is
+ //! probably not aligned well...
+ bool ownsDataset;
//! The mapping to the dataset
std::vector<size_t> points;
//! The local dataset
@@ -156,6 +158,14 @@ class RectangleTree
RectangleTree(const RectangleTree& other, const bool deepCopy = true);
/**
+ * Construct the tree from a boost::serialization archive.
+ */
+ template<typename Archive>
+ RectangleTree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type* = 0);
+
+ /**
* Deletes this node, deallocating the memory for the children and calling
* their destructors in turn. This will invalidate any younters or references
* to any nodes which are children of this one.
@@ -307,9 +317,9 @@ class RectangleTree
RectangleTree*& Parent() { return parent; }
//! Get the dataset which the tree is built on.
- const MatType& Dataset() const { return dataset; }
+ const MatType& Dataset() const { return *dataset; }
//! Modify the dataset which the tree is built on. Be careful!
- MatType& Dataset() { return const_cast<MatType&>(dataset); }
+ MatType& Dataset() { return const_cast<MatType&>(*dataset); }
//! Get the points vector for this node.
const std::vector<size_t>& Points() const { return points; }
@@ -511,6 +521,18 @@ class RectangleTree
*/
void SplitNode(std::vector<bool>& relevels);
+ protected:
+ /**
+ * A default constructor. This is meant to only be used with
+ * boost::serialization, which is allowed with the friend declaration below.
+ * This does not return a valid tree! This method must be protected, so that
+ * the serialization shim can work with the default constructor.
+ */
+ RectangleTree();
+
+ //! Friend access is given for the default constructor.
+ friend class boost::serialization::access;
+
public:
/**
* Condense the bounding rectangles for this node based on the removal of the
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index c733817..d325ab4 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -41,7 +41,8 @@ RectangleTree(const MatType& data,
bound(data.n_rows),
splitHistory(bound.Dim()),
parentDistance(0),
- dataset(data),
+ dataset(new MatType(data)),
+ ownsDataset(true),
points(maxLeafSize + 1), // Add one to make splitting the node simpler.
localDataset(new MatType(data.n_rows, static_cast<int> (maxLeafSize) + 1))
{
@@ -75,7 +76,8 @@ RectangleTree(
bound(parentNode->Bound().Dim()),
splitHistory(bound.Dim()),
parentDistance(0),
- dataset(parentNode->Dataset()),
+ dataset(&parentNode->Dataset()),
+ ownsDataset(false),
points(maxLeafSize + 1), // Add one to make splitting the node simpler.
localDataset(new MatType(static_cast<int> (parentNode->Bound().Dim()),
static_cast<int> (maxLeafSize) + 1))
@@ -108,7 +110,8 @@ RectangleTree(
bound(other.bound),
splitHistory(other.SplitHistory()),
parentDistance(other.ParentDistance()),
- dataset(other.dataset),
+ dataset(new MatType(*other.dataset)),
+ ownsDataset(true),
points(other.Points()),
localDataset(NULL)
{
@@ -135,6 +138,25 @@ RectangleTree(
}
/**
+ * Construct the tree from a boost::serialization archive.
+ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType>
+template<typename Archive>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+RectangleTree(
+ Archive& ar,
+ const typename boost::enable_if<typename Archive::is_loading>::type*) :
+ RectangleTree() // Use default constructor.
+{
+ // Now serialize.
+ ar >> data::CreateNVP(*this, "tree");
+}
+
+/**
* Deletes this node, deallocating the memory for the children and calling
* their destructors in turn. This will invalidate any pointers or references
* to any nodes which are children of this one.
@@ -150,6 +172,9 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
for (size_t i = 0; i < numChildren; i++)
delete children[i];
+ if (ownsDataset)
+ delete dataset;
+
delete localDataset;
}
@@ -201,7 +226,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
InsertPoint(const size_t point)
{
// Expand the bound regardless of whether it is a leaf node.
- bound |= dataset.col(point);
+ bound |= dataset->col(point);
std::vector<bool> lvls(TreeDepth());
for (size_t i = 0; i < lvls.size(); i++)
@@ -210,7 +235,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
// If this is a leaf node, we stop here and add the point.
if (numChildren == 0)
{
- localDataset->col(count) = dataset.col(point);
+ localDataset->col(count) = dataset->col(point);
points[count++] = point;
SplitNode(lvls);
return;
@@ -219,7 +244,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
const size_t descentNode = DescentType::ChooseDescentNode(this,
- dataset.col(point));
+ dataset->col(point));
children[descentNode]->InsertPoint(point, lvls);
}
@@ -238,12 +263,12 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
InsertPoint(const size_t point, std::vector<bool>& relevels)
{
// Expand the bound regardless of whether it is a leaf node.
- bound |= dataset.col(point);
+ bound |= dataset->col(point);
// If this is a leaf node, we stop here and add the point.
if (numChildren == 0)
{
- localDataset->col(count) = dataset.col(point);
+ localDataset->col(count) = dataset->col(point);
points[count++] = point;
SplitNode(relevels);
return;
@@ -252,7 +277,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
const size_t descentNode = DescentType::ChooseDescentNode(this,
- dataset.col(point));
+ dataset->col(point));
children[descentNode]->InsertPoint(point, relevels);
}
@@ -320,14 +345,14 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
localDataset->col(i) = localDataset->col(--count); // Decrement count.
points[i] = points[count];
// This function wil ensure that minFill is satisfied.
- CondenseTree(dataset.col(point), lvls, true);
+ CondenseTree(dataset->col(point), lvls, true);
return true;
}
}
}
for (size_t i = 0; i < numChildren; i++)
- if (children[i]->Bound().Contains(dataset.col(point)))
+ if (children[i]->Bound().Contains(dataset->col(point)))
if (children[i]->DeletePoint(point, lvls))
return true;
@@ -355,14 +380,14 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
localDataset->col(i) = localDataset->col(--count);
points[i] = points[count];
// This function will ensure that minFill is satisfied.
- CondenseTree(dataset.col(point), relevels, true);
+ CondenseTree(dataset->col(point), relevels, true);
return true;
}
}
}
for (size_t i = 0; i < numChildren; i++)
- if (children[i]->Bound().Contains(dataset.col(point)))
+ if (children[i]->Bound().Contains(dataset->col(point)))
if (children[i]->DeletePoint(point, relevels))
return true;
@@ -591,6 +616,31 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
}
}
+//! Default constructor for boost::serialization.
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ typename SplitType,
+ typename DescentType>
+RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
+RectangleTree() :
+ maxNumChildren(0), // Try to give sensible defaults, but it shouldn't matter
+ minNumChildren(0), // because this tree isn't valid anyway and is only used
+ numChildren(0), // by boost::serialization.
+ parent(NULL),
+ begin(0),
+ count(0),
+ maxLeafSize(0),
+ minLeafSize(0),
+ parentDistance(0.0),
+ furthestDescendantDistance(0.0),
+ dataset(NULL),
+ ownsDataset(false),
+ localDataset(NULL)
+{
+ // Nothing to do.
+}
+
/**
* Condense the tree. This shrinks the bounds and moves up the tree if
* applicable. If a node goes below minimum fill, this code will deal with it.
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index d03b406..9f42c15 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -467,12 +467,16 @@ BOOST_AUTO_TEST_CASE(PointDynamicAdd)
arma::mat> TreeType;
TreeType tree(dataset, 20, 6, 5, 2, 0);
- // Add numIter new points to the dataset.
+ // Add numIter new points to the dataset. The tree copies the dataset, so we
+ // must modify both the original dataset and the one that the tree holds.
+ // (This API is clunky. It should be redone sometime.)
+ tree.Dataset().reshape(8, 1000 + numIter);
dataset.reshape(8, 1000 + numIter);
arma::mat tmpData;
tmpData.randu(8, numIter);
for (int i = 0; i < numIter; i++)
{
+ tree.Dataset().col(1000 + i) = tmpData.col(i);
dataset.col(1000 + i) = tmpData.col(i);
tree.InsertPoint(1000 + i);
}
More information about the mlpack-git
mailing list