[mlpack-svn] r12829 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon May 28 17:21:06 EDT 2012
Author: rcurtin
Date: 2012-05-28 17:21:06 -0400 (Mon, 28 May 2012)
New Revision: 12829
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Parent distance is the interesting one to track, not furthest child distance.
Fix a few bugs so that furthest descendant distance is correct.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-05-28 21:15:35 UTC (rev 12828)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-05-28 21:21:06 UTC (rev 12829)
@@ -139,6 +139,7 @@
const double expansionConstant,
const size_t pointIndex,
const int scale,
+ const double parentDistance,
arma::Col<size_t>& indices,
arma::vec& distances,
size_t nearSetSize,
@@ -239,8 +240,8 @@
//! Returns true: this tree does have self-children.
static bool HasSelfChildren() { return true; }
- //! Get the distance to the furthest child.
- double FurthestChildDistance() const { return furthestChildDistance; }
+ //! Get the distance to the parent.
+ double ParentDistance() const { return parentDistance; }
//! Get the distance to teh furthest descendant.
double FurthestDescendantDistance() const
@@ -268,8 +269,8 @@
//! The instantiated metric. Either the user passes it in, or we build it.
MetricType* metric;
- //! Distance to the furthest child.
- double furthestChildDistance;
+ //! Distance to the parent.
+ double parentDistance;
//! Distance to the furthest descendant.
double furthestDescendantDistance;
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-05-28 21:15:35 UTC (rev 12828)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-05-28 21:21:06 UTC (rev 12829)
@@ -21,7 +21,7 @@
dataset(dataset),
point(RootPointPolicy::ChooseRoot(dataset)),
expansionConstant(expansionConstant),
- furthestChildDistance(0),
+ parentDistance(0),
furthestDescendantDistance(0)
{
// Kick off the building. Create the indices array and the distances array.
@@ -50,7 +50,7 @@
size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
size_t usedSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
- indices, distances, childNearSetSize, childFarSetSize, usedSetSize));
+ 0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize));
furthestDescendantDistance = children[0]->FurthestDescendantDistance();
@@ -94,8 +94,8 @@
distances[0] = tempDist;
}
- if (distances[0] > furthestChildDistance)
- furthestChildDistance = distances[0];
+ if (distances[0] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[0];
size_t childUsedSetSize = 0;
@@ -105,8 +105,8 @@
size_t childNearSetSize = 0;
size_t childFarSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant,
- indices[0], scale - 1, indices, distances, childNearSetSize,
- childFarSetSize, usedSetSize));
+ indices[0], scale - 1, distances[0], indices, distances,
+ childNearSetSize, childFarSetSize, usedSetSize));
// And we're done.
break;
@@ -118,11 +118,12 @@
// Put the current point into the used set, so when we move our indices to
// the used set, this will be done for us.
childIndices(nearSetSize - 1) = indices[0];
- arma::vec childDistances(nearSetSize - 1);
+ arma::vec childDistances(nearSetSize);
// Build distances for the child.
ComputeDistances(indices[0], childIndices, childDistances,
nearSetSize - 1);
+ childDistances(nearSetSize - 1) = 0;
// Split into near and far sets for this point.
childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
@@ -132,7 +133,7 @@
childUsedSetSize = 1; // Mark self point as used.
childFarSetSize = ((nearSetSize - 1) - childNearSetSize);
children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
- scale - 1, childIndices, childDistances, childNearSetSize,
+ scale - 1, distances[0], childIndices, childDistances, childNearSetSize,
childFarSetSize, childUsedSetSize));
// If we created an implicit node, take its self-child instead (this could
@@ -162,6 +163,13 @@
MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
childIndices, childFarSetSize, childUsedSetSize);
}
+
+ // Calculate furthest descendant.
+ for (size_t i = 0; i < usedSetSize; ++i)
+ if (distances[i] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[i];
+
+ Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -170,6 +178,7 @@
const double expansionConstant,
const size_t pointIndex,
const int scale,
+ const double parentDistance,
arma::Col<size_t>& indices,
arma::vec& distances,
size_t nearSetSize,
@@ -179,13 +188,13 @@
point(pointIndex),
scale(scale),
expansionConstant(expansionConstant),
- furthestChildDistance(0), // This stays the case if this is a leaf.
+ parentDistance(parentDistance),
furthestDescendantDistance(0)
{
// If the size of the near set is 0, this is a leaf.
if (nearSetSize == 0)
return;
-
+
// Determine the next scale level. This should be the first level where there
// are any points in the far set. So, if we know the maximum distance in the
// distances array, this will be the largest i such that
@@ -201,14 +210,14 @@
// This should not modify farSetSize or usedSetSize.
size_t tempSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
- INT_MIN, indices, distances, 0, tempSize, usedSetSize));
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
// Every point in the near set should be a leaf.
for (size_t i = 0; i < nearSetSize; ++i)
{
// farSetSize and usedSetSize will not be modified.
children.push_back(new CoverTree(dataset, expansionConstant, indices[i],
- INT_MIN, indices, distances, 0, tempSize, usedSetSize));
+ INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
usedSetSize++;
}
@@ -237,9 +246,9 @@
size_t childFarSetSize = nearSetSize - childNearSetSize;
size_t childUsedSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
- nextScale, indices, distances, childNearSetSize, childFarSetSize,
+ nextScale, 0, indices, distances, childNearSetSize, childFarSetSize,
childUsedSetSize));
-
+
// The self-child can't modify the furthestChildDistance away from 0, but it
// can modify the furthestDescendantDistance.
furthestDescendantDistance = children[0]->FurthestDescendantDistance();
@@ -270,7 +279,7 @@
// is what we are trying to make.
SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
farSetSize);
-
+
// Update size of near set and used set.
nearSetSize -= childUsedSetSize;
usedSetSize += childUsedSetSize;
@@ -297,8 +306,6 @@
}
// Will this be a new furthest child?
- if (distances[0] > furthestChildDistance)
- furthestChildDistance = distances[0];
if (distances[0] > furthestDescendantDistance)
furthestDescendantDistance = distances[0];
@@ -307,8 +314,8 @@
{
size_t childNearSetSize = 0;
children.push_back(new CoverTree(dataset, expansionConstant,
- indices[0], nextScale, indices, distances, childNearSetSize,
- farSetSize, usedSetSize));
+ indices[0], nextScale, distances[0], indices, distances,
+ childNearSetSize, farSetSize, usedSetSize));
// Because the far set size is 0, we don't have to do any swapping to
// move the point into the used set.
@@ -342,11 +349,12 @@
// MoveToUsedSet(), it will move the self-point correctly. The distance
// does not matter.
childIndices(childNearSetSize + childFarSetSize) = indices[0];
+ childDistances(childNearSetSize + childFarSetSize) = 0;
// Build this child (recursively).
childUsedSetSize = 1; // Mark self point as used.
children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
- nextScale, childIndices, childDistances, childNearSetSize,
+ nextScale, distances[0], childIndices, childDistances, childNearSetSize,
childFarSetSize, childUsedSetSize));
// If we created an implicit node, take its self-child instead (this could
@@ -375,8 +383,12 @@
MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
childIndices, childFarSetSize, childUsedSetSize);
}
+
+ // Calculate furthest descendant.
+ for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize + usedSetSize); ++i)
+ if (distances[i] > furthestDescendantDistance)
+ furthestDescendantDistance = distances[i];
- Log::Assert(furthestChildDistance <= pow(expansionConstant, scale));
Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
}
@@ -545,11 +557,11 @@
memcpy(indices.memptr() + bufferToLocation, indicesBuffer,
sizeof(size_t) * bufferSize);
memcpy(distances.memptr() + bufferToLocation, distancesBuffer,
- sizeof(size_t) * bufferSize);
+ sizeof(double) * bufferSize);
delete[] indicesBuffer;
delete[] distancesBuffer;
-
+
// This returns the complete size of the far set.
return (childFarSetSize + farSetSize);
}
@@ -565,6 +577,8 @@
const size_t childFarSetSize, // childNearSetSize is 0 in this case.
const size_t childUsedSetSize)
{
+ const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
+
// Loop across the set. We will swap points as we need. It should be noted
// that farSetSize and nearSetSize may change with each iteration of this loop
// (depending on if we make a swap or not).
@@ -578,10 +592,6 @@
{
// We have found a point; a swap is necessary.
- // Check if this point is a new furthest descendant.
- if (distances[i] > furthestDescendantDistance)
- furthestDescendantDistance = distances[i];
-
// Since this point is from the near set, to preserve the near set, we
// must do a swap.
if (farSetSize > 0)
@@ -665,10 +675,6 @@
{
// We have found a point to swap.
- // Check if this point is a new furthest descendant.
- if (distances[i + nearSetSize] > furthestDescendantDistance)
- furthestDescendantDistance = distances[i + nearSetSize];
-
// Perform the swap.
size_t tempIndex = indices[nearSetSize + farSetSize - 1];
double tempDist = distances[nearSetSize + farSetSize - 1];
@@ -697,6 +703,8 @@
// Update used set size.
usedSetSize += childUsedSetSize;
+
+ Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
}
template<typename MetricType, typename RootPointPolicy, typename StatisticType>
More information about the mlpack-svn
mailing list