[mlpack-git] master: A first attempt at a working Hamerly prune. The bounds tighten too much and don't reset, so there's not much speedup, but it's a start. (c83b94b)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:28 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit c83b94bc2243dcb3143cbcb521b4bcbe2aa757db
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jan 21 16:50:07 2015 -0500
A first attempt at a working Hamerly prune. The bounds tighten too much and don't reset, so there's not much speedup, but it's a start.
>---------------------------------------------------------------
c83b94bc2243dcb3143cbcb521b4bcbe2aa757db
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 5 +-
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 116 ++++++++++++++++-----
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 90 ++++++++++------
3 files changed, 152 insertions(+), 59 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 27bcf25..b7e3c61 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -63,7 +63,10 @@ class DualTreeKMeans
void TreeUpdate(TreeType* node,
const size_t clusters,
- const arma::vec& clusterDistances);
+ const arma::vec& clusterDistances,
+ const arma::Col<size_t>& assignments,
+ const arma::mat& oldCentroids,
+ const arma::mat& dataset);
};
template<typename MetricType, typename MatType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 068a74f..747e69f 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -66,6 +66,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
}
// Build a tree on the centroids.
+ arma::mat oldCentroids(centroids);
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
@@ -120,10 +121,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
residual += std::pow(dist, 2.0);
}
}
-// Log::Info << clusterDistances.t();
// Update the tree with the centroid movement information.
- TreeUpdate(tree, centroids.n_cols, clusterDistances);
+ TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
+ oldCentroids, dataset);
delete centroidTree;
@@ -157,7 +158,10 @@ template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
TreeType* node,
const size_t clusters,
- const arma::vec& clusterDistances)
+ const arma::vec& clusterDistances,
+ const arma::Col<size_t>& assignments,
+ const arma::mat& centroids,
+ const arma::mat& dataset)
{
// This is basically IterationUpdate(), but pulled out to be separate from the
// actual dual-tree algorithm.
@@ -165,6 +169,22 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
node->Stat().Owner() = node->Parent()->Stat().Owner();
+ const size_t cluster = assignments[node->Descendant(0)];
+ bool allSame = true;
+ for (size_t i = 1; i < node->NumDescendants(); ++i)
+ {
+ if (assignments[node->Descendant(i)] != cluster)
+ {
+ allSame = false;
+ break;
+ }
+ }
+
+ if (allSame)
+ node->Stat().Owner() = cluster;
+
+ node->Stat().HamerlyPruned() = false;
+
// The easy case: this node had an owner.
if (node->Stat().Owner() < clusters)
{
@@ -175,24 +195,62 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
-/*
- // During the last iteration, this node was pruned. In addition, we have
- // cached a lower bound on the second closest cluster. So, use the
- // triangle inequality: if the maximum distance between the point and the
- // cluster centroid plus the distance that centroid moved is less than the
- // lower bound minus the maximum moving centroid, then this cluster *must*
- // still have the same owner.
- const size_t owner = node->Stat().Owner();
- const double closestUpperBound = node->Stat().MaxQueryNodeDistance() +
- clusterDistances[owner];
- const TreeType* nonOwner = (TreeType*) node->Stat().ClosestNonOwner();
- const double tightestLowerBound = node->Stat().ClosestNonOwnerDistance() -
- nonOwner->Stat().MinQueryNodeDistance();
- if (closestUpperBound <= tightestLowerBound)
+ // Check if we can perform a Hamerly prune: if the node has an owner, and
+ // the second closest cluster could not have moved close enough that any
+ // points could have changed assignment, then this node *must* belong to the
+ // same owner in the next iteration. Note that MaxQueryNodeDistance() has
+ // already been adjusted for cluster movement.
+
+ if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
+ - clusterDistances[clusters])
{
- // Then the owner must not have changed.
+ node->Stat().HamerlyPruned() = true;
+ Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
+ << "Hamerly pruned.\n";
+
+ // Check the second bound. (This is time-consuming...)
+ for (size_t j = 0; j < node->NumDescendants(); ++j)
+ {
+ arma::vec distances(centroids.n_cols);
+ double secondClosestDist = DBL_MAX;
+ for (size_t i = 0; i < centroids.n_cols; ++i)
+ {
+ const double distance = MetricType::Evaluate(centroids.col(i),
+ dataset.col(node->Descendant(j)));
+ if (distance < secondClosestDist && i != node->Stat().Owner())
+ secondClosestDist = distance;
+
+ distances(i) = distance;
+ }
+
+ if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
+ {
+ Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
+node->Stat().MaxQueryNodeDistance() << ", mnqnd " <<
+node->Stat().MinQueryNodeDistance() << ".\n";
+ Log::Warn << distances.t();
+ Log::Fatal << "Second closest bound " <<
+node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
+ << "! (" << node->Stat().SecondClosestBound() - secondClosestDist
+<< ")\n";
+ }
+// if (node->Begin() == 37591)
+// Log::Warn << "r37591c" << node->Count() << ": " << distances.t();
+ }
}
-*/
+// else
+// {
+// Log::Warn << "Failed Hamerly prune for r" << node->Begin() << "c" <<
+// node->Count() << "; mqnd " << node->Stat().MaxQueryNodeDistance() <<
+// ", scb " << node->Stat().SecondClosestBound() << ".\n";
+// }
+
+// if (node->Stat().SecondClosestBound() == DBL_MAX)
+// {
+// Log::Warn << "r" << node->Begin() << "c" << node->Count() << " never had "
+// << "the second bound updated.\n";
+// }
+
}
else
{
@@ -204,6 +262,9 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
node->Stat().MaxQueryNodeDistance() += clusterDistances[clusters];
if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
node->Stat().MinQueryNodeDistance() += clusterDistances[clusters];
+
+ // Since the node didn't have an owner, it can't be Hamerly pruned.
+ node->Stat().HamerlyPruned() = false;
}
node->Stat().Iteration() = iteration;
@@ -211,11 +272,18 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
// We have to set the closest query node to NULL because the cluster tree will
// be rebuilt.
node->Stat().ClosestQueryNode() = NULL;
-// node->Stat().MaxQueryNodeDistance() = DBL_MAX;
-// node->Stat().MinQueryNodeDistance() = DBL_MAX;
-
- for (size_t i = 0; i < node->NumChildren(); ++i)
- TreeUpdate(&node->Child(i), clusters, clusterDistances);
+ node->Stat().SecondClosestBound() -= clusterDistances[clusters];
+ if (node->Stat().SecondClosestBound() < 0)
+ node->Stat().SecondClosestBound() = 0;
+
+// if (node->Begin() == 37591)
+// Log::Warn << "scb for r37591c" << node->Count() << " updated to " <<
+//node->Stat().SecondClosestBound() << ".\n";
+
+// if (!node->Stat().HamerlyPruned())
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
+ centroids, dataset);
}
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 15ea586..6e8cdb2 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -106,6 +106,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
TreeType& referenceNode)
{
// This won't happen with the root since it is explicitly set to 0.
+ const size_t origPruned = referenceNode.Stat().ClustersPruned();
if (referenceNode.Stat().ClustersPruned() == size_t(-1))
referenceNode.Stat().ClustersPruned() =
referenceNode.Parent()->Stat().ClustersPruned();
@@ -123,34 +124,67 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
referenceNode.Stat().MaxQueryNodeDistance() = std::min(
referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
referenceNode.Stat().MaxQueryNodeDistance());
+ referenceNode.Stat().SecondClosestBound() = std::min(
+ referenceNode.Parent()->Stat().SecondClosestBound(),
+ referenceNode.Stat().SecondClosestBound());
}
- double score = ElkanTypeScore(queryNode, referenceNode);
+ double score = HamerlyTypeScore(referenceNode);
+ if (score == DBL_MAX)
+ {
+ if (origPruned == size_t(-1))
+ {
+ const size_t cluster = referenceNode.Stat().Owner();
+ newCentroids.col(cluster) += referenceNode.Stat().Centroid() *
+ referenceNode.NumDescendants();
+ counts(cluster) += referenceNode.NumDescendants();
+ referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+ }
+ return DBL_MAX; // No other bookkeeping to do.
+ }
if (score != DBL_MAX)
{
- // We also have to update things if the closest query node is null. This
- // can probably be improved.
- const double minDistance = referenceNode.MinDistance(&queryNode);
- ++distanceCalculations;
- score = PellegMooreScore(queryNode, referenceNode, minDistance);
+ score = ElkanTypeScore(queryNode, referenceNode);
- if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+ if (score != DBL_MAX)
{
- const double maxDistance = referenceNode.MaxDistance(&queryNode);
+ // We also have to update things if the closest query node is null. This
+ // can probably be improved.
+ const double minDistance = referenceNode.MinDistance(&queryNode);
++distanceCalculations;
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
- }
- else if (IsDescendantOf(*((TreeType*)
- referenceNode.Stat().ClosestQueryNode()), queryNode))
- {
- const double maxDistance = referenceNode.MaxDistance(&queryNode);
- ++distanceCalculations;
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+ score = PellegMooreScore(queryNode, referenceNode, minDistance);
+
+ if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+ {
+ const double maxDistance = referenceNode.MaxDistance(&queryNode);
+ // Only take the previous minimum query node distance in some
+ // circumstances.
+ if (!IsDescendantOf(*((TreeType*)
+ referenceNode.Stat().ClosestQueryNode()), queryNode) &&
+ referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX &&
+ referenceNode.Stat().MinQueryNodeDistance() <
+ referenceNode.Stat().SecondClosestBound())
+ referenceNode.Stat().SecondClosestBound() =
+ referenceNode.Stat().MinQueryNodeDistance();
+ ++distanceCalculations;
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+ }
+ else if (IsDescendantOf(*((TreeType*)
+ referenceNode.Stat().ClosestQueryNode()), queryNode))
+ {
+ const double maxDistance = referenceNode.MaxDistance(&queryNode);
+ ++distanceCalculations;
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+ }
+ else if (minDistance < referenceNode.Stat().SecondClosestBound())
+ {
+ referenceNode.Stat().SecondClosestBound() = minDistance;
+ }
}
}
@@ -209,20 +243,8 @@ template<typename MetricType, typename TreeType>
double DualTreeKMeansRules<MetricType, TreeType>::HamerlyTypeScore(
TreeType& referenceNode)
{
- // Does the reference node have an owner?
- if (referenceNode.Owner() < centroids.n_cols)
- {
- // Has the owner stayed stationary enough and no other centroids moved
- // enough that this owner _must_ be the continued owner?
- if (referenceNode.MaxQueryNodeDistance() +
- clusterDistances[referenceNode.Owner()] <
- referenceNode.SecondClosestQueryNodeDistance() -
- clusterDistances[centroids.n_cols])
- {
- return DBL_MAX;
- // Not yet handled: when to add this to the finished counts?
- }
- }
+ if (referenceNode.Stat().HamerlyPruned())
+ return DBL_MAX;
return 0.0;
}
More information about the mlpack-git
mailing list