[mlpack-git] master: Refactor so holding a dataset internally is possible. (997d03b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Oct 11 16:26:14 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7e6fe3f1c25445dfd41c747839aa1fbf1c31a776...6e9a7465d7739e05e6b4aa650e1f87c45e9cd656
>---------------------------------------------------------------
commit 997d03baf19eaf73cb3f51b38d5c5b8b877d353a
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 11 05:18:32 2015 -0400
Refactor so holding a dataset internally is possible.
>---------------------------------------------------------------
997d03baf19eaf73cb3f51b38d5c5b8b877d353a
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 27 ++++++------
.../core/tree/cover_tree/cover_tree_impl.hpp | 48 +++++++++++++---------
src/mlpack/tests/tree_test.cpp | 15 ++++---
3 files changed, 50 insertions(+), 40 deletions(-)
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index e5e78fc..1f27a60 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -191,7 +191,7 @@ class CoverTree
/**
* Create a cover tree from another tree. Be careful! This may use a lot of
- * memory and take a lot of time.
+ * memory and take a lot of time. This will also make a copy of the dataset.
*
* @param other Cover tree to copy from.
*/
@@ -215,7 +215,7 @@ class CoverTree
using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
//! Get a reference to the dataset.
- const MatType& Dataset() const { return dataset; }
+ const MatType& Dataset() const { return *dataset; }
//! Get the index of the point which this node represents.
size_t Point() const { return point; }
@@ -335,7 +335,7 @@ class CoverTree
//! Get the center of the node and store it in the given vector.
void Center(arma::vec& center) const
{
- center = arma::vec(dataset.col(point));
+ center = arma::vec(dataset->col(point));
}
//! Get the instantiated metric.
@@ -343,38 +343,29 @@ class CoverTree
private:
//! Reference to the matrix which this tree is built on.
- const MatType& dataset;
-
+ const MatType* dataset;
//! Index of the point in the matrix which this node represents.
size_t point;
-
//! The list of children; the first is the self-child.
std::vector<CoverTree*> children;
-
//! Scale level of the node.
int scale;
-
//! The base used to construct the tree.
double base;
-
//! The instantiated statistic.
StatisticType stat;
-
//! The number of descendant points.
size_t numDescendants;
-
//! The parent node (NULL if this is the root of the tree).
CoverTree* parent;
-
//! Distance to the parent.
double parentDistance;
-
//! Distance to the furthest descendant.
double furthestDescendantDistance;
-
//! Whether or not we need to destroy the metric in the destructor.
bool localMetric;
-
+ //! If true, we own the dataset and need to destroy it in the destructor.
+ bool localDataset;
//! The metric used for this tree.
MetricType* metric;
@@ -472,6 +463,12 @@ class CoverTree
*/
std::string ToString() const;
+ /**
+ * Serialize the tree.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
size_t DistanceComps() const { return distanceComps; }
size_t& DistanceComps() { return distanceComps; }
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index cd634d7..e9212fc 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -27,7 +27,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
const MatType& dataset,
const double base,
MetricType* metric) :
- dataset(dataset),
+ dataset(&dataset),
point(RootPointPolicy::ChooseRoot(dataset)),
scale(INT_MAX),
base(base),
@@ -36,6 +36,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
parentDistance(0),
furthestDescendantDistance(0),
localMetric(metric == NULL),
+ localDataset(false),
metric(metric),
distanceComps(0)
{
@@ -114,7 +115,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
const MatType& dataset,
MetricType& metric,
const double base) :
- dataset(dataset),
+ dataset(&dataset),
point(RootPointPolicy::ChooseRoot(dataset)),
scale(INT_MAX),
base(base),
@@ -123,6 +124,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
parentDistance(0),
furthestDescendantDistance(0),
localMetric(false),
+ localDataset(false),
metric(&metric),
distanceComps(0)
{
@@ -207,7 +209,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
size_t& farSetSize,
size_t& usedSetSize,
MetricType& metric) :
- dataset(dataset),
+ dataset(&dataset),
point(pointIndex),
scale(scale),
base(base),
@@ -216,6 +218,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
parentDistance(parentDistance),
furthestDescendantDistance(0),
localMetric(false),
+ localDataset(false),
metric(&metric),
distanceComps(0)
{
@@ -251,7 +254,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
const double parentDistance,
const double furthestDescendantDistance,
MetricType* metric) :
- dataset(dataset),
+ dataset(&dataset),
point(pointIndex),
scale(scale),
base(base),
@@ -260,6 +263,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
parentDistance(parentDistance),
furthestDescendantDistance(furthestDescendantDistance),
localMetric(metric == NULL),
+ localDataset(false),
metric(metric),
distanceComps(0)
{
@@ -279,7 +283,7 @@ template<
>
CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
const CoverTree& other) :
- dataset(other.dataset),
+ dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL),
point(other.point),
scale(other.scale),
base(other.base),
@@ -289,6 +293,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
parentDistance(other.parentDistance),
furthestDescendantDistance(other.furthestDescendantDistance),
localMetric(false),
+ localDataset(other.parent == NULL),
metric(other.metric),
distanceComps(0)
{
@@ -297,6 +302,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CoverTree(
{
children.push_back(new CoverTree(other.Child(i)));
children[i]->Parent() = this;
+ children[i]->dataset = this->dataset;
}
}
@@ -315,6 +321,10 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::~CoverTree()
// Delete the local metric, if necessary.
if (localMetric)
delete metric;
+
+ // Delete the local dataset, if necessary.
+ if (localDataset)
+ delete dataset;
}
//! Return the number of descendant points.
@@ -373,7 +383,7 @@ double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
MinDistance(const CoverTree* other) const
{
// Every cover tree node will contain points up to base^(scale + 1) away.
- return std::max(metric->Evaluate(dataset.col(point),
+ return std::max(metric->Evaluate(dataset->col(point),
other->Dataset().col(other->Point())) -
furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
}
@@ -401,7 +411,7 @@ template<
double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
MinDistance(const arma::vec& other) const
{
- return std::max(metric->Evaluate(dataset.col(point), other) -
+ return std::max(metric->Evaluate(dataset->col(point), other) -
furthestDescendantDistance, 0.0);
}
@@ -426,7 +436,7 @@ template<
double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
MaxDistance(const CoverTree* other) const
{
- return metric->Evaluate(dataset.col(point),
+ return metric->Evaluate(dataset->col(point),
other->Dataset().col(other->Point())) +
furthestDescendantDistance + other->FurthestDescendantDistance();
}
@@ -454,7 +464,7 @@ template<
double CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
MaxDistance(const arma::vec& other) const
{
- return metric->Evaluate(dataset.col(point), other) +
+ return metric->Evaluate(dataset->col(point), other) +
furthestDescendantDistance;
}
@@ -480,7 +490,7 @@ template<
math::Range CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
RangeDistance(const CoverTree* other) const
{
- const double distance = metric->Evaluate(dataset.col(point),
+ const double distance = metric->Evaluate(dataset->col(point),
other->Dataset().col(other->Point()));
math::Range result;
@@ -523,7 +533,7 @@ template<
math::Range CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
RangeDistance(const arma::vec& other) const
{
- const double distance = metric->Evaluate(dataset.col(point), other);
+ const double distance = metric->Evaluate(dataset->col(point), other);
return math::Range(distance - furthestDescendantDistance,
distance + furthestDescendantDistance);
@@ -575,7 +585,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
// Make the self child at the lowest possible level.
// This should not modify farSetSize or usedSetSize.
size_t tempSize = 0;
- children.push_back(new CoverTree(dataset, base, point, INT_MIN, this, 0,
+ children.push_back(new CoverTree(*dataset, base, point, INT_MIN, this, 0,
indices, distances, 0, tempSize, usedSetSize, *metric));
distanceComps += children.back()->DistanceComps();
@@ -583,7 +593,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
for (size_t i = 0; i < nearSetSize; ++i)
{
// farSetSize and usedSetSize will not be modified.
- children.push_back(new CoverTree(dataset, base, indices[i],
+ children.push_back(new CoverTree(*dataset, base, indices[i],
INT_MIN, this, distances[i], indices, distances, 0, tempSize,
usedSetSize, *metric));
distanceComps += children.back()->DistanceComps();
@@ -615,7 +625,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
// Build the self child (recursively).
size_t childFarSetSize = nearSetSize - childNearSetSize;
size_t childUsedSetSize = 0;
- children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
+ children.push_back(new CoverTree(*dataset, base, point, nextScale, this, 0,
indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
*metric));
// Don't double-count the self-child (so, subtract one).
@@ -673,7 +683,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
if ((nearSetSize == 1) && (farSetSize == 0))
{
size_t childNearSetSize = 0;
- children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+ children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
this, distances[0], indices, distances, childNearSetSize, farSetSize,
usedSetSize, *metric));
distanceComps += children.back()->DistanceComps();
@@ -715,7 +725,7 @@ CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::CreateChildren(
// Build this child (recursively).
childUsedSetSize = 1; // Mark self point as used.
- children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
+ children.push_back(new CoverTree(*dataset, base, indices[0], nextScale,
this, distances[0], childIndices, childDistances, childNearSetSize,
childFarSetSize, childUsedSetSize, *metric));
numDescendants += children.back()->NumDescendants();
@@ -816,8 +826,8 @@ void CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
distanceComps += pointSetSize;
for (size_t i = 0; i < pointSetSize; ++i)
{
- distances[i] = metric->Evaluate(dataset.col(pointIndex),
- dataset.col(indices[i]));
+ distances[i] = metric->Evaluate(dataset->col(pointIndex),
+ dataset->col(indices[i]));
}
}
@@ -1125,7 +1135,7 @@ std::string CoverTree<MetricType, StatisticType, MatType, RootPointPolicy>::
{
std::ostringstream convert;
convert << "CoverTree [" << this << "]" << std::endl;
- convert << " dataset: " << &dataset << std::endl;
+ convert << " dataset: " << dataset << std::endl;
convert << " point: " << point << std::endl;
convert << " scale: " << scale << std::endl;
convert << " base: " << base << std::endl;
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 782e10d..5f6b1c4 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1831,8 +1831,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
TreeType d = c;
- // Check that everything is the same.
- BOOST_REQUIRE_EQUAL(c.Dataset().memptr(), d.Dataset().memptr());
+ // Check that everything is the same, except the dataset, which should have
+ // been copied.
+ BOOST_REQUIRE_NE(c.Dataset().memptr(), d.Dataset().memptr());
BOOST_REQUIRE_CLOSE(c.Base(), d.Base(), 1e-50);
BOOST_REQUIRE_EQUAL(c.Point(), d.Point());
BOOST_REQUIRE_EQUAL(c.Scale(), d.Scale());
@@ -1850,8 +1851,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
BOOST_REQUIRE_EQUAL(d.Child(1).Parent(), &d);
// Check that the children are okay.
- BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(),
- d.Child(0).Dataset().memptr());
+ BOOST_REQUIRE_NE(c.Child(0).Dataset().memptr(),
+ d.Child(0).Dataset().memptr());
+ BOOST_REQUIRE_EQUAL(c.Child(0).Dataset().memptr(), c.Dataset().memptr());
BOOST_REQUIRE_CLOSE(c.Child(0).Base(), d.Child(0).Base(), 1e-50);
BOOST_REQUIRE_EQUAL(c.Child(0).Point(), d.Child(0).Point());
BOOST_REQUIRE_EQUAL(c.Child(0).Scale(), d.Child(0).Scale());
@@ -1860,8 +1862,9 @@ BOOST_AUTO_TEST_CASE(CoverTreeCopyConstructor)
d.Child(0).FurthestDescendantDistance());
BOOST_REQUIRE_EQUAL(c.Child(0).NumChildren(), d.Child(0).NumChildren());
- BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(),
- d.Child(1).Dataset().memptr());
+ BOOST_REQUIRE_NE(c.Child(1).Dataset().memptr(),
+ d.Child(1).Dataset().memptr());
+ BOOST_REQUIRE_EQUAL(c.Child(1).Dataset().memptr(), c.Dataset().memptr());
BOOST_REQUIRE_CLOSE(c.Child(1).Base(), d.Child(1).Base(), 1e-50);
BOOST_REQUIRE_EQUAL(c.Child(1).Point(), d.Child(1).Point());
BOOST_REQUIRE_EQUAL(c.Child(1).Scale(), d.Child(1).Scale());
More information about the mlpack-git
mailing list