[mlpack-git] master: Correct Pelleg-Moore prunes that finish a node. There were cases where a Pelleg-Moore prune would happen before committing the point. This is actually getting pretty fast in terms of base cases, so I am happy with that (for once). (6300189)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:32 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 6300189ab44849d38795732ac29d8df20d379075
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 29 16:59:56 2015 -0500
Correct Pelleg-Moore prunes that finish a node. There were cases where a Pelleg-Moore prune would happen before committing the point. This is actually getting pretty fast in terms of base cases, so I am happy with that (for once).
>---------------------------------------------------------------
6300189ab44849d38795732ac29d8df20d379075
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 1 +
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 42 ++++++++++++++--------
.../methods/kmeans/dual_tree_kmeans_rules.hpp | 3 +-
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 36 ++++++++++++++++---
4 files changed, 63 insertions(+), 19 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 6b2c81f..68714cd 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -50,6 +50,7 @@ class DualTreeKMeans
arma::vec clusterDistances;
arma::Col<size_t> assignments;
arma::vec distances;
+ arma::Col<size_t> visited;
arma::Col<size_t> distanceIteration;
//! The current iteration.
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index dcf0aa5..27d2fd0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -30,6 +30,7 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
distances.set_size(dataset.n_cols);
distances.fill(DBL_MAX);
assignments.zeros(dataset.n_cols);
+ visited.zeros(dataset.n_cols);
distanceIteration.zeros(dataset.n_cols);
Timer::Start("tree_building");
@@ -89,9 +90,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
// Now run the dual-tree algorithm.
typedef DualTreeKMeansRules<MetricType, TreeType> RulesType;
+ visited.zeros(dataset.n_cols);
RulesType rules(dataset, centroids, newCentroids, counts, oldFromNewCentroids,
- iteration, clusterDistances, distances, assignments, distanceIteration,
- interclusterDistances, metric);
+ iteration, clusterDistances, distances, assignments, visited,
+ distanceIteration, interclusterDistances, metric);
// Use the dual-tree traverser.
//typename TreeType::template DualTreeTraverser<RulesType> traverser(rules);
@@ -99,9 +101,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
traverser(rules);
tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
- Log::Info << "Traversal begins.\n";
traverser.Traverse(*centroidTree, *tree);
- Log::Info << "Traversal done.\n";
distanceCalculations += rules.DistanceCalculations();
@@ -131,12 +131,10 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
size_t hamerlyPruned = 0;
size_t hamerlyPrunedNodes = 0;
size_t totalNodes = 0;
- Log::Info << "Update tree.\n";
UpdateOwner(tree, centroids.n_cols, assignments);
TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned,
hamerlyPrunedNodes, totalNodes);
- Log::Info << "Update tree done.\n";
delete centroidTree;
@@ -241,19 +239,30 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
node->Stat().HamerlyPruned() = false;
++totalNodes;
+/*
+ for (size_t i = 0; i < node->NumPoints(); ++i)
+ {
+ if (!prunedLastIteration &&
+ distanceIteration[node->Point(i)] < iteration)
+ Log::Warn << "Point " << node->Point(i) << " was never visited!"
+<< " (" << distanceIteration[node->Point(i)] << ", " << prunedLastIteration
+<< ")\n";
+ if (!prunedLastIteration &&
+ node->Stat().ClustersPruned() + visited[node->Point(i)] < clusters)
+ Log::Fatal << "Point " << node->Point(i) << " was only visited " <<
+node->Stat().ClustersPruned() << " + " << visited[node->Point(i)] <<
+" times!\n";
+ }
+*/
+
// The easy case: this node had an owner.
if (node->Stat().Owner() < clusters)
{
/*
+ if (prunedLastIteration && node->Stat().MaxQueryNodeDistance() == DBL_MAX)
+ Log::Fatal << "r" << node->Begin() << "c" << node->Count() << " was "
+ << "Hamerly pruned but was not visited!\n";
// Verify correctness...
- for (size_t i = 0; i < node->NumPoints(); ++i)
- {
- if (!prunedLastIteration &&
- distanceIteration[node->Descendant(i)] < iteration)
- Log::Fatal << "Point " << node->Descendant(i) << " was never visited!"
-<< " (" << distanceIteration[node->Descendant(i)] << ", " << prunedLastIteration
-<< ")\n";
- }
for (size_t i = 0; i < node->NumDescendants(); ++i)
{
size_t closest = clusters;
@@ -319,6 +328,11 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
{
if (node->Parent() != NULL && node->Parent()->Stat().HamerlyPruned())
{
+// Log::Warn << "Extra prune via parent: r" << node->Begin() << "c" <<
+//node->Count() << ".\n";
+ if (node->Stat().Owner() != node->Parent()->Stat().Owner())
+ Log::Fatal << "Holy shit!\n";
+
node->Stat().HamerlyPruned() = true;
node->Stat().MinQueryNodeDistance() = DBL_MAX;
}
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index e47651e..8cf4ef5 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -23,6 +23,7 @@ class DualTreeKMeansRules
const arma::vec& clusterDistances,
arma::vec& distances,
arma::Col<size_t>& assignments,
+ arma::Col<size_t>& visited,
arma::Col<size_t>& distanceIteration,
const arma::mat& interclusterDistances,
MetricType& metric);
@@ -59,7 +60,7 @@ class DualTreeKMeansRules
const arma::vec& clusterDistances;
arma::vec& distances;
arma::Col<size_t>& assignments;
- arma::Col<size_t> visited;
+ arma::Col<size_t>& visited;
arma::Col<size_t>& distanceIteration;
const arma::mat& interclusterDistances;
MetricType& metric;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 3064d88..7857d88 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -24,6 +24,7 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
const arma::vec& clusterDistances,
arma::vec& distances,
arma::Col<size_t>& assignments,
+ arma::Col<size_t>& visited,
arma::Col<size_t>& distanceIteration,
const arma::mat& interclusterDistances,
MetricType& metric) :
@@ -36,14 +37,12 @@ DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
clusterDistances(clusterDistances),
distances(distances),
assignments(assignments),
+ visited(visited),
distanceIteration(distanceIteration),
interclusterDistances(interclusterDistances),
metric(metric),
distanceCalculations(0)
-{
- // Nothing has been visited yet.
- visited.zeros(dataset.n_cols);
-}
+{ }
template<typename MetricType, typename TreeType>
inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
@@ -57,6 +56,10 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
// It's possible that the reference node has been pruned before we got to the
// base case. In that case, don't do the base case, and just return.
+// if (referenceIndex == 37447)
+// Log::Warn << "Visit " << referenceIndex << ", q" << queryIndex << ". " <<
+//traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
+//visited[referenceIndex] << ".\n";
if (traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
visited[referenceIndex] == centroids.n_cols)
return 0.0;
@@ -86,6 +89,7 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
newCentroids.col(assignments[referenceIndex]) +=
dataset.col(referenceIndex);
++counts(assignments[referenceIndex]);
+// Log::Warn << "Commit base case " << referenceIndex << ".\n";
}
return distance;
@@ -105,6 +109,12 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+// if (referenceNode.Begin() == 33313 || referenceNode.Begin() == 37121 ||
+// if (referenceNode.Begin() == 37447)
+// Log::Warn << "Visit r" << referenceNode.Begin() << "c" <<
+//referenceNode.Count() << ", q" << queryNode.Begin() << "c" << queryNode.Count()
+//<< ":\n" << referenceNode.Stat();
+
// This won't happen with the root since it is explicitly set to 0.
if (referenceNode.Stat().ClustersPruned() == size_t(-1))
referenceNode.Stat().ClustersPruned() =
@@ -150,7 +160,25 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
}
else if (distances.Lo() > referenceNode.Stat().SecondMaxQueryNodeDistance())
{
+// if (referenceNode.Begin() == 37447)
+// Log::Warn << "Pelleg-Moore pruned.\n";
referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+ // Is everything pruned? Then commit the points.
+ if (referenceNode.Stat().ClustersPruned() +
+ visited[referenceNode.Descendant(0)] == centroids.n_cols)
+ {
+// Log::Warn << "Commit points in r" << referenceNode.Begin() << "c" <<
+//referenceNode.Count() << ".\n";
+ for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
+ {
+ const size_t index = referenceNode.Descendant(i);
+ const size_t cluster = assignments[index];
+ referenceNode.Stat().Owner() = cluster;
+ newCentroids.col(cluster) += dataset.col(index);
+ ++counts(cluster);
+ }
+ }
return DBL_MAX;
}
More information about the mlpack-git
mailing list