[mlpack-git] master: Refactor CoverTree to allow sparse datasets. (90a9d93)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Mar 31 11:25:55 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0c14444f7fde2a24c230d3536b544a3b5c1e1658...90a9d939482856608467bf174f824bc441ec2c21
>---------------------------------------------------------------
commit 90a9d939482856608467bf174f824bc441ec2c21
Author: ryan <ryan at ratml.org>
Date: Tue Mar 31 11:25:31 2015 -0400
Refactor CoverTree to allow sparse datasets.
>---------------------------------------------------------------
90a9d939482856608467bf174f824bc441ec2c21
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 19 +-
.../core/tree/cover_tree/cover_tree_impl.hpp | 357 ++++++++++++++-------
.../core/tree/cover_tree/dual_tree_traverser.hpp | 13 +-
.../tree/cover_tree/dual_tree_traverser_impl.hpp | 52 ++-
.../core/tree/cover_tree/first_point_is_root.hpp | 3 +-
.../core/tree/cover_tree/single_tree_traverser.hpp | 10 +-
.../tree/cover_tree/single_tree_traverser_impl.hpp | 31 +-
src/mlpack/tests/tree_test.cpp | 36 ++-
8 files changed, 373 insertions(+), 148 deletions(-)
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 8536718..547d3f7 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -81,14 +81,17 @@ namespace tree {
* @tparam MetricType Metric type to use during tree construction.
* @tparam RootPointPolicy Determines which point to use as the root node.
* @tparam StatisticType Statistic to be used during tree creation.
+ * @tparam MatType Type of matrix to build the tree on (generally mat or
+ * sp_mat).
*/
template<typename MetricType = metric::LMetric<2, true>,
typename RootPointPolicy = FirstPointIsRoot,
- typename StatisticType = EmptyStatistic>
+ typename StatisticType = EmptyStatistic,
+ typename MatType = arma::mat>
class CoverTree
{
public:
- typedef arma::mat Mat;
+ typedef MatType Mat;
/**
* Create the cover tree with the given dataset and given base.
@@ -100,7 +103,7 @@ class CoverTree
* @param dataset Reference to the dataset to build a tree on.
* @param base Base to use during tree building (default 2.0).
*/
- CoverTree(const arma::mat& dataset,
+ CoverTree(const MatType& dataset,
const double base = 2.0,
MetricType* metric = NULL);
@@ -113,7 +116,7 @@ class CoverTree
* @param metric Instantiated metric to use during tree building.
* @param base Base to use during tree building (default 2.0).
*/
- CoverTree(const arma::mat& dataset,
+ CoverTree(const MatType& dataset,
MetricType& metric,
const double base = 2.0);
@@ -148,7 +151,7 @@ class CoverTree
* any points in the far set).
* @param usedSetSize The number of points used will be added to this number.
*/
- CoverTree(const arma::mat& dataset,
+ CoverTree(const MatType& dataset,
const double base,
const size_t pointIndex,
const int scale,
@@ -177,7 +180,7 @@ class CoverTree
* @param furthestDescendantDistance Distance to furthest descendant point.
* @param metric Instantiated metric (optional).
*/
- CoverTree(const arma::mat& dataset,
+ CoverTree(const MatType& dataset,
const double base,
const size_t pointIndex,
const int scale,
@@ -212,7 +215,7 @@ class CoverTree
using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
//! Get a reference to the dataset.
- const arma::mat& Dataset() const { return dataset; }
+ const MatType& Dataset() const { return dataset; }
//! Get the index of the point which this node represents.
size_t Point() const { return point; }
@@ -337,7 +340,7 @@ class CoverTree
private:
//! Reference to the matrix which this tree is built on.
- const arma::mat& dataset;
+ const MatType& dataset;
//! Index of the point in the matrix which this node represents.
size_t point;
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index ade0eb4..24ed095 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -17,9 +17,14 @@ namespace mlpack {
namespace tree {
// Create the cover tree.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+ const MatType& dataset,
const double base,
MetricType* metric) :
dataset(dataset),
@@ -98,9 +103,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
<< "construction." << std::endl;
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+ const MatType& dataset,
MetricType& metric,
const double base) :
dataset(dataset),
@@ -175,9 +185,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
<< "construction." << std::endl;
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+ const MatType& dataset,
const double base,
const size_t pointIndex,
const int scale,
@@ -218,9 +233,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
}
// Manually create a cover tree node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+ const MatType& dataset,
const double base,
const size_t pointIndex,
const int scale,
@@ -248,8 +268,13 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
stat = StatisticType(*this);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
const CoverTree& other) :
dataset(other.dataset),
point(other.point),
@@ -272,8 +297,13 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
}
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::~CoverTree()
{
// Delete each child.
for (size_t i = 0; i < children.size(); ++i)
@@ -285,17 +315,28 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
}
//! Return the number of descendant points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
inline size_t
-CoverTree<MetricType, RootPointPolicy, StatisticType>::NumDescendants() const
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ NumDescendants() const
{
return numDescendants;
}
//! Return the index of a particular descendant point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
inline size_t
-CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::Descendant(
const size_t index) const
{
// The first descendant is the point contained within this node.
@@ -319,84 +360,125 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
return (size_t() - 1);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MinDistance(const CoverTree* other) const
{
// Every cover tree node will contain points up to base^(scale + 1) away.
- return std::max(metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) -
+ return std::max(metric->Evaluate(dataset.col(point),
+ other->Dataset().col(other->Point())) -
furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MinDistance(const CoverTree* other, const double distance) const
{
// We already have the distance as evaluated by the metric.
return std::max(distance - furthestDescendantDistance -
other->FurthestDescendantDistance(), 0.0);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& other) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MinDistance(const arma::vec& other) const
{
- return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
+ return std::max(metric->Evaluate(dataset.col(point), other) -
furthestDescendantDistance, 0.0);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& /* other */,
- const double distance) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MinDistance(const arma::vec& /* other */, const double distance) const
{
return std::max(distance - furthestDescendantDistance, 0.0);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MaxDistance(const CoverTree* other) const
{
- return metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) +
+ return metric->Evaluate(dataset.col(point),
+ other->Dataset().col(other->Point())) +
furthestDescendantDistance + other->FurthestDescendantDistance();
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MaxDistance(const CoverTree* other, const double distance) const
{
// We already have the distance as evaluated by the metric.
return distance + furthestDescendantDistance +
other->FurthestDescendantDistance();
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& other) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MaxDistance(const arma::vec& other) const
{
- return metric->Evaluate(dataset.unsafe_col(point), other) +
+ return metric->Evaluate(dataset.col(point), other) +
furthestDescendantDistance;
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& /* other */,
- const double distance) const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MaxDistance(const arma::vec& /* other */, const double distance) const
{
return distance + furthestDescendantDistance;
}
//! Return the minimum and maximum distance to another node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
RangeDistance(const CoverTree* other) const
{
- const double distance = metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point()));
+ const double distance = metric->Evaluate(dataset.col(point),
+ other->Dataset().col(other->Point()));
math::Range result;
result.Lo() = distance - furthestDescendantDistance -
@@ -409,8 +491,13 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
//! Return the minimum and maximum distance to another node given that the
//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
RangeDistance(const CoverTree* other,
const double distance) const
{
@@ -424,11 +511,16 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
}
//! Return the minimum and maximum distance to another point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
RangeDistance(const arma::vec& other) const
{
- const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
+ const double distance = metric->Evaluate(dataset.col(point), other);
return math::Range(distance - furthestDescendantDistance,
distance + furthestDescendantDistance);
@@ -436,8 +528,13 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
//! Return the minimum and maximum distance to another point given that the
//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
RangeDistance(const arma::vec& /* other */,
const double distance) const
{
@@ -446,9 +543,14 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
}
//! For a newly initialized node, create children using the near and far set.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
inline void
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CreateChildren(
arma::Col<size_t>& indices,
arma::vec& distances,
size_t nearSetSize,
@@ -636,12 +738,17 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
furthestDescendantDistance = distances[i];
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize)
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ SplitNearFar(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize)
{
// Sanity check; there is no guarantee that this condition will not be true.
// ...or is there?
@@ -689,30 +796,40 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
}
// Returns the maximum distance between points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
- const size_t pointIndex,
- const arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t pointSetSize)
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ ComputeDistances(const size_t pointIndex,
+ const arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t pointSetSize)
{
// For each point, rebuild the distances. The indices do not need to be
// modified.
distanceComps += pointSetSize;
for (size_t i = 0; i < pointSetSize; ++i)
{
- distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
- dataset.unsafe_col(indices[i]));
+ distances[i] = metric->Evaluate(dataset.col(pointIndex),
+ dataset.col(indices[i]));
}
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize)
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ SortPointSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize)
{
// We'll use low-level memcpy calls ourselves, just to ensure it's done
// quickly and the way we want it to be. Unfortunately this takes up more
@@ -766,16 +883,21 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
return (childFarSetSize + farSetSize);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t& nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize,
- arma::Col<size_t>& childIndices,
- const size_t childFarSetSize, // childNearSetSize is 0 in this case.
- const size_t childUsedSetSize)
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ MoveToUsedSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t& nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize,
+ arma::Col<size_t>& childIndices,
+ const size_t childFarSetSize, // childNearSetSize is 0 here.
+ const size_t childUsedSetSize)
{
const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
@@ -907,13 +1029,18 @@ void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
- arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t nearSetSize,
- const size_t pointSetSize)
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ PruneFarSet(arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t nearSetSize,
+ const size_t pointSetSize)
{
// What we are trying to do is remove any points greater than the bound from
// the far set. We don't care what happens to those indices and distances...
@@ -948,8 +1075,13 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
* Take a look at the last child (the most recently created one) and remove any
* implicit nodes that have been created.
*/
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-inline void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+inline void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
RemoveNewImplicitNodes()
{
// If we created an implicit node, take its self-child instead (this could
@@ -978,9 +1110,14 @@ inline void CoverTree<MetricType, RootPointPolicy, StatisticType>::
/**
* Returns a string representation of this object.
*/
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString()
- const
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
+std::string CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ ToString() const
{
std::ostringstream convert;
convert << "CoverTree [" << this << "]" << std::endl;
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
index e09d2d0..963099c 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
@@ -13,9 +13,15 @@
namespace mlpack {
namespace tree {
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
+class CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ DualTreeTraverser
{
public:
/**
@@ -53,7 +59,8 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
struct DualCoverTreeMapEntry
{
//! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+ CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>*
+ referenceNode;
//! The score of the node.
double score;
//! The base case.
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
index 45849ab..aa28537 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
@@ -13,20 +13,29 @@
namespace mlpack {
namespace tree {
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
rule(rule),
numPrunes(0)
{ /* Nothing to do. */ }
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
- CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(CoverTree& queryNode,
+ CoverTree& referenceNode)
{
// Start by creating a map and adding the reference root node to it.
std::map<int, std::vector<DualCoverTreeMapEntry> > refMap;
@@ -46,11 +55,16 @@ DualTreeTraverser<RuleType>::Traverse(
Traverse(queryNode, refMap);
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
DualTreeTraverser<RuleType>::Traverse(
- CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+ CoverTree& queryNode,
std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
{
if (referenceMap.size() == 0)
@@ -128,9 +142,14 @@ DualTreeTraverser<RuleType>::Traverse(
}
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
DualTreeTraverser<RuleType>::PruneMap(
CoverTree& queryNode,
std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap,
@@ -246,9 +265,14 @@ DualTreeTraverser<RuleType>::PruneMap(
}
}
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
DualTreeTraverser<RuleType>::ReferenceRecursion(
CoverTree& queryNode,
std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
diff --git a/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp b/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
index 5e8979c..18b449d 100644
--- a/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
+++ b/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
@@ -28,7 +28,8 @@ class FirstPointIsRoot
* Return the point to be used as the root point of the cover tree. This just
* returns 0.
*/
- static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
+ template<typename MatType>
+ static size_t ChooseRoot(const MatType& /* dataset */) { return 0; }
};
}; // namespace tree
diff --git a/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
index c767f8c..62aefe5 100644
--- a/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
@@ -16,9 +16,15 @@
namespace mlpack {
namespace tree {
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::SingleTreeTraverser
+class CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+ SingleTreeTraverser
{
public:
/**
diff --git a/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
index 3b60a93..3a49cfb 100644
--- a/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
@@ -17,11 +17,16 @@ namespace mlpack {
namespace tree {
//! This is the structure the cover tree map will use for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
struct CoverTreeMapEntry
{
//! The node this entry refers to.
- CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+ CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>* node;
//! The score of the node.
double score;
//! The index of the parent node.
@@ -36,24 +41,34 @@ struct CoverTreeMapEntry
}
};
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
rule(rule),
numPrunes(0)
{ /* Nothing to do. */ }
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+ typename MetricType,
+ typename RootPointPolicy,
+ typename StatisticType,
+ typename MatType
+>
template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
SingleTreeTraverser<RuleType>::Traverse(
const size_t queryIndex,
- CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+ CoverTree& referenceNode)
{
// This is a non-recursive implementation (which should be faster than a
// recursive implementation).
- typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+ typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType, MatType>
MapEntryType;
// We will use this map as a priority queue. Each key represents the scale,
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 6125a14..e842d93 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1573,7 +1573,7 @@ void CheckCovering(const TreeType& node)
if (node.NumChildren() == 0)
return;
- const arma::mat& dataset = node.Dataset();
+ const typename TreeType::Mat& dataset = node.Dataset();
const size_t nodePoint = node.Point();
// To ensure that this node satisfies the covering principle, we must ensure
@@ -1620,7 +1620,7 @@ void CheckIndividualSeparation(const TreeType& constantNode,
return;
// Now we know we are at the same scale, so make the comparison.
- const arma::mat& dataset = constantNode.Dataset();
+ const typename TreeType::Mat& dataset = constantNode.Dataset();
const size_t constantPoint = constantNode.Point();
const size_t nodePoint = node.Point();
@@ -1735,6 +1735,38 @@ BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
}
/**
+ * Create a cover tree on sparse data and make sure it's accurate.
+ */
+BOOST_AUTO_TEST_CASE(SparseCoverTreeConstructionTest)
+{
+ arma::sp_mat dataset;
+ // 50-dimensional, 1000 point.
+ dataset.sprandu(50, 1000, 0.3);
+
+ typedef CoverTree<EuclideanDistance, FirstPointIsRoot, EmptyStatistic,
+ arma::sp_mat> TreeType;
+ TreeType tree(dataset);
+
+ // Ensure each leaf is only created once.
+ arma::vec counts;
+ counts.zeros(1000);
+ RecurseTreeCountLeaves(tree, counts);
+
+ for (size_t i = 0; i < 1000; ++i)
+ BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+ // Each non-leaf should have a self-child.
+ CheckSelfChild<TreeType>(tree);
+
+ // Each node must satisfy the covering principle (its children must be less
+ // than or equal to a certain distance apart).
+ CheckCovering<TreeType, LMetric<2, true> >(tree);
+
+ // Each node's children must be separated by at least a certain value.
+ CheckSeparation<TreeType, LMetric<2, true> >(tree, tree);
+}
+
+/**
* Test the manual constructor.
*/
BOOST_AUTO_TEST_CASE(CoverTreeManualConstructorTest)
More information about the mlpack-git
mailing list