[mlpack-git] master: Handle SecondClosestBound() a little better. Debugging information is still there, and it is going to need to be seriously refactored. We still don't have properly working Hamerly prunes; they go away after a couple iterations incorrectly. (db0d792)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:56 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit db0d792ef83a231ca777a6c12a5450745d6cffdf
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Jan 26 15:30:01 2015 -0500
Handle SecondClosestBound() a little better. Debugging information is still there, and it is going to need to be seriously refactored. We still don't have properly working Hamerly prunes; they go away after a couple iterations incorrectly.
>---------------------------------------------------------------
db0d792ef83a231ca777a6c12a5450745d6cffdf
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 97 ++++++++++++++++++----
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 69 ++++++++++++---
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 16 ++++
3 files changed, 157 insertions(+), 25 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index ab293af..592a175 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -70,6 +70,9 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+ for (size_t i = 0; i < oldFromNewCentroids.size(); ++i)
+ Log::Warn << oldFromNewCentroids[i] << " ";
+ Log::Warn << "\n";
// Now calculate distances between centroids.
neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
@@ -154,6 +157,19 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ClusterTreeUpdate(
node->Stat().FirstBound() = firstBound;
}
+template<typename TreeType>
+bool IsDescendantOf(
+ const TreeType& potentialParent,
+ const TreeType& potentialChild)
+{
+ if (potentialChild.Parent() == &potentialParent)
+ return true;
+ else if (potentialChild.Parent() == NULL)
+ return false;
+ else
+ return IsDescendantOf(potentialParent, *potentialChild.Parent());
+}
+
template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
TreeType* node,
@@ -201,12 +217,64 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
// same owner in the next iteration. Note that MaxQueryNodeDistance() has
// already been adjusted for cluster movement.
+ // Re-set second closest bound if necessary.
+ if (node->Stat().SecondClosestBound() == DBL_MAX)
+ {
+ if (node->Parent() == NULL)
+ node->Stat().SecondClosestBound() = 0.0; // Don't prune the root.
+
+ else
+ {
+ if (node->Parent()->Stat().SecondClosestBound() != DBL_MAX &&
+node->Stat().LastSecondClosestBound() != DBL_MAX)
+ node->Stat().SecondClosestBound() =
+std::max(node->Parent()->Stat().SecondClosestBound(),
+node->Stat().LastSecondClosestBound());
+ else
+ node->Stat().SecondClosestBound() =
+std::min(node->Parent()->Stat().SecondClosestBound(),
+node->Stat().LastSecondClosestBound());
+ }
+// if (node->Begin() == 35871)
+// Log::Warn << "Update second closest bound for r35871c" <<
+//node->Count() << " to " << node->Stat().SecondClosestBound() << ", which could "
+// << "have been parent's (" << node->Parent()->Stat().SecondClosestBound()
+//<< ") or adjusted last iteration's (" << node->Stat().LastSecondClosestBound()
+//<< ").\n";
+ }
+
+// if (node->Begin() == 35871)
+// Log::Warn << "r35871c" << node->Count() << " has second bound " <<
+//node->Stat().SecondClosestBound() << " (q" << ((TreeType*)
+//node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//node->Stat().SecondClosestQueryNode())->Count() << ") and parent has second "
+// << "bound " << node->Parent()->Stat().SecondClosestBound() << " (q"
+// << ((TreeType*)
+//node->Parent()->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//node->Parent()->Stat().SecondClosestQueryNode())->Count() << ").\n";
+
+ if (node->Parent() != NULL &&
+node->Parent()->Stat().SecondClosestQueryNode() != NULL &&
+node->Stat().SecondClosestQueryNode() != NULL && !IsDescendantOf(*((TreeType*)
+node->Stat().SecondClosestQueryNode()), *((TreeType*)
+node->Parent()->Stat().SecondClosestQueryNode())) &&
+node->Parent()->Stat().SecondClosestBound() < node->Stat().SecondClosestBound())
+ {
+// if (node->Begin() == 35871)
+// Log::Warn << "Take second closest bound for r35871c" <<
+//node->Count() << " from parent: " << node->Parent()->Stat().SecondClosestBound()
+//<< " (was " << node->Stat().SecondClosestBound() << ").\n";
+ node->Stat().SecondClosestBound() =
+node->Parent()->Stat().SecondClosestBound();
+ }
+
if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
- clusterDistances[clusters])
{
node->Stat().HamerlyPruned() = true;
- Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
- << "Hamerly pruned.\n";
+// if (node->Begin() == 35871)
+ Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
+ << "Hamerly pruned.\n";
// Check the second bound. (This is time-consuming...)
for (size_t j = 0; j < node->NumDescendants(); ++j)
@@ -223,13 +291,6 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
distances(i) = distance;
}
- // Re-set second closest bound if necessary.
- if (node->Stat().ClustersPruned() == size_t(-1))
- {
-// Log::Warn << "Update second closest bound!\n";
- node->Stat().SecondClosestBound() = node->Parent()->Stat().SecondClosestBound();
- }
-
if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
{
Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
@@ -240,9 +301,10 @@ node->Stat().MinQueryNodeDistance() << ".\n";
node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
<< "! (" << node->Stat().SecondClosestBound() - secondClosestDist
<< ")\n";
+
}
-// if (node->Begin() == 37591)
-// Log::Warn << "r37591c" << node->Count() << ": " << distances.t();
+// if (node->Begin() == 35871)
+// Log::Warn << "r35871c" << node->Count() << ": " << distances.t();
}
}
// else
@@ -280,16 +342,21 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
// be rebuilt.
node->Stat().ClosestQueryNode() = NULL;
-// if (node->Begin() == 37591)
-// Log::Warn << "scb for r37591c" << node->Count() << " updated to " <<
+// if (node->Begin() == 35871)
+// Log::Warn << "scb for r35871c" << node->Count() << " updated to " <<
//node->Stat().SecondClosestBound() << ".\n";
-// if (!node->Stat().HamerlyPruned())
+ if (!node->Stat().HamerlyPruned())
for (size_t i = 0; i < node->NumChildren(); ++i)
TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
centroids, dataset);
-}
+ node->Stat().LastSecondClosestBound() = node->Stat().SecondClosestBound() -
+ clusterDistances[clusters];
+ // This should change later, but I'm not yet sure how to do it.
+ node->Stat().SecondClosestBound() = DBL_MAX;
+ node->Stat().SecondClosestQueryNode() = NULL;
+}
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index f51c380..0a84376 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -136,9 +136,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
referenceNode.Stat().MaxQueryNodeDistance() = std::min(
referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
referenceNode.Stat().MaxQueryNodeDistance());
- referenceNode.Stat().SecondClosestBound() = std::min(
- referenceNode.Parent()->Stat().SecondClosestBound(),
- referenceNode.Stat().SecondClosestBound());
+// referenceNode.Stat().SecondClosestBound() = std::min(
+// referenceNode.Parent()->Stat().SecondClosestBound(),
+// referenceNode.Stat().SecondClosestBound());
// if (referenceNode.Begin() == 37591)
// Log::Warn << "Update second closest bound for r37591c" <<
//referenceNode.Count() << " to parent's, which "
@@ -175,6 +175,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
const double minDistance = referenceNode.MinDistance(&queryNode);
++distanceCalculations;
score = PellegMooreScore(queryNode, referenceNode, minDistance);
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "mQND for r37591c" << referenceNode.Count() << " is "
+// << referenceNode.Stat().MinQueryNodeDistance() << ".\n";
if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
{
@@ -187,24 +190,49 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
{
referenceNode.Stat().SecondClosestBound() =
referenceNode.Stat().MinQueryNodeDistance();
+ referenceNode.Stat().SecondClosestQueryNode() =
+ referenceNode.Stat().ClosestQueryNode();
// if (referenceNode.Begin() == 37591)
// Log::Warn << "scb for r37591c" << referenceNode.Count() << " taken "
// << "from minDistance, which is " <<
//referenceNode.Stat().MinQueryNodeDistance() << ".\n";
}
- ++distanceCalculations;
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-// if (referenceNode.Begin() == 37591)
-// Log::Warn << "mQND for r37591c" << referenceNode.Count() << " updated to " << minDistance << " and "
+ if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX &&
+ score == DBL_MAX &&
+ minDistance < referenceNode.Stat().SecondClosestBound())
+ {
+ referenceNode.Stat().SecondClosestBound() = minDistance;
+ referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "scb for r37591c" << referenceNode.Count() << " taken "
+// << "from minDistance for pruned query node, which is " <<
+//minDistance << ".\n";
+ }
+
+ if (score != DBL_MAX)
+ {
+ ++distanceCalculations;
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "mQND for r37591c" << referenceNode.Count() << " updated to " << minDistance << " and "
// << "MQND to " << maxDistance << " with furthest query node " <<
// queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+ }
}
else if (IsDescendantOf(*((TreeType*)
referenceNode.Stat().ClosestQueryNode()), queryNode))
{
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "Old closest for r37591c" << referenceNode.Count() <<
+// " is q" << ((TreeType*)
+//referenceNode.Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
+//referenceNode.Stat().ClosestQueryNode())->Count() << " with mQND " <<
+//referenceNode.Stat().MinQueryNodeDistance() << " and MQND " <<
+//referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
const double maxDistance = referenceNode.MaxDistance(&queryNode);
++distanceCalculations;
referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
@@ -219,6 +247,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
else if (minDistance < referenceNode.Stat().SecondClosestBound())
{
referenceNode.Stat().SecondClosestBound() = minDistance;
+ referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
// if (referenceNode.Begin() == 37591)
// Log::Warn << "scb for r37591c" << referenceNode.Count() << " updated to " << minDistance << " via "
// << queryNode.Begin() << "c" << queryNode.Count() << ".\n";
@@ -226,6 +255,14 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
}
}
+ if (((TreeType*) referenceNode.Stat().ClosestQueryNode())->NumDescendants() > 1)
+ {
+ referenceNode.Stat().SecondClosestBound() =
+ referenceNode.Stat().MinQueryNodeDistance();
+ referenceNode.Stat().SecondClosestQueryNode() =
+ referenceNode.Stat().ClosestQueryNode();
+ }
+
if (score == DBL_MAX)
{
referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
@@ -335,6 +372,9 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
queryNode)) &&
(&queryNode != (TreeType*) referenceNode.Stat().ClosestQueryNode()))
{
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "Elkan prune r37591c" << referenceNode.Count() << ", q" <<
+//queryNode.Begin() << "c" << queryNode.Count() << "!\n";
// Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
// means that N_q cannot possibly hold any clusters that own any points in
// N_r.
@@ -346,15 +386,24 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
template<typename MetricType, typename TreeType>
double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
- TreeType& /* queryNode */,
+ TreeType& queryNode,
TreeType& referenceNode,
const double minDistance) const
{
// If the minimum distance to the node is greater than the bound, then every
// cluster in the query node cannot possibly be the nearest neighbor of any of
// the points in the reference node.
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "Pelleg-Moore prune attempt r37591c" << referenceNode.Count() << ", "
+// << "q" << queryNode.Begin() << "c" << queryNode.Count() << "; "
+// << "minDistance " << minDistance << ", MQND " <<
+//referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
if (minDistance > referenceNode.Stat().MaxQueryNodeDistance())
+ {
+// if (referenceNode.Begin() == 37591)
+// Log::Warn << "Attempt successful!\n";
return DBL_MAX;
+ }
return minDistance;
}
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 87e4368..0b01fa6 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -21,6 +21,8 @@ class DualTreeKMeansStatistic
minQueryNodeDistance(DBL_MAX),
maxQueryNodeDistance(DBL_MAX),
secondClosestBound(DBL_MAX),
+ secondClosestQueryNode(NULL),
+ lastSecondClosestBound(DBL_MAX),
hamerlyPruned(false),
clustersPruned(size_t(-1)),
iteration(size_t() - 1),
@@ -68,6 +70,16 @@ class DualTreeKMeansStatistic
//! Modify the lower bound on the second closest cluster distance.
double& SecondClosestBound() { return secondClosestBound; }
+ //! Get the second closest query node.
+ void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+ //! Modify the second closest query node.
+ void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+
+ //! Get last iteration's second closest bound.
+ double LastSecondClosestBound() const { return lastSecondClosestBound; }
+ //! Modify last iteration's second closest bound.
+ double& LastSecondClosestBound() { return lastSecondClosestBound; }
+
//! Get whether or not this node is Hamerly pruned this iteration.
bool HamerlyPruned() const { return hamerlyPruned; }
//! Modify whether or not this node is Hamerly pruned this iteration.
@@ -123,6 +135,10 @@ class DualTreeKMeansStatistic
double maxQueryNodeDistance;
//! A lower bound on the distance to the second closest cluster.
double secondClosestBound;
+ //! The second closest query node.
+ void* secondClosestQueryNode;
+ //! The second closest lower bound, on the previous iteration.
+ double lastSecondClosestBound;
//! Whether or not this node is pruned for the next iteration.
bool hamerlyPruned;
More information about the mlpack-git
mailing list