[mlpack-git] master: I don't think this is worth saving. It also doesn't work very well, but I learned a lot about the bookkeeping I need to do. (de01b8f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:51 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit de01b8f67ae3c54ab259069ac5c369723dcebfd0
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jan 28 15:17:53 2015 -0500
I don't think this is worth saving. It also doesn't work very well, but I learned a lot about the bookkeeping I need to do.
>---------------------------------------------------------------
de01b8f67ae3c54ab259069ac5c369723dcebfd0
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 322 +++++++--------------
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 253 ++++------------
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 39 ++-
src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp | 5 +
4 files changed, 204 insertions(+), 415 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index f08b4e1..7a65d3e 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -39,7 +39,7 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
datasetCopy = datasetOrig;
// Now build the tree. We don't need any mappings.
- tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+ tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset), 1);
Timer::Stop("tree_building");
}
@@ -186,9 +186,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
{
// This is basically IterationUpdate(), but pulled out to be separate from the
// actual dual-tree algorithm.
-
+ if (node->Begin() == 26038)
+ Log::Warn << "r26038c" << node->Count() << " has owner " <<
+node->Stat().Owner() << ".\n";
if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
node->Stat().Owner() = node->Parent()->Stat().Owner();
+ if (node->Begin() == 26038)
+ Log::Warn << "r26038c" << node->Count() << " has owner " <<
+node->Stat().Owner() << " after parent check.\n";
const size_t cluster = assignments[node->Descendant(0)];
bool allSame = true;
@@ -203,242 +208,102 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
if (allSame)
node->Stat().Owner() = cluster;
+ else
+ node->Stat().Owner() = centroids.n_cols;
+ if (node->Begin() == 26038)
+ Log::Warn << "r26038c" << node->Count() << " has manually set owner " <<
+node->Stat().Owner() << ".\n";
const bool prunedLastIteration = node->Stat().HamerlyPruned();
node->Stat().HamerlyPruned() = false;
- if (node->Begin() == 23058)
- Log::Warn << "r23058c" << node->Count() << " has owner " <<
+ if (node->Begin() == 26038)
+ Log::Warn << "r26038c" << node->Count() << " has owner " <<
node->Stat().Owner() << ".\n";
// The easy case: this node had an owner.
if (node->Stat().Owner() < clusters)
{
- // During the last iteration, this node was pruned.
- const size_t owner = node->Stat().Owner();
- if (node->Stat().MaxQueryNodeDistance() != DBL_MAX)
- node->Stat().MaxQueryNodeDistance() += clusterDistances[owner];
- if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
- node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
-
- // Check if we can perform a Hamerly prune: if the node has an owner, and
- // the second closest cluster could not have moved close enough that any
- // points could have changed assignment, then this node *must* belong to the
- // same owner in the next iteration. Note that MaxQueryNodeDistance() has
- // already been adjusted for cluster movement.
-
- // Re-set second closest bound if necessary.
- if (node->Stat().SecondClosestBound() == DBL_MAX && node->Parent() == NULL)
- node->Stat().SecondClosestBound() = 0.0; // Don't prune the root.
-
- if (node->Begin() == 23058)
- Log::Warn << "r23058c" << node->Count() << " scb " <<
-node->Stat().SecondClosestBound() << " and lscb " <<
-node->Stat().LastSecondClosestBound() << ".\n";
-
- // If both the second closest bound and last second closest bound are valid,
- // we have the option of taking the better of the two bounds. But if only
- // one is valid, take the minimum of the two (which will be the valid one).
- // If neither is valid, then we end up with a second closest bound of
- // DBL_MAX.
- const double scb = node->Stat().SecondClosestBound();
- const double lscb = node->Stat().LastSecondClosestBound();
- if (scb != DBL_MAX && lscb != DBL_MAX)
- node->Stat().SecondClosestBound() = std::max(scb, lscb);
- else
- node->Stat().SecondClosestBound() = std::min(scb, lscb);
-
- // But if we were Hamerly pruned last time, we can't trust the second
- // closest bound and thus have to take last iteration's.
- if (prunedLastIteration)
- node->Stat().SecondClosestBound() = lscb;
- else
+ // Verify correctness...
+ for (size_t i = 0; i < node->NumDescendants(); ++i)
{
- // Now, we must ensure that we don't need to take the parent's second
- // closest bound. We surely do if the current bound is DBL_MAX. We
- // already took care of the root node earlier so we don't need to check if
- // Parent() is NULL.
- if (node->Stat().SecondClosestBound() == DBL_MAX)
- node->Stat().SecondClosestBound() =
- node->Parent()->Stat().SecondClosestBound();
-
- // There may exist a case where the true second closest query node got
- // pruned by the parent, and was thus never visited with this node. This
- // situation occurs if the second closest query node is not a descendant
- // of the second closest query node of the parent.
- if (node->Stat().SecondClosestQueryNode() != NULL)
+ size_t closest = clusters;
+ double closestDistance = DBL_MAX;
+ arma::vec distances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
{
- if (node->Begin() == 23058)
+ const double distance = metric.Evaluate(centroids.col(j),
+ dataset.col(node->Descendant(i)));
+ if (distance < closestDistance)
{
- Log::Warn << "Second closest query node is q" << ((TreeType*)
-node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
-node->Stat().SecondClosestQueryNode())->Count() << ", with scb " <<
-node->Stat().SecondClosestBound() << ".\n";
- Log::Warn << "True SCB to this node should be " <<
-node->MinDistance((TreeType*) node->Stat().SecondClosestQueryNode()) << ".\n";
+ closest = j;
+ closestDistance = distance;
}
+ distances(j) = distance;
}
- if (node->Stat().ClosestQueryNode() != NULL)
- if (node->Begin() == 23058)
- Log::Warn << "Closest query node: q" << ((TreeType*)
-node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
-node->Stat().ClosestQueryNode())->Count() << ", with MQND " <<
-node->Stat().MaxQueryNodeDistance() << " and mQND " <<
-node->Stat().MinQueryNodeDistance() << ".\n";
-
- // If the closest query node contains more than one descendant, we have to
- // find the closest...
- TreeType* cqn = (TreeType*) node->Stat().ClosestQueryNode();
- if (cqn != NULL && cqn->NumDescendants() > 1)
+ if (closest != node->Stat().Owner())
{
- size_t closest = centroids.n_cols;
- double closestDistance = DBL_MAX;
- size_t secondClosest = centroids.n_cols;
- double secondClosestDistance = DBL_MAX;
- for (size_t i = 0; i < cqn->NumDescendants(); ++i)
- {
- const size_t index = cqn->Descendant(i);
- const double distance =
- node->MinDistance(centroids.col(oldFromNew[index]));
-// Log::Info << "Index " << index << ", distance " << distance << " (i "
-// << i + cqn->Begin() << ").\n";
- ++distanceCalculations;
- if (distance < closestDistance)
- {
- secondClosest = closest;
- secondClosestDistance = closestDistance;
- closest = index;
- closestDistance = distance;
- }
- else if (distance < secondClosestDistance)
- {
- secondClosest = index;
- secondClosestDistance = distance;
- }
- }
-
- // Recalculate maximum distance.
- const double maxDistance = node->MaxDistance(centroids.col(closest));
- ++distanceCalculations;
-
- node->Stat().MinQueryNodeDistance() = closestDistance;
- node->Stat().MaxQueryNodeDistance() = maxDistance;
- if (secondClosestDistance < node->Stat().SecondClosestBound())
- node->Stat().SecondClosestBound() = secondClosestDistance;
-
- if (node->Begin() == 23058)
- Log::Warn << "After recalculation, closest for r" << node->Begin() << "c" << node->Count()
-<< " is " << closest << ", with mQND " << node->Stat().MinQueryNodeDistance() <<
-", MQND" << node->Stat().MaxQueryNodeDistance() << ", and scb " <<
-node->Stat().SecondClosestBound() << ", " << secondClosest << ".\n";
- }
-
-// if (node->Parent() != NULL &&
-//node->Parent()->Stat().SecondClosestQueryNode() != NULL)
-// if (node->Begin() == 23058)
-// Log::Warn << "Parent's (r" << node->Parent()->Begin() << "c"
-//<< node->Parent()->Count() << ") second closest query node is q" << ((TreeType*)
-//node->Parent()->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
-//node->Parent()->Stat().SecondClosestQueryNode())->Count() << ", with scb " <<
-//node->Parent()->Stat().SecondClosestBound() << ".\n";
-
- // Suppose that the true second closest query node was pruned by the
- // parent, and thus was never seen by this node. To ensure the
- // correctness of the second bound in this situation, we'll take the
- // parent's second closest bound only if the parent's second closest query
- // node is on a separate subtree than the node's second closest query node
- // _and_ the node's closest query node.
- TreeType* parent = (TreeType*) node->Parent();
- TreeType* scqn = (TreeType*) node->Stat().SecondClosestQueryNode();
- TreeType* parentScqn = (parent == NULL) ? NULL :
- (TreeType*) parent->Stat().SecondClosestQueryNode();
- TreeType* parentCqn = (parent == NULL) ? NULL :
- (TreeType*) parent->Stat().ClosestQueryNode();
- if (parentScqn != NULL && node->Begin() == 23058)
- Log::Warn << "Parent (" << parent->Begin() << "c" << parent->Count() <<
-") SCB is " << parent->Stat().SecondClosestBound() << ", "
- << "with q" << parentScqn->Begin() << "c" << parentScqn->Count() <<
+ Log::Warn << distances.t();
+ Log::Fatal << "Point " << node->Descendant(i) << " mistakenly assigned "
+ << "to cluster " << node->Stat().Owner() << ", but should be " <<
+closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
".\n";
- if (scqn != NULL && parentScqn != NULL &&
- !IsDescendantOf(*parentScqn, *scqn) &&
- !IsDescendantOf(*parentCqn, *scqn) &&
- (parent->Stat().SecondClosestBound() <
- node->Stat().SecondClosestBound()))
- {
- if (node->Begin() == 23058)
- Log::Warn << "Take parent's SCB of " <<
-parent->Stat().SecondClosestBound() << "; parent SCQN is " <<
-parentScqn->Begin() << "c" << parentScqn->Count() << ", parent CQN is " <<
-parentCqn->Begin() << "c" << parentCqn->Count() << ".\n";
- node->Stat().SecondClosestBound() = parent->Stat().SecondClosestBound();
- node->Stat().SecondClosestQueryNode() = parentScqn;
}
}
- if (node->Begin() == 23058)
- {
- Log::Warn << "Attempt Hamerly prune on r23058c" << node->Count() <<
- " with MQND " << node->Stat().MaxQueryNodeDistance() << ", scb "
- << node->Stat().SecondClosestBound() << ", owner " <<
-node->Stat().Owner() << ", and clusterDistances " << clusterDistances[clusters]
-<< ".\n";
- }
+ // During the last iteration, this node was pruned.
+ const size_t owner = node->Stat().Owner();
+ if (node->Stat().MaxQueryNodeDistance() != DBL_MAX)
+ node->Stat().MaxQueryNodeDistance() += clusterDistances[owner];
+ if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
+ node->Stat().MinQueryNodeDistance() += clusterDistances[owner];
- // Check the second bound. (This is time-consuming...)
- arma::vec minDistances(centroids.n_cols);
- for (size_t j = 0; j < node->NumDescendants(); ++j)
+ if (prunedLastIteration)
{
- arma::vec distances(centroids.n_cols);
- double secondClosestDist = DBL_MAX;
- for (size_t i = 0; i < centroids.n_cols; ++i)
+ // Can we continue being Hamerly pruned? If not, we'll have to update the
+ // bound next iteration.
+ if (node->Begin() == 26038)
+ Log::Warn << "r26038c" << node->Count() << ": check sustained Hamerly "
+ << "prune with MQND " << node->Stat().MaxQueryNodeDistance() << ", "
+ << "lscb " << node->Stat().LastSecondClosestBound() << ", cd "
+ << clusterDistances[clusters] << ".\n";
+ if (node->Stat().MaxQueryNodeDistance() <
+ node->Stat().LastSecondClosestBound() - clusterDistances[clusters])
{
- if (j == 0)
- minDistances[i] = node->MinDistance(centroids.col(i));
-
- const double distance = MetricType::Evaluate(centroids.col(i),
- dataset.col(node->Descendant(j)));
- if (distance < secondClosestDist && i != node->Stat().Owner())
- secondClosestDist = distance;
-
- distances(i) = distance;
+ node->Stat().HamerlyPruned() = true;
+ if (!node->Parent()->Stat().HamerlyPruned())
+ hamerlyPruned += node->NumDescendants();
}
-
- if (j == 0)
- if (node->Begin() == 23058)
- Log::Warn << "r23058c" << node->Count() << ": " << minDistances.t();
- if (secondClosestDist < node->Stat().SecondClosestBound() - 1e-15)
+ }
+ else
+ {
+ if (node->Begin() == 26038)
{
- Log::Warn << "r" << node->Begin() << "c" << node->Count() << ":\n";
- Log::Warn << "Owner " << node->Stat().Owner() << ", mqnd " <<
-node->Stat().MaxQueryNodeDistance() << ", mnqnd " <<
-node->Stat().MinQueryNodeDistance() << ".\n";
- Log::Warn << distances.t();
- Log::Fatal << "Second closest bound " <<
-node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
- << "! (" << node->Stat().SecondClosestBound() - secondClosestDist
-<< ")\n";
+ if (node->Stat().ClosestQueryNode() != NULL)
+ Log::Warn << "r26038c" << node->Count() << " CQN: " << ((TreeType*)
+ node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
+ node->Stat().ClosestQueryNode())->Count() << ".\n";
+ if (node->Stat().SecondClosestQueryNode() != NULL)
+ Log::Warn << "r26038c" << node->Count() << " SCQN: " << ((TreeType*)
+ node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
+ node->Stat().SecondClosestQueryNode())->Count() << ".\n";
+ Log::Warn << "Attempt hamerly prune r26038c" << node->Count() << " with "
+ << "MQND " << node->Stat().MaxQueryNodeDistance() << " and smqnd "
+ << node->Stat().SecondMinQueryNodeDistance() << " and cluster d "
+ << clusterDistances[clusters] << ".\n";
}
- }
-
- if (node->Stat().MaxQueryNodeDistance() < node->Stat().SecondClosestBound()
- - clusterDistances[clusters])
- {
- node->Stat().HamerlyPruned() = true;
- if (!node->Parent()->Stat().HamerlyPruned())
+ // Now we check for a Hamerly prune. We know that we have an accurate
+ // second bound since nothing can be pruned.
+ if (node->Stat().MaxQueryNodeDistance() /* already adjusted */ <
+ node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters])
{
- if (node->Begin() == 23058)
- Log::Warn << "Mark r" << node->Begin() << "c" << node->Count() << " as "
- << "Hamerly pruned.\n";
- hamerlyPruned += node->NumDescendants();
+ node->Stat().HamerlyPruned() = true;
+ if (!node->Parent()->Stat().HamerlyPruned())
+ hamerlyPruned += node->NumDescendants();
}
}
- else
- {
- // No Hamerly prune, so we don't have a known owner.
- node->Stat().Owner() = clusters;
- }
}
else
{
@@ -456,9 +321,34 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
node->Stat().Owner() = centroids.n_cols;
}
+ bool allPruned = true;
+ size_t owner = clusters;
for (size_t i = 0; i < node->NumChildren(); ++i)
+ {
TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
centroids, dataset, oldFromNew, hamerlyPruned);
+ if (!node->Child(i).Stat().HamerlyPruned())
+ allPruned = false;
+ else if (owner == clusters)
+ owner = node->Child(i).Stat().Owner();
+ else if (owner < clusters && owner != node->Child(i).Stat().Owner())
+ owner = clusters + 1;
+ }
+
+ if (node->NumChildren() == 0 && !node->Stat().HamerlyPruned())
+ allPruned = false;
+
+ if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
+ {
+ if (node->Begin() == 26038)
+ Log::Warn << "Set r" << node->Begin() << "c" << node->Count() << " to be "
+ << "Hamerly pruned.\n";
+ node->Stat().HamerlyPruned() = true;
+ }
+
+ if (node->Begin() == 26038 && node->Stat().HamerlyPruned())
+ Log::Warn << "r" << node->Begin() << "c" << node->Count() << " is Hamerly "
+ << "pruned.\n";
node->Stat().Iteration() = iteration;
node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;
@@ -466,11 +356,19 @@ node->Stat().SecondClosestBound() << " is too loose! -- " << secondClosestDist
// be rebuilt.
node->Stat().ClosestQueryNode() = NULL;
- node->Stat().LastSecondClosestBound() = node->Stat().SecondClosestBound() -
- clusterDistances[clusters];
+ if (prunedLastIteration)
+ node->Stat().LastSecondClosestBound() -= clusterDistances[clusters];
+ else
+ node->Stat().LastSecondClosestBound() =
+ node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters];
+ node->Stat().MinQueryNodeDistance() = DBL_MAX;
+ if (prunedLastIteration && !node->Stat().HamerlyPruned())
+ node->Stat().MaxQueryNodeDistance() = DBL_MAX;
+ node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
+ node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
// This should change later, but I'm not yet sure how to do it.
- node->Stat().SecondClosestBound() = DBL_MAX;
- node->Stat().SecondClosestQueryNode() = NULL;
+// node->Stat().SecondClosestBound() = DBL_MAX;
+// node->Stat().SecondClosestQueryNode() = NULL;
if (node->Parent() == NULL)
Log::Info << "Total Hamerly pruned points: " << hamerlyPruned << ".\n";
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 28b6695..4615b80 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -50,6 +50,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
const size_t queryIndex,
const size_t referenceIndex)
{
+ if (referenceIndex == 26038)
+ Log::Warn << "Visit 26038 with query " << queryIndex << ".\n";
+
// Collect the number of clusters that have been pruned during the traversal.
// The ternary operator may not be necessary.
const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL) ?
@@ -72,17 +75,26 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
distanceIteration[referenceIndex] = iteration;
distances[referenceIndex] = distance;
assignments[referenceIndex] = mappings[queryIndex];
+ if (referenceIndex == 26038)
+ Log::Warn << "assignment for point " << referenceIndex << " set to " <<
+mappings[queryIndex] << ".\n";
}
else if (distance < distances[referenceIndex])
{
distances[referenceIndex] = distance;
assignments[referenceIndex] = mappings[queryIndex];
+ if (referenceIndex == 26038)
+ Log::Warn << "assignment for point " << referenceIndex << " set to " <<
+mappings[queryIndex] << ".\n";
}
++visited[referenceIndex];
if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
{
+ if (referenceIndex == 26038)
+ Log::Warn << "assignment for point " << referenceIndex << " committed to " <<
+assignments[referenceIndex] << ".\n";
newCentroids.col(assignments[referenceIndex]) +=
dataset.col(referenceIndex);
++counts(assignments[referenceIndex]);
@@ -105,206 +117,69 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+// if (referenceNode.Begin() == 2432)
+// Log::Warn << "Visit q" << queryNode.Begin() << "c" << queryNode.Count() <<
+//", r" << referenceNode.Begin() << "c" << referenceNode.Count() << ".\n";
+
// This won't happen with the root since it is explicitly set to 0.
const size_t origPruned = referenceNode.Stat().ClustersPruned();
if (referenceNode.Stat().ClustersPruned() == size_t(-1))
referenceNode.Stat().ClustersPruned() =
referenceNode.Parent()->Stat().ClustersPruned();
- traversalInfo.LastReferenceNode() = &referenceNode;
-
- if (referenceNode.Begin() == 23058)
- Log::Warn << "Visit r23058c" << referenceNode.Count() << ", q" <<
-queryNode.Begin() << "c" << queryNode.Count() << ".\n";
-
- // If there's no closest query node assigned, but the parent has one, take
- // that one.
- if (referenceNode.Stat().ClosestQueryNode() == NULL &&
- referenceNode.Parent() != NULL &&
- referenceNode.Parent()->Stat().ClosestQueryNode() != NULL)
- {
- if (referenceNode.Begin() == 23058)
- Log::Warn << "Update closest query node for r23058c" <<
-referenceNode.Count() << " to parent's, which is "
- << ((TreeType*)
-referenceNode.Parent()->Stat().ClosestQueryNode())->Begin() << "c" <<
-((TreeType*) referenceNode.Parent()->Stat().ClosestQueryNode())->Count() <<
-".\n";
-
- referenceNode.Stat().ClosestQueryNode() =
- referenceNode.Parent()->Stat().ClosestQueryNode();
- referenceNode.Stat().MaxQueryNodeDistance() = std::min(
- referenceNode.Parent()->Stat().MaxQueryNodeDistance(),
- referenceNode.Stat().MaxQueryNodeDistance());
-// referenceNode.Stat().SecondClosestBound() = std::min(
-// referenceNode.Parent()->Stat().SecondClosestBound(),
-// referenceNode.Stat().SecondClosestBound());
-// if (referenceNode.Begin() == 23058)
-// Log::Warn << "Update second closest bound for r23058c" <<
-//referenceNode.Count() << " to parent's, which "
-// << "is " << referenceNode.Stat().SecondClosestBound() << ".\n";
- }
-
- double score = HamerlyTypeScore(referenceNode);
- if (score == DBL_MAX)
+ if (referenceNode.Stat().HamerlyPruned())
{
- if (referenceNode.Begin() == 23058)
- Log::Warn << "Hamerly prune for r23058c" << referenceNode.Count() << ", q" << queryNode.Begin() << "c" <<
-queryNode.Count() << ".\n";
- if (origPruned == size_t(-1))
+ // Add to centroids if necessary.
+ if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX /* hack */)
{
- const size_t cluster = referenceNode.Stat().Owner();
- newCentroids.col(cluster) += referenceNode.Stat().Centroid() *
- referenceNode.NumDescendants();
-// Log::Warn << "Hamerly prune: r" << referenceNode.Begin() << "c" <<
-// referenceNode.Count() << ".\n";
- counts(cluster) += referenceNode.NumDescendants();
- referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+ if (referenceNode.Begin() == 26038)
+ Log::Warn << "Add centroid mass for r26038c" << referenceNode.Count() <<
+".\n";
+ newCentroids.col(referenceNode.Stat().Owner()) +=
+ referenceNode.NumDescendants() * referenceNode.Stat().Centroid();
+ counts(referenceNode.Stat().Owner()) += referenceNode.NumDescendants();
+ referenceNode.Stat().MinQueryNodeDistance() = 0.0;
}
- return DBL_MAX; // No other bookkeeping to do.
+ return DBL_MAX; // No need to go further.
}
- if (score != DBL_MAX)
- {
- score = ElkanTypeScore(queryNode, referenceNode);
+ traversalInfo.LastReferenceNode() = &referenceNode;
- if (score != DBL_MAX)
- {
- // We also have to update things if the closest query node is null. This
- // can probably be improved.
- const double minDistance = referenceNode.MinDistance(&queryNode);
- ++distanceCalculations;
- score = PellegMooreScore(queryNode, referenceNode, minDistance);
- if (referenceNode.Begin() == 23058)
- Log::Warn << "mQND for r23058c" << referenceNode.Count() << " is "
- << referenceNode.Stat().MinQueryNodeDistance() << "; minDistance "
- << minDistance << ", scb " <<
-referenceNode.Stat().SecondClosestBound() << ".\n";
-
- if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
- {
- const double maxDistance = referenceNode.MaxDistance(&queryNode);
- if (!IsDescendantOf(*((TreeType*)
- referenceNode.Stat().ClosestQueryNode()), queryNode) &&
- referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX &&
- referenceNode.Stat().MinQueryNodeDistance() <
- referenceNode.Stat().SecondClosestBound() &&
- &queryNode != referenceNode.Stat().ClosestQueryNode())
- {
- referenceNode.Stat().SecondClosestBound() =
- referenceNode.Stat().MinQueryNodeDistance();
- referenceNode.Stat().SecondClosestQueryNode() =
- referenceNode.Stat().ClosestQueryNode();
- if (referenceNode.Begin() == 23058)
- Log::Warn << "scb for r23058c" << referenceNode.Count() << " taken "
- << "from minDistance, which is " <<
-referenceNode.Stat().MinQueryNodeDistance() << ".\n";
- }
-
- if (referenceNode.Stat().MinQueryNodeDistance() == DBL_MAX &&
- score == DBL_MAX &&
- minDistance < referenceNode.Stat().SecondClosestBound())
- {
- referenceNode.Stat().SecondClosestBound() = minDistance;
- referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
- if (referenceNode.Begin() == 23058)
- Log::Warn << "scb for r23058c" << referenceNode.Count() << " taken "
- << "from minDistance for pruned query node, which is " <<
-minDistance << ".\n";
- }
-
- if (score != DBL_MAX)
- {
- ++distanceCalculations;
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-
- if (referenceNode.Begin() == 23058)
- Log::Warn << "mQND for r23058c" << referenceNode.Count() << " updated to " << minDistance << " and "
- << "MQND to " << maxDistance << " with furthest query node " <<
- queryNode.Begin() << "c" << queryNode.Count() << ".\n";
- }
- }
- else if (IsDescendantOf(*((TreeType*)
- referenceNode.Stat().ClosestQueryNode()), queryNode))
- {
- if (referenceNode.Begin() == 23058)
- Log::Warn << "Old closest for r23058c" << referenceNode.Count() <<
- " is q" << ((TreeType*)
-referenceNode.Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
-referenceNode.Stat().ClosestQueryNode())->Count() << " with mQND " <<
-referenceNode.Stat().MinQueryNodeDistance() << " and MQND " <<
-referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
- const double maxDistance = referenceNode.MaxDistance(&queryNode);
- ++distanceCalculations;
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-
- if (referenceNode.Begin() == 23058)
- Log::Warn << "mQND for r23058c" << referenceNode.Count() << " updated to " << minDistance << " and "
- << "MQND to " << maxDistance << " via descendant with fqn " <<
- queryNode.Begin() << "c" << queryNode.Count() << ".\n";
- }
- else if (minDistance < referenceNode.Stat().SecondClosestBound())
- {
- referenceNode.Stat().SecondClosestBound() = minDistance;
- referenceNode.Stat().SecondClosestQueryNode() = &queryNode;
- if (referenceNode.Begin() == 23058)
- Log::Warn << "scb for r23058c" << referenceNode.Count() << " updated to " << minDistance << " via "
- << queryNode.Begin() << "c" << queryNode.Count() << ".\n";
- }
- }
- else
- {
- // There was an Elkan prune, but we still need to check the second closest
- // bound.
- const double minDistance = referenceNode.MinDistance(&queryNode);
- ++distanceCalculations;
- if (minDistance < referenceNode.Stat().SecondClosestBound())
- {
- if (referenceNode.Begin() == 23058)
- Log::Warn << "After Elkan prune, update scb to " << minDistance <<
-".\n";
+ // Calculate distance to node.
+ // This costs about the same (in terms of runtime) as a single MinDistance()
+ // call, so there only need to add one distance computation.
+ math::Range distances = referenceNode.RangeDistance(&queryNode);
+ ++distanceCalculations;
- referenceNode.Stat().SecondClosestBound() = minDistance;
- referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
- }
- }
+ // Is this closer than the current best query node?
+ if (distances.Lo() < referenceNode.Stat().MinQueryNodeDistance())
+ {
+ if (referenceNode.Begin() == 26038)
+ Log::Warn << "r26038c" << referenceNode.Count() << ": new CQN " <<
+queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+ // This is the new closest node.
+ referenceNode.Stat().SecondClosestQueryNode() =
+ referenceNode.Stat().ClosestQueryNode();
+ referenceNode.Stat().SecondMinQueryNodeDistance() =
+ referenceNode.Stat().MinQueryNodeDistance();
+ referenceNode.Stat().SecondMaxQueryNodeDistance() =
+ referenceNode.Stat().MaxQueryNodeDistance();
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = distances.Lo();
+ referenceNode.Stat().MaxQueryNodeDistance() = distances.Hi();
}
-
-// if (((TreeType*) referenceNode.Stat().ClosestQueryNode())->NumDescendants() > 1)
-// {
-// referenceNode.Stat().SecondClosestBound() =
-// referenceNode.Stat().MinQueryNodeDistance();
-// referenceNode.Stat().SecondClosestQueryNode() =
-// referenceNode.Stat().ClosestQueryNode();
-// }
-
- if (score == DBL_MAX)
+ else if (distances.Lo() < referenceNode.Stat().SecondMinQueryNodeDistance())
{
- referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
- if (referenceNode.Begin() == 23058)
- Log::Warn << "For r23058c" << referenceNode.Count() << ", q" <<
-queryNode.Begin() << "c" << queryNode.Count() << " is pruned. Min distance is"
- << " " << queryNode.MinDistance(&referenceNode) << " and scb is " <<
-referenceNode.Stat().SecondClosestBound() << ".\n";
-
- // Have we pruned everything?
- if (referenceNode.Stat().ClustersPruned() +
- visited[referenceNode.Descendant(0)] == centroids.n_cols)
- {
- for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
- {
- const size_t cluster = assignments[referenceNode.Descendant(i)];
- newCentroids.col(cluster) += dataset.col(referenceNode.Descendant(i));
- counts(cluster)++;
- }
- }
+ if (referenceNode.Begin() == 26038)
+ Log::Warn << "r26038c" << referenceNode.Count() << ": new SCQN " <<
+queryNode.Begin() << "c" << queryNode.Count() << ".\n";
+ // This is the new second closest node.
+ referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().SecondMinQueryNodeDistance() = distances.Lo();
+ referenceNode.Stat().SecondMaxQueryNodeDistance() = distances.Hi();
}
- return score;
+ return 0.0; // No pruning allowed at this time.
}
template<typename MetricType, typename TreeType>
@@ -344,7 +219,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::HamerlyTypeScore(
{
if (referenceNode.Stat().HamerlyPruned())
{
-// if (referenceNode.Begin() == 23058)
+// if (referenceNode.Begin() == 26038)
// Log::Warn << "Hamerly prune! r" << referenceNode.Begin() << "c" <<
//referenceNode.Count() << ".\n";
return DBL_MAX;
@@ -392,8 +267,8 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
queryNode)) &&
(&queryNode != (TreeType*) referenceNode.Stat().ClosestQueryNode()))
{
- if (referenceNode.Begin() == 23058)
- Log::Warn << "Elkan prune r23058c" << referenceNode.Count() << ", q" <<
+ if (referenceNode.Begin() == 26038)
+ Log::Warn << "Elkan prune r26038c" << referenceNode.Count() << ", q" <<
queryNode.Begin() << "c" << queryNode.Count() << "!\n";
// Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
// means that N_q cannot possibly hold any clusters that own any points in
@@ -413,14 +288,14 @@ double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
// If the minimum distance to the node is greater than the bound, then every
// cluster in the query node cannot possibly be the nearest neighbor of any of
// the points in the reference node.
-// if (referenceNode.Begin() == 23058)
-// Log::Warn << "Pelleg-Moore prune attempt r23058c" << referenceNode.Count() << ", "
+// if (referenceNode.Begin() == 26038)
+// Log::Warn << "Pelleg-Moore prune attempt r26038c" << referenceNode.Count() << ", "
// << "q" << queryNode.Begin() << "c" << queryNode.Count() << "; "
// << "minDistance " << minDistance << ", MQND " <<
//referenceNode.Stat().MaxQueryNodeDistance() << ".\n";
if (minDistance > referenceNode.Stat().MaxQueryNodeDistance())
{
-// if (referenceNode.Begin() == 23058)
+// if (referenceNode.Begin() == 26038)
// Log::Warn << "Attempt successful!\n";
return DBL_MAX;
}
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 0b01fa6..c42eabe 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -18,10 +18,11 @@ class DualTreeKMeansStatistic
template<typename TreeType>
DualTreeKMeansStatistic(TreeType& node) :
closestQueryNode(NULL),
+ secondClosestQueryNode(NULL),
minQueryNodeDistance(DBL_MAX),
maxQueryNodeDistance(DBL_MAX),
- secondClosestBound(DBL_MAX),
- secondClosestQueryNode(NULL),
+ secondMinQueryNodeDistance(DBL_MAX),
+ secondMaxQueryNodeDistance(DBL_MAX),
lastSecondClosestBound(DBL_MAX),
hamerlyPruned(false),
clustersPruned(size_t(-1)),
@@ -55,6 +56,11 @@ class DualTreeKMeansStatistic
//! Modify the current closest query node.
void*& ClosestQueryNode() { return closestQueryNode; }
+ //! Get the second closest query node.
+ void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+ //! Modify the second closest query node.
+ void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+
//! Get the minimum distance to the closest query node.
double MinQueryNodeDistance() const { return minQueryNodeDistance; }
//! Modify the minimum distance to the closest query node.
@@ -65,15 +71,17 @@ class DualTreeKMeansStatistic
//! Modify the maximum distance to the closest query node.
double& MaxQueryNodeDistance() { return maxQueryNodeDistance; }
- //! Get a lower bound on the second closest cluster distance.
- double SecondClosestBound() const { return secondClosestBound; }
- //! Modify the lower bound on the second closest cluster distance.
- double& SecondClosestBound() { return secondClosestBound; }
+ //! Get the minimum distance to the second closest query node.
+ double SecondMinQueryNodeDistance() const
+ { return secondMinQueryNodeDistance; }
+ //! Modify the minimum distance to the second closest query node.
+ double& SecondMinQueryNodeDistance() { return secondMinQueryNodeDistance; }
- //! Get the second closest query node.
- void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
- //! Modify the second closest query node.
- void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+ //! Get the maximum distance to the second closest query node.
+ double SecondMaxQueryNodeDistance() const
+ { return secondMaxQueryNodeDistance; }
+ //! Modify the maximum distance to the second closest query node.
+ double& SecondMaxQueryNodeDistance() { return secondMaxQueryNodeDistance; }
//! Get last iteration's second closest bound.
double LastSecondClosestBound() const { return lastSecondClosestBound; }
@@ -129,14 +137,17 @@ class DualTreeKMeansStatistic
//! The current closest query node to this reference node.
void* closestQueryNode;
+ //! The second closest query node.
+ void* secondClosestQueryNode;
//! The minimum distance to the closest query node.
double minQueryNodeDistance;
//! The maximum distance to the closest query node.
double maxQueryNodeDistance;
- //! A lower bound on the distance to the second closest cluster.
- double secondClosestBound;
- //! The second closest query node.
- void* secondClosestQueryNode;
+ //! The minimum distance to the second closest query node.
+ double secondMinQueryNodeDistance;
+ //! The maximum distance to the second closest query node.
+ double secondMaxQueryNodeDistance;
+
//! The second closest lower bound, on the previous iteration.
double lastSecondClosestBound;
//! Whether or not this node is pruned for the next iteration.
diff --git a/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp b/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
index 06edfb0..b33a0eb 100644
--- a/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/hamerly_kmeans_impl.hpp
@@ -28,6 +28,8 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
arma::mat& newCentroids,
arma::Col<size_t>& counts)
{
+ size_t hamerlyPruned = 0;
+
// If this is the first iteration, we need to set all the bounds.
if (minClusterDistances.n_elem != centroids.n_cols)
{
@@ -68,6 +70,7 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
// First bound test.
if (upperBounds(i) <= m)
{
+ ++hamerlyPruned;
newCentroids.col(assignments[i]) += dataset.col(i);
++counts(assignments[i]);
continue;
@@ -161,6 +164,8 @@ double HamerlyKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
lowerBounds(i) -= furthestMovement;
}
+ Log::Info << "Hamerly prunes: " << hamerlyPruned << ".\n";
+
return std::sqrt(centroidMovement);
}
More information about the mlpack-git
mailing list