[mlpack-git] master: More reasonable handling of bound updating. This speeds things up. (de0a5f1)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:44 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit de0a5f1cf277308ab95767f38da441ac78c6c45c
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Feb 1 13:56:16 2015 -0500
More reasonable handling of bound updating. This speeds things up.
>---------------------------------------------------------------
de0a5f1cf277308ab95767f38da441ac78c6c45c
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 126 ++++++++++++++++++++-----
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 8 +-
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 4 +-
3 files changed, 110 insertions(+), 28 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 9aa0be1..0833ec3 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -162,7 +162,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
}
clusterDistances[centroids.n_cols] = maxMovement;
- Log::Warn << clusterDistances.t();
distanceCalculations += centroids.n_cols;
UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
@@ -210,10 +209,10 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
// It would be nice if we could do this during the traversal.
bool singleOwner = true;
size_t owner = centroids.n_cols + 1;
- node.Stat().MaxClusterDistance() = 0.0;
- node.Stat().SecondClusterBound() = DBL_MAX;
if (!node.Stat().Pruned() && childrenPruned)
{
+ double newMaxClusterDistance = 0.0;
+ double newSecondClusterBound = DBL_MAX;
for (size_t i = 0; i < node.NumPoints(); ++i)
{
// Don't forget to map back from the new cluster index.
@@ -225,10 +224,10 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
singleOwner = false;
// Update maximum cluster distance and second cluster bound.
- if (distances(0, node.Point(i)) > node.Stat().MaxClusterDistance())
- node.Stat().MaxClusterDistance() = distances(0, node.Point(i));
- if (distances(1, node.Point(i)) < node.Stat().SecondClusterBound())
- node.Stat().SecondClusterBound() = distances(1, node.Point(i));
+ if (distances(0, node.Point(i)) > newMaxClusterDistance)
+ newMaxClusterDistance = distances(0, node.Point(i));
+ if (distances(1, node.Point(i)) < newSecondClusterBound)
+ newSecondClusterBound = distances(1, node.Point(i));
}
for (size_t i = 0; i < node.NumChildren(); ++i)
@@ -241,22 +240,27 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
singleOwner = false;
// Update maximum cluster distance and second cluster bound.
- if (node.Child(i).Stat().MaxClusterDistance() >
- node.Stat().MaxClusterDistance())
- node.Stat().MaxClusterDistance() =
- node.Child(i).Stat().MaxClusterDistance();
- if (node.Child(i).Stat().SecondClusterBound() <
- node.Stat().SecondClusterBound())
- node.Stat().SecondClusterBound() =
- node.Child(i).Stat().SecondClusterBound();
+ if (node.Child(i).Stat().MaxClusterDistance() > newMaxClusterDistance)
+ newMaxClusterDistance = node.Child(i).Stat().MaxClusterDistance();
+ if (node.Child(i).Stat().SecondClusterBound() < newSecondClusterBound)
+ newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
}
+ // What do we do with the new cluster bounds?
+ if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
+ node.Stat().MaxClusterDistance())
+ node.Stat().MaxClusterDistance() = newMaxClusterDistance;
+ if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
+ node.Stat().SecondClusterBound())
+ node.Stat().SecondClusterBound() = newSecondClusterBound;
+
// Okay, now we know if it's owned or not, and by which cluster.
if (singleOwner)
{
node.Stat().Owner() = owner;
// Sanity check: ensure the owner is right.
+/*
for (size_t i = 0; i < node.NumPoints(); ++i)
{
const double ownerDist = metric.Evaluate(dataset.col(node.Point(i)),
@@ -268,6 +272,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
if (dist < ownerDist)
{
Log::Warn << node << "...\n" << *node.Parent();
+// TreeType* n = node.Parent()->Parent();
+// while (n != NULL)
+// {
+// Log::Warn << "...\n" << *n;
+// n = n->Parent();
+// }
Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
<< owner << " but has true owner " << j << "! [" <<
oldFromNewCentroids[assignments(0, node.Point(i))] << " -- " <<
@@ -279,6 +289,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
}
}
}
+*/
// What is the maximum distance to the closest cluster in the node?
if (node.Stat().MaxClusterDistance() +
@@ -287,6 +298,20 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
{
node.Stat().Pruned() = true;
}
+
+ // Adjust for next iteration.
+ node.Stat().MaxClusterDistance() +=
+ clusterDistances[node.Stat().Owner()];
+ node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
+ }
+ else
+ {
+ // The node isn't owned by a single cluster.
+ // Adjust the bounds for next iteration.
+ node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
+ node.Stat().SecondClusterBound() = std::max(0.0,
+ node.Stat().SecondClusterBound() -
+ clusterDistances[centroids.n_cols]);
}
}
else if (node.Stat().Pruned())
@@ -294,23 +319,74 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
// The node was pruned last iteration. See if the node can remain pruned.
singleOwner = false;
- node.Stat().Pruned() = false;
- node.Stat().FirstBound() = DBL_MAX;
- node.Stat().SecondBound() = DBL_MAX;
- node.Stat().Bound() = DBL_MAX;
+/*
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ size_t trueOwner = 0;
+ double ownerDist = DBL_MAX;
+ arma::vec distances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(node.Point(i)),
+ centroids.col(j));
+ distances(j) = dist;
+ if (dist < ownerDist)
+ {
+ trueOwner = j;
+ ownerDist = dist;
+ }
+ }
+
+ if (trueOwner != node.Stat().Owner())
+ {
+ Log::Warn << node << "...\n" << *node.Parent();
+ Log::Warn << distances.t();
+ Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+ << node.Stat().Owner() << " but has true owner " << trueOwner <<
+"!\n";
+ }
+ }
+*/
+
+ // Will our bounds still work?
+ if (node.Stat().MaxClusterDistance() +
+ clusterDistances[node.Stat().Owner()] <
+ node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
+ {
+ // The node remains pruned. Adjust the bounds for next iteration.
+ node.Stat().MaxClusterDistance() += clusterDistances[node.Stat().Owner()];
+ node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
+ }
+ else
+ {
+ node.Stat().Pruned() = false;
+ node.Stat().FirstBound() = DBL_MAX;
+ node.Stat().SecondBound() = DBL_MAX;
+ node.Stat().Bound() = DBL_MAX;
+ node.Stat().MaxClusterDistance() = DBL_MAX;
+ node.Stat().SecondClusterBound() = 0.0;
+ }
}
else
{
// The children haven't been pruned, so we can't.
// This node was not pruned last iteration, so we simply need to adjust the
// bounds.
- if (node.Stat().FirstBound() != DBL_MAX)
- node.Stat().FirstBound() += tolerance;
- if (node.Stat().SecondBound() != DBL_MAX)
- node.Stat().SecondBound() += tolerance;
- if (node.Stat().Bound() != DBL_MAX)
- node.Stat().Bound() += tolerance;
+ node.Stat().Owner() = centroids.n_cols;
+ if (node.Stat().MaxClusterDistance() != DBL_MAX)
+ node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
+ if (node.Stat().SecondClusterBound() != DBL_MAX)
+ node.Stat().SecondClusterBound() = std::max(0.0,
+ node.Stat().SecondClusterBound() -
+ clusterDistances[centroids.n_cols]);
}
+
+ if (node.Stat().FirstBound() != DBL_MAX)
+ node.Stat().FirstBound() += tolerance;
+ if (node.Stat().SecondBound() != DBL_MAX)
+ node.Stat().SecondBound() += tolerance;
+ if (node.Stat().Bound() != DBL_MAX)
+ node.Stat().Bound() += tolerance;
}
template<typename MetricType, typename MatType, typename TreeType>
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index eface18..e3d0c55 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -31,7 +31,8 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
{
// We'll check if the query point has been Hamerly pruned. If so, don't
// continue.
-
+// if (queryIndex == 27040)
+// Log::Warn << "Visit point 27040 with cluster " << referenceIndex << ".\n";
return rules.BaseCase(queryIndex, referenceIndex);
}
@@ -48,6 +49,11 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+// if (queryNode.Point(0) == 27040)
+// Log::Warn << "Visit q27040c1 r" << referenceNode.Point(0) << "c" <<
+//referenceNode.NumDescendants() << ", " << queryNode.Stat().Pruned() << ", " <<
+//queryNode.Stat() << ", " << queryNode.Stat().FirstBound() << "," <<
+//queryNode.Stat().SecondBound() << ", " << queryNode.Stat().Bound() << ".\n";
if (queryNode.Stat().Pruned())
return DBL_MAX;
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index ba3af0d..5ae7f30 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -20,7 +20,7 @@ class DTNNStatistic : public
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
pruned(false),
iteration(0),
- maxClusterDistance(0.0),
+ maxClusterDistance(DBL_MAX),
secondClusterBound(0.0),
owner(size_t(-1)),
centroid()
@@ -33,7 +33,7 @@ class DTNNStatistic : public
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
pruned(false),
iteration(0),
- maxClusterDistance(0.0),
+ maxClusterDistance(DBL_MAX),
secondClusterBound(0.0),
owner(size_t(-1))
{
More information about the mlpack-git
mailing list