[mlpack-svn] r15035 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 8 17:22:20 EDT 2013
Author: rcurtin
Date: 2013-05-08 17:22:19 -0400 (Wed, 08 May 2013)
New Revision: 15035
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
Log:
Add Descendant() and NumDescendants() functions.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-05-08 19:26:14 UTC (rev 15034)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-05-08 21:22:19 UTC (rev 15035)
@@ -238,6 +238,12 @@
//! Modify the children manually (maybe not a great idea).
std::vector<CoverTree*>& Children() { return children; }
+ //! Get the number of descendant points.
+ size_t NumDescendants() const;
+
+ //! Get the index of a particular descendant point.
+ size_t Descendant(const size_t index) const;
+
//! Get the scale of this node.
int Scale() const { return scale; }
//! Modify the scale of this node. Be careful...
@@ -341,6 +347,9 @@
//! The instantiated statistic.
StatisticType stat;
+ //! The number of descendant points.
+ size_t numDescendants;
+
//! The parent node (NULL if this is the root of the tree).
CoverTree* parent;
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-05-08 19:26:14 UTC (rev 15034)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-05-08 21:22:19 UTC (rev 15035)
@@ -26,6 +26,7 @@
point(RootPointPolicy::ChooseRoot(dataset)),
scale(INT_MAX),
base(base),
+ numDescendants(0),
parent(NULL),
parentDistance(0),
furthestDescendantDistance(0),
@@ -76,6 +77,7 @@
point(RootPointPolicy::ChooseRoot(dataset)),
scale(INT_MAX),
base(base),
+ numDescendants(0),
parent(NULL),
parentDistance(0),
furthestDescendantDistance(0),
@@ -131,6 +133,7 @@
point(pointIndex),
scale(scale),
base(base),
+ numDescendants(0),
parent(parent),
parentDistance(parentDistance),
furthestDescendantDistance(0),
@@ -141,6 +144,7 @@
if (nearSetSize == 0)
{
this->scale = INT_MIN;
+ numDescendants = 1;
stat = StatisticType(*this);
return;
}
@@ -152,7 +156,234 @@
stat = StatisticType(*this);
}
+// Manually create a cover tree node.
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double base,
+ const size_t pointIndex,
+ const int scale,
+ CoverTree* parent,
+ const double parentDistance,
+ const double furthestDescendantDistance,
+ MetricType* metric) :
+ dataset(dataset),
+ point(pointIndex),
+ scale(scale),
+ base(base),
+ numDescendants(0),
+ parent(parent),
+ parentDistance(parentDistance),
+ furthestDescendantDistance(furthestDescendantDistance),
+ localMetric(metric == NULL),
+ metric(metric)
+{
+ // If necessary, create a local metric.
+ if (localMetric)
+ this->metric = new MetricType();
+
+ // Initialize the statistic.
+ stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+ const CoverTree& other) :
+ dataset(other.dataset),
+ point(other.point),
+ scale(other.scale),
+ base(other.base),
+ stat(other.stat),
+ numDescendants(other.numDescendants),
+ parent(other.parent),
+ parentDistance(other.parentDistance),
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ localMetric(false),
+ metric(other.metric)
+{
+ // Copy each child by hand.
+ for (size_t i = 0; i < other.NumChildren(); ++i)
+ {
+ children.push_back(new CoverTree(other.Child(i)));
+ children[i]->Parent() = this;
+ }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+{
+ // Delete each child.
+ for (size_t i = 0; i < children.size(); ++i)
+ delete children[i];
+
+ // Delete the local metric, if necessary.
+ if (localMetric)
+ delete metric;
+}
+
+//! Return the number of descendant points.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+inline size_t
+CoverTree<MetricType, RootPointPolicy, StatisticType>::NumDescendants() const
+{
+ return numDescendants;
+}
+
+//! Return the index of a particular descendant point.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+inline size_t
+CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
+ const size_t index) const
+{
+ // The first descendant is the point contained within this node.
+ if (index == 0)
+ return point;
+
+ // Is it in the self-child?
+ if (index < children[0]->NumDescendants())
+ return children[0]->Descendant(index);
+
+ // Now check the other children.
+ size_t sum = children[0]->NumDescendants();
+ for (size_t i = 1; i < children.size(); ++i)
+ {
+ if (index - sum < children[i]->NumDescendants())
+ return children[i]->Descendant(index - sum);
+ sum += children[i]->NumDescendants();
+ }
+
+ // This should never happen.
+ return (size_t() - 1);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ // Every cover tree node will contain points up to EC^(scale + 1) away.
+ return std::max(metric->Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) -
+ furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return std::max(distance - furthestDescendantDistance -
+ other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& other) const
+{
+ return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
+ furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return std::max(distance - furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+ return metric->Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point())) +
+ furthestDescendantDistance + other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+ const double distance) const
+{
+ // We already have the distance as evaluated by the metric.
+ return distance + furthestDescendantDistance +
+ other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& other) const
+{
+ return metric->Evaluate(dataset.unsafe_col(point), other) +
+ furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+ const arma::vec& /* other */,
+ const double distance) const
+{
+ return distance + furthestDescendantDistance;
+}
+
+//! Return the minimum and maximum distance to another node.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+ RangeDistance(const CoverTree* other) const
+{
+ const double distance = metric->Evaluate(dataset.unsafe_col(point),
+ other->Dataset().unsafe_col(other->Point()));
+
+ math::Range result;
+ result.Lo() = distance - furthestDescendantDistance -
+ other->FurthestDescendantDistance();
+ result.Hi() = distance + furthestDescendantDistance +
+ other->FurthestDescendantDistance();
+
+ return result;
+}
+
+//! Return the minimum and maximum distance to another node given that the
+//! point-to-point distance has already been calculated.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+ RangeDistance(const CoverTree* other,
+ const double distance) const
+{
+ math::Range result;
+ result.Lo() = distance - furthestDescendantDistance -
+ other->FurthestDescendantDistance();
+ result.Hi() = distance + furthestDescendantDistance +
+ other->FurthestDescendantDistance();
+
+ return result;
+}
+
+//! Return the minimum and maximum distance to another point.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+ RangeDistance(const arma::vec& other) const
+{
+ const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
+
+ return math::Range(distance - furthestDescendantDistance,
+ distance + furthestDescendantDistance);
+}
+
+//! Return the minimum and maximum distance to another point given that the
+//! point-to-point distance has already been calculated.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+ RangeDistance(const arma::vec& other,
+ const double distance) const
+{
+ return math::Range(distance - furthestDescendantDistance,
+ distance + furthestDescendantDistance);
+}
+
+//! For a newly initialized node, create children using the near and far set.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
inline void
CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
arma::Col<size_t>& indices,
@@ -189,15 +420,16 @@
usedSetSize++;
}
+ // The number of descendants is just the number of children, because each of
+ // them are leaves and contain one point.
+ numDescendants = children.size();
+
// Re-sort the dataset. We have
// [ used | far | other used ]
// and we want
// [ far | all used ].
SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
- // Initialize the statistic.
- stat = StatisticType(*this);
-
return;
}
@@ -216,6 +448,8 @@
children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
*metric));
+ // Don't double-count the self-child (so, subtract one).
+ numDescendants += children[0]->NumDescendants();
// The self-child can't modify the furthestChildDistance away from 0, but it
// can modify the furthestDescendantDistance.
@@ -287,6 +521,7 @@
children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
this, distances[0], indices, distances, childNearSetSize, farSetSize,
usedSetSize, *metric));
+ numDescendants += children[children.size() - 1]->NumDescendants();
// Because the far set size is 0, we don't have to do any swapping to
// move the point into the used set.
@@ -327,6 +562,7 @@
children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
this, distances[0], childIndices, childDistances, childNearSetSize,
childFarSetSize, childUsedSetSize, *metric));
+ numDescendants += children[children.size() - 1]->NumDescendants();
// If we created an implicit node, take its self-child instead (this could
// happen multiple times).
@@ -365,196 +601,7 @@
furthestDescendantDistance = distances[i];
}
-// Manually create a cover tree node.
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const arma::mat& dataset,
- const double base,
- const size_t pointIndex,
- const int scale,
- CoverTree* parent,
- const double parentDistance,
- const double furthestDescendantDistance,
- MetricType* metric) :
- dataset(dataset),
- point(pointIndex),
- scale(scale),
- base(base),
- parent(parent),
- parentDistance(parentDistance),
- furthestDescendantDistance(furthestDescendantDistance),
- localMetric(metric == NULL),
- metric(metric)
-{
- // If necessary, create a local metric.
- if (localMetric)
- this->metric = new MetricType();
-
- // Initialize the statistic.
- stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
- const CoverTree& other) :
- dataset(other.dataset),
- point(other.point),
- scale(other.scale),
- base(other.base),
- stat(other.stat),
- parent(other.parent),
- parentDistance(other.parentDistance),
- furthestDescendantDistance(other.furthestDescendantDistance),
- localMetric(false),
- metric(other.metric)
-{
- // Copy each child by hand.
- for (size_t i = 0; i < other.NumChildren(); ++i)
- {
- children.push_back(new CoverTree(other.Child(i)));
- children[i]->Parent() = this;
- }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
-{
- // Delete each child.
- for (size_t i = 0; i < children.size(); ++i)
- delete children[i];
-
- // Delete the local metric, if necessary.
- if (localMetric)
- delete metric;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- // Every cover tree node will contain points up to EC^(scale + 1) away.
- return std::max(metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) -
- furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return std::max(distance - furthestDescendantDistance -
- other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& other) const
-{
- return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
- furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return std::max(distance - furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
- return metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point())) +
- furthestDescendantDistance + other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
- const double distance) const
-{
- // We already have the distance as evaluated by the metric.
- return distance + furthestDescendantDistance +
- other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& other) const
-{
- return metric->Evaluate(dataset.unsafe_col(point), other) +
- furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
- const arma::vec& /* other */,
- const double distance) const
-{
- return distance + furthestDescendantDistance;
-}
-
-//! Return the minimum and maximum distance to another node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
- RangeDistance(const CoverTree* other) const
-{
- const double distance = metric->Evaluate(dataset.unsafe_col(point),
- other->Dataset().unsafe_col(other->Point()));
-
- math::Range result;
- result.Lo() = distance - furthestDescendantDistance -
- other->FurthestDescendantDistance();
- result.Hi() = distance + furthestDescendantDistance +
- other->FurthestDescendantDistance();
-
- return result;
-}
-
-//! Return the minimum and maximum distance to another node given that the
-//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
- RangeDistance(const CoverTree* other,
- const double distance) const
-{
- math::Range result;
- result.Lo() = distance - furthestDescendantDistance -
- other->FurthestDescendantDistance();
- result.Hi() = distance + furthestDescendantDistance +
- other->FurthestDescendantDistance();
-
- return result;
-}
-
-//! Return the minimum and maximum distance to another point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
- RangeDistance(const arma::vec& other) const
-{
- const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
-
- return math::Range(distance - furthestDescendantDistance,
- distance + furthestDescendantDistance);
-}
-
-//! Return the minimum and maximum distance to another point given that the
-//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
- RangeDistance(const arma::vec& other,
- const double distance) const
-{
- return math::Range(distance - furthestDescendantDistance,
- distance + furthestDescendantDistance);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
arma::Col<size_t>& indices,
arma::vec& distances,
@@ -865,7 +912,8 @@
* Returns a string representation of this object.
*/
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString() const
+std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString()
+ const
{
std::ostringstream convert;
convert << "CoverTree [" << this << "]" << std::endl;
More information about the mlpack-svn
mailing list