[mlpack-git] master: Refactor into UpdateOwner(), instead of an ugly loop at the beginning of TreeUpdate(). (2633583)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:12 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 2633583a44a2ef6776721199f28291f02cae75b5
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jan 28 16:17:01 2015 -0500
Refactor into UpdateOwner(), instead of an ugly loop at the beginning of TreeUpdate().
>---------------------------------------------------------------
2633583a44a2ef6776721199f28291f02cae75b5
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 4 +
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 106 +++++++++------------
2 files changed, 48 insertions(+), 62 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 9e0c17a..1d26273 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -61,6 +61,10 @@ class DualTreeKMeans
void ClusterTreeUpdate(TreeType* node,
const arma::mat& distances);
+ void UpdateOwner(TreeType* node,
+ const size_t clusters,
+ const arma::Col<size_t>& assignments) const;
+
void TreeUpdate(TreeType* node,
const size_t clusters,
const arma::vec& clusterDistances,
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 7a65d3e..756cceb 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -127,6 +127,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
// Update the tree with the centroid movement information.
size_t hamerlyPruned = 0;
+ UpdateOwner(tree, centroids.n_cols, assignments);
TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned);
@@ -174,6 +175,47 @@ bool IsDescendantOf(
}
template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateOwner(
+ TreeType* node,
+ const size_t clusters,
+ const arma::Col<size_t>& assignments) const
+{
+ size_t owner = clusters + 1;
+ bool same = true;
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ {
+ UpdateOwner(&node->Child(i), clusters, assignments);
+ if (owner == clusters + 1)
+ owner = node->Child(i).Stat().Owner();
+ else if (owner != node->Child(i).Stat().Owner())
+ {
+ same = false;
+ owner = clusters;
+ break;
+ }
+ }
+
+ if (same)
+ {
+ for (size_t i = 0; i < node->NumPoints(); ++i)
+ {
+ if (owner == clusters + 1)
+ owner = assignments[node->Point(i)];
+ else if (owner != assignments[node->Point(i)])
+ {
+ same = false;
+ break;
+ }
+ }
+ }
+
+ if (same)
+ node->Stat().Owner() = owner;
+ else
+ node->Stat().Owner() = clusters;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
TreeType* node,
const size_t clusters,
@@ -186,45 +228,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
{
// This is basically IterationUpdate(), but pulled out to be separate from the
// actual dual-tree algorithm.
- if (node->Begin() == 26038)
- Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << ".\n";
- if (node->Parent() != NULL && node->Parent()->Stat().Owner() < clusters)
- node->Stat().Owner() = node->Parent()->Stat().Owner();
- if (node->Begin() == 26038)
- Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << " after parent check.\n";
-
- const size_t cluster = assignments[node->Descendant(0)];
- bool allSame = true;
- for (size_t i = 1; i < node->NumDescendants(); ++i)
- {
- if (assignments[node->Descendant(i)] != cluster)
- {
- allSame = false;
- break;
- }
- }
-
- if (allSame)
- node->Stat().Owner() = cluster;
- else
- node->Stat().Owner() = centroids.n_cols;
- if (node->Begin() == 26038)
- Log::Warn << "r26038c" << node->Count() << " has manually set owner " <<
-node->Stat().Owner() << ".\n";
-
const bool prunedLastIteration = node->Stat().HamerlyPruned();
node->Stat().HamerlyPruned() = false;
- if (node->Begin() == 26038)
- Log::Warn << "r26038c" << node->Count() << " has owner " <<
-node->Stat().Owner() << ".\n";
-
// The easy case: this node had an owner.
if (node->Stat().Owner() < clusters)
{
// Verify correctness...
+ /*
for (size_t i = 0; i < node->NumDescendants(); ++i)
{
size_t closest = clusters;
@@ -251,6 +262,7 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
".\n";
}
}
+ */
// During the last iteration, this node was pruned.
const size_t owner = node->Stat().Owner();
@@ -263,11 +275,6 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
{
// Can we continue being Hamerly pruned? If not, we'll have to update the
// bound next iteration.
- if (node->Begin() == 26038)
- Log::Warn << "r26038c" << node->Count() << ": check sustained Hamerly "
- << "prune with MQND " << node->Stat().MaxQueryNodeDistance() << ", "
- << "lscb " << node->Stat().LastSecondClosestBound() << ", cd "
- << clusterDistances[clusters] << ".\n";
if (node->Stat().MaxQueryNodeDistance() <
node->Stat().LastSecondClosestBound() - clusterDistances[clusters])
{
@@ -278,22 +285,6 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
}
else
{
- if (node->Begin() == 26038)
- {
- if (node->Stat().ClosestQueryNode() != NULL)
- Log::Warn << "r26038c" << node->Count() << " CQN: " << ((TreeType*)
- node->Stat().ClosestQueryNode())->Begin() << "c" << ((TreeType*)
- node->Stat().ClosestQueryNode())->Count() << ".\n";
- if (node->Stat().SecondClosestQueryNode() != NULL)
- Log::Warn << "r26038c" << node->Count() << " SCQN: " << ((TreeType*)
- node->Stat().SecondClosestQueryNode())->Begin() << "c" << ((TreeType*)
- node->Stat().SecondClosestQueryNode())->Count() << ".\n";
- Log::Warn << "Attempt hamerly prune r26038c" << node->Count() << " with "
- << "MQND " << node->Stat().MaxQueryNodeDistance() << " and smqnd "
- << node->Stat().SecondMinQueryNodeDistance() << " and cluster d "
- << clusterDistances[clusters] << ".\n";
- }
-
// Now we check for a Hamerly prune. We know that we have an accurate
// second bound since nothing can be pruned.
if (node->Stat().MaxQueryNodeDistance() /* already adjusted */ <
@@ -339,16 +330,7 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
allPruned = false;
if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
- {
- if (node->Begin() == 26038)
- Log::Warn << "Set r" << node->Begin() << "c" << node->Count() << " to be "
- << "Hamerly pruned.\n";
node->Stat().HamerlyPruned() = true;
- }
-
- if (node->Begin() == 26038 && node->Stat().HamerlyPruned())
- Log::Warn << "r" << node->Begin() << "c" << node->Count() << " is Hamerly "
- << "pruned.\n";
node->Stat().Iteration() = iteration;
node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;
More information about the mlpack-git
mailing list