[mlpack-svn] r14417 - mlpack/trunk/src/mlpack/core/tree/cover_tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 28 14:15:26 EST 2013
Author: rcurtin
Date: 2013-02-28 14:15:25 -0500 (Thu, 28 Feb 2013)
New Revision: 14417
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
Log:
Add Metric() function, and refactor a little bit. The metric may be stored
locally if none is given.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-02-28 19:14:56 UTC (rev 14416)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp 2013-02-28 19:15:25 UTC (rev 14417)
@@ -160,6 +160,7 @@
* @param parent Parent node (NULL indicates no parent).
* @param parentDistance Distance to parent node point.
* @param furthestDescendantDistance Distance to furthest descendant point.
+ * @param metric Instantiated metric (optional).
*/
CoverTree(const arma::mat& dataset,
const double base,
@@ -167,7 +168,8 @@
const int scale,
CoverTree* parent,
const double parentDistance,
- const double furthestDescendantDistance);
+ const double furthestDescendantDistance,
+ MetricType* metric = NULL);
/**
* Create a cover tree from another tree. Be careful! This may use a lot of
@@ -311,6 +313,12 @@
//! Distance to the furthest descendant.
double furthestDescendantDistance;
+ //! Whether or not we need to destroy the metric in the destructor.
+ bool localMetric;
+
+ //! The metric used for this tree.
+ MetricType* metric;
+
/**
* Fill the vector of distances with the distances between the point specified
* by pointIndex and each point in the indices array. The distances of the
@@ -325,8 +333,7 @@
void ComputeDistances(const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
- const size_t pointSetSize,
- MetricType& metric);
+ const size_t pointSetSize);
/**
* Split the given indices and distances into a near and a far set, returning
* the number of points in the near set. The distances must already be
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-02-28 19:14:56 UTC (rev 14416)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp 2013-02-28 19:15:25 UTC (rev 14417)
@@ -27,15 +27,13 @@
base(base),
parent(NULL),
parentDistance(0),
- furthestDescendantDistance(0)
+ furthestDescendantDistance(0),
+ localMetric(metric == NULL),
+ metric(metric)
{
// If we need to create a metric, do that. We'll just do it on the heap.
- bool localMetric = false;
- if (metric == NULL)
- {
- localMetric = true; // So we know we need to free it.
- metric = new MetricType();
- }
+ if (localMetric)
+ this->metric = new MetricType();
// Kick off the building. Create the indices array and the distances array.
arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
@@ -48,7 +46,7 @@
arma::vec distances(dataset.n_cols - 1);
// Build the initial distances.
- ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1);
// Now determine the scale factor of the root node.
const double maxDistance = max(distances);
@@ -139,7 +137,7 @@
// Build distances for the child.
ComputeDistances(indices[0], childIndices, childDistances,
- nearSetSize - 1, *metric);
+ nearSetSize - 1);
childDistances(nearSetSize - 1) = 0;
// Split into near and far sets for this point.
@@ -191,9 +189,6 @@
Log::Assert(furthestDescendantDistance <= pow(base, scale + 1));
- if (localMetric)
- delete metric;
-
// Initialize statistic.
stat = StatisticType(*this);
}
@@ -218,7 +213,9 @@
base(base),
parent(parent),
parentDistance(parentDistance),
- furthestDescendantDistance(0)
+ furthestDescendantDistance(0),
+ localMetric(false),
+ metric(&metric)
{
// If the size of the near set is 0, this is a leaf.
if (nearSetSize == 0)
@@ -376,7 +373,7 @@
// Build distances for the child.
ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
- + farSetSize - 1, metric);
+ + farSetSize - 1);
// Split into near and far sets for this point.
childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
@@ -449,15 +446,22 @@
const int scale,
CoverTree* parent,
const double parentDistance,
- const double furthestDescendantDistance) :
+ const double furthestDescendantDistance,
+ MetricType* metric) :
dataset(dataset),
point(pointIndex),
scale(scale),
base(base),
parent(parent),
parentDistance(parentDistance),
- furthestDescendantDistance(furthestDescendantDistance)
+ furthestDescendantDistance(furthestDescendantDistance),
+ localMetric(metric == NULL),
+ metric(metric)
{
+ // If necessary, create a local metric.
+ if (localMetric)
+ this->metric = new MetricType();
+
// Initialize the statistic.
stat = StatisticType(*this);
}
@@ -472,7 +476,9 @@
stat(other.stat),
parent(other.parent),
parentDistance(other.parentDistance),
- furthestDescendantDistance(other.furthestDescendantDistance)
+ furthestDescendantDistance(other.furthestDescendantDistance),
+ localMetric(false),
+ metric(other.metric)
{
// Copy each child by hand.
for (size_t i = 0; i < other.NumChildren(); ++i)
@@ -488,6 +494,10 @@
// Delete each child.
for (size_t i = 0; i < children.size(); ++i)
delete children[i];
+
+ // Delete the local metric, if necessary.
+ if (localMetric)
+ delete metric;
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -495,7 +505,7 @@
const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
{
// Every cover tree node will contain points up to EC^(scale + 1) away.
- return std::max(MetricType::Evaluate(dataset.unsafe_col(point),
+ return std::max(metric->Evaluate(dataset.unsafe_col(point),
other->Dataset().unsafe_col(other->Point())) -
furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
}
@@ -514,7 +524,7 @@
double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
const arma::vec& other) const
{
- return std::max(MetricType::Evaluate(dataset.unsafe_col(point), other) -
+ return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
furthestDescendantDistance, 0.0);
}
@@ -530,7 +540,7 @@
double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
{
- return MetricType::Evaluate(dataset.unsafe_col(point),
+ return metric->Evaluate(dataset.unsafe_col(point),
other->Dataset().unsafe_col(other->Point())) +
furthestDescendantDistance + other->FurthestDescendantDistance();
}
@@ -549,7 +559,7 @@
double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
const arma::vec& other) const
{
- return MetricType::Evaluate(dataset.unsafe_col(point), other) +
+ return metric->Evaluate(dataset.unsafe_col(point), other) +
furthestDescendantDistance;
}
@@ -619,14 +629,13 @@
const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
- const size_t pointSetSize,
- MetricType& metric)
+ const size_t pointSetSize)
{
// For each point, rebuild the distances. The indices do not need to be
// modified.
for (size_t i = 0; i < pointSetSize; ++i)
{
- distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
+ distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
dataset.unsafe_col(indices[i]));
}
}
More information about the mlpack-svn
mailing list