[mlpack-svn] r12943 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jun 4 19:10:37 EDT 2012
Author: rcurtin
Date: 2012-06-04 19:10:37 -0400 (Mon, 04 Jun 2012)
New Revision: 12943
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Clean up cover tree implementation a little bit. Allow instantiated metrics
(good for hyptan or gaussian kernels or any other kernels with parameters).
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-06-04 21:00:04 UTC (rev 12942)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-06-04 23:10:37 UTC (rev 12943)
@@ -103,7 +103,8 @@
* building (default 2.0).
*/
CoverTree(const arma::mat& dataset,
- const double expansionConstant = 2.0);
+ const double expansionConstant = 2.0,
+ MetricType* metric = NULL);
/**
* Construct a child cover tree node. This constructor is not meant to be
@@ -144,7 +145,8 @@
arma::vec& distances,
size_t nearSetSize,
size_t& farSetSize,
- size_t& usedSetSize);
+ size_t& usedSetSize,
+ MetricType& metric = NULL);
/**
* Delete this cover tree node and its children.
@@ -266,9 +268,6 @@
//! The instantiated statistic.
StatisticType stat;
- //! The instantiated metric. Either the user passes it in, or we build it.
- MetricType* metric;
-
//! Distance to the parent.
double parentDistance;
@@ -289,7 +288,8 @@
void ComputeDistances(const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
- const size_t pointSetSize);
+ const size_t pointSetSize,
+ MetricType& metric);
/**
* Split the given indices and distances into a near and a far set, returning
* the number of points in the near set. The distances must already be
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-06-04 21:00:04 UTC (rev 12942)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-06-04 23:10:37 UTC (rev 12943)
@@ -17,13 +17,22 @@
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
const arma::mat& dataset,
- const double expansionConstant) :
+ const double expansionConstant,
+ MetricType* metric) :
dataset(dataset),
point(RootPointPolicy::ChooseRoot(dataset)),
expansionConstant(expansionConstant),
parentDistance(0),
furthestDescendantDistance(0)
{
+ // If we need to create a metric, do that. We'll just do it on the heap.
+ bool localMetric = false;
+ if (metric == NULL)
+ {
+ localMetric = true; // So we know we need to free it.
+ metric = new MetricType();
+ }
+
// Kick off the building. Create the indices array and the distances array.
arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
dataset.n_cols - 1, dataset.n_cols - 1);
@@ -35,7 +44,7 @@
arma::vec distances(dataset.n_cols - 1);
// Build the initial distances.
- ComputeDistances(point, indices, distances, dataset.n_cols - 1);
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
// Now determine the scale factor of the root node.
const double maxDistance = max(distances);
@@ -50,7 +59,8 @@
size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
size_t usedSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
- 0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize));
+ 0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize,
+ *metric));
furthestDescendantDistance = children[0]->FurthestDescendantDistance();
@@ -80,7 +90,7 @@
{
// We want to select the furthest point in the near set as the next child.
size_t newPointIndex = nearSetSize - 1;
-
+
// Swap to front if necessary.
if (newPointIndex != 0)
{
@@ -106,7 +116,7 @@
size_t childFarSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant,
indices[0], scale - 1, distances[0], indices, distances,
- childNearSetSize, childFarSetSize, usedSetSize));
+ childNearSetSize, childFarSetSize, usedSetSize, *metric));
// And we're done.
break;
@@ -122,7 +132,7 @@
// Build distances for the child.
ComputeDistances(indices[0], childIndices, childDistances,
- nearSetSize - 1);
+ nearSetSize - 1, *metric);
childDistances(nearSetSize - 1) = 0;
// Split into near and far sets for this point.
@@ -134,7 +144,7 @@
childFarSetSize = ((nearSetSize - 1) - childNearSetSize);
children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
scale - 1, distances[0], childIndices, childDistances, childNearSetSize,
- childFarSetSize, childUsedSetSize));
+ childFarSetSize, childUsedSetSize, *metric));
// If we created an implicit node, take its self-child instead (this could
// happen multiple times).
@@ -163,13 +173,16 @@
MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
childIndices, childFarSetSize, childUsedSetSize);
}
-
+
// Calculate furthest descendant.
for (size_t i = 0; i < usedSetSize; ++i)
if (distances[i] > furthestDescendantDistance)
furthestDescendantDistance = distances[i];
Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
+
+ if (localMetric)
+ delete metric;
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -183,7 +196,8 @@
arma::vec& distances,
size_t nearSetSize,
size_t& farSetSize,
- size_t& usedSetSize) :
+ size_t& usedSetSize,
+ MetricType& metric) :
dataset(dataset),
point(pointIndex),
scale(scale),
@@ -193,8 +207,11 @@
{
// If the size of the near set is 0, this is a leaf.
if (nearSetSize == 0)
+ {
+ this->scale = INT_MIN;
return;
-
+ }
+
// Determine the next scale level. This should be the first level where there
// are any points in the far set. So, if we know the maximum distance in the
// distances array, this will be the largest i such that
@@ -203,21 +220,22 @@
// implicit node. If the maximum distance is 0, every point in the near set
// will be created as a leaf, and a child to this node. We also do not need
// to change the furthestChildDistance or furthestDescendantDistance.
- const double maxDistance = max(distances.rows(0, nearSetSize + farSetSize - 1));
+ const double maxDistance = max(distances.rows(0,
+ nearSetSize + farSetSize - 1));
if (maxDistance == 0)
{
// Make the self child at the lowest possible level.
// This should not modify farSetSize or usedSetSize.
size_t tempSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
- INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
// Every point in the near set should be a leaf.
for (size_t i = 0; i < nearSetSize; ++i)
{
// farSetSize and usedSetSize will not be modified.
children.push_back(new CoverTree(dataset, expansionConstant, indices[i],
- INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
usedSetSize++;
}
@@ -247,8 +265,8 @@
size_t childUsedSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
nextScale, 0, indices, distances, childNearSetSize, childFarSetSize,
- childUsedSetSize));
-
+ childUsedSetSize, metric));
+
// The self-child can't modify the furthestChildDistance away from 0, but it
// can modify the furthestDescendantDistance.
furthestDescendantDistance = children[0]->FurthestDescendantDistance();
@@ -279,7 +297,7 @@
// is what we are trying to make.
SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
farSetSize);
-
+
// Update size of near set and used set.
nearSetSize -= childUsedSetSize;
usedSetSize += childUsedSetSize;
@@ -315,7 +333,7 @@
size_t childNearSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant,
indices[0], nextScale, distances[0], indices, distances,
- childNearSetSize, farSetSize, usedSetSize));
+ childNearSetSize, farSetSize, usedSetSize, metric));
// Because the far set size is 0, we don't have to do any swapping to
// move the point into the used set.
@@ -335,7 +353,7 @@
// Build distances for the child.
ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
- + farSetSize - 1);
+ + farSetSize - 1, metric);
// Split into near and far sets for this point.
childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
@@ -343,7 +361,7 @@
childFarSetSize = PruneFarSet(childIndices, childDistances,
expansionConstant * bound, childNearSetSize,
(nearSetSize + farSetSize - 1));
-
+
// Now that we know the near and far set sizes, we can put the used point
// (the self point) in the correct place; now, when we call
// MoveToUsedSet(), it will move the self-point correctly. The distance
@@ -355,7 +373,7 @@
childUsedSetSize = 1; // Mark self point as used.
children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
nextScale, distances[0], childIndices, childDistances, childNearSetSize,
- childFarSetSize, childUsedSetSize));
+ childFarSetSize, childUsedSetSize, metric));
// If we created an implicit node, take its self-child instead (this could
// happen multiple times).
@@ -383,9 +401,10 @@
MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
childIndices, childFarSetSize, childUsedSetSize);
}
-
+
// Calculate furthest descendant.
- for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize + usedSetSize); ++i)
+ for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
+ usedSetSize); ++i)
if (distances[i] > furthestDescendantDistance)
furthestDescendantDistance = distances[i];
@@ -495,13 +514,14 @@
const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
- const size_t pointSetSize)
+ const size_t pointSetSize,
+ MetricType& metric)
{
// For each point, rebuild the distances. The indices do not need to be
// modified.
for (size_t i = 0; i < pointSetSize; ++i)
{
- distances[i] = MetricType::Evaluate(dataset.unsafe_col(pointIndex),
+ distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
dataset.unsafe_col(indices[i]));
}
}
@@ -561,7 +581,7 @@
delete[] indicesBuffer;
delete[] distancesBuffer;
-
+
// This returns the complete size of the far set.
return (childFarSetSize + farSetSize);
}
@@ -674,7 +694,7 @@
if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
{
// We have found a point to swap.
-
+
// Perform the swap.
size_t tempIndex = indices[nearSetSize + farSetSize - 1];
double tempDist = distances[nearSetSize + farSetSize - 1];
@@ -741,7 +761,7 @@
// The far set size is the left pointer, with the near set size subtracted
// from it.
- return (left - nearSetSize);
+ return (left - nearSetSize);
}
}; // namespace tree
More information about the mlpack-svn
mailing list