[mlpack-git] master: Better speedups, provide more output on prunes. (c413ad2)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:02:18 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit c413ad23996be0c7cf504700da17b7cf41b6a3e2
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 29 14:29:56 2015 -0500
Better speedups, provide more output on prunes.
>---------------------------------------------------------------
c413ad23996be0c7cf504700da17b7cf41b6a3e2
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 4 +-
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 69 +++++++++++++++++++---
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 17 +-----
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 16 ++---
4 files changed, 74 insertions(+), 32 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index 1d26273..6b2c81f 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -72,7 +72,9 @@ class DualTreeKMeans
const arma::mat& oldCentroids,
const arma::mat& dataset,
const std::vector<size_t>& oldFromNew,
- size_t& hamerlyPruned);
+ size_t& hamerlyPruned,
+ size_t& hamerlyPrunedNodes,
+ size_t& totalNodes);
};
template<typename MetricType, typename MatType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 756cceb..dcf0aa5 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -99,7 +99,9 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
traverser(rules);
tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
+ Log::Info << "Traversal begins.\n";
traverser.Traverse(*centroidTree, *tree);
+ Log::Info << "Traversal done.\n";
distanceCalculations += rules.DistanceCalculations();
@@ -127,9 +129,14 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
// Update the tree with the centroid movement information.
size_t hamerlyPruned = 0;
+ size_t hamerlyPrunedNodes = 0;
+ size_t totalNodes = 0;
+ Log::Info << "Update tree.\n";
UpdateOwner(tree, centroids.n_cols, assignments);
TreeUpdate(tree, centroids.n_cols, clusterDistances, assignments,
- oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned);
+ oldCentroids, dataset, oldFromNewCentroids, hamerlyPruned,
+ hamerlyPrunedNodes, totalNodes);
+ Log::Info << "Update tree done.\n";
delete centroidTree;
@@ -224,18 +231,29 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
const arma::mat& centroids,
const arma::mat& dataset,
const std::vector<size_t>& oldFromNew,
- size_t& hamerlyPruned)
+ size_t& hamerlyPruned,
+ size_t& hamerlyPrunedNodes,
+ size_t& totalNodes)
{
// This is basically IterationUpdate(), but pulled out to be separate from the
// actual dual-tree algorithm.
const bool prunedLastIteration = node->Stat().HamerlyPruned();
node->Stat().HamerlyPruned() = false;
+ ++totalNodes;
// The easy case: this node had an owner.
if (node->Stat().Owner() < clusters)
{
+/*
// Verify correctness...
- /*
+ for (size_t i = 0; i < node->NumPoints(); ++i)
+ {
+ if (!prunedLastIteration &&
+ distanceIteration[node->Descendant(i)] < iteration)
+ Log::Fatal << "Point " << node->Descendant(i) << " was never visited!"
+<< " (" << distanceIteration[node->Descendant(i)] << ", " << prunedLastIteration
+<< ")\n";
+ }
for (size_t i = 0; i < node->NumDescendants(); ++i)
{
size_t closest = clusters;
@@ -262,7 +280,7 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
".\n";
}
}
- */
+*/
// During the last iteration, this node was pruned.
const size_t owner = node->Stat().Owner();
@@ -293,8 +311,27 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
node->Stat().HamerlyPruned() = true;
if (!node->Parent()->Stat().HamerlyPruned())
hamerlyPruned += node->NumDescendants();
+ ++hamerlyPrunedNodes;
+ }
+ }
+
+ if (!node->Stat().HamerlyPruned())
+ {
+ if (node->Parent() != NULL && node->Parent()->Stat().HamerlyPruned())
+ {
+ node->Stat().HamerlyPruned() = true;
+ node->Stat().MinQueryNodeDistance() = DBL_MAX;
+ }
+ else
+ {
+ if (node->Stat().SecondMaxQueryNodeDistance() != DBL_MAX)
+ node->Stat().SecondMaxQueryNodeDistance() += clusterDistances[clusters];
+ if (node->Stat().SecondMinQueryNodeDistance() != DBL_MAX)
+ node->Stat().SecondMinQueryNodeDistance() += clusterDistances[clusters];
}
}
+ else
+ node->Stat().MinQueryNodeDistance() = DBL_MAX;
}
else
{
@@ -306,6 +343,10 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
node->Stat().MaxQueryNodeDistance() += clusterDistances[clusters];
if (node->Stat().MinQueryNodeDistance() != DBL_MAX)
node->Stat().MinQueryNodeDistance() += clusterDistances[clusters];
+ if (node->Stat().SecondMaxQueryNodeDistance() != DBL_MAX)
+ node->Stat().SecondMaxQueryNodeDistance() += clusterDistances[clusters];
+ if (node->Stat().SecondMinQueryNodeDistance() != DBL_MAX)
+ node->Stat().SecondMinQueryNodeDistance() += clusterDistances[clusters];
// Since the node didn't have an owner, it can't be Hamerly pruned.
node->Stat().HamerlyPruned() = false;
@@ -317,7 +358,8 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
for (size_t i = 0; i < node->NumChildren(); ++i)
{
TreeUpdate(&node->Child(i), clusters, clusterDistances, assignments,
- centroids, dataset, oldFromNew, hamerlyPruned);
+ centroids, dataset, oldFromNew, hamerlyPruned, hamerlyPrunedNodes,
+ totalNodes);
if (!node->Child(i).Stat().HamerlyPruned())
allPruned = false;
else if (owner == clusters)
@@ -330,30 +372,41 @@ closest << "! It's part of node r" << node->Begin() << "c" << node->Count() <<
allPruned = false;
if (allPruned && owner < clusters && !node->Stat().HamerlyPruned())
+ {
+ node->Stat().MinQueryNodeDistance() = DBL_MAX;
node->Stat().HamerlyPruned() = true;
+ hamerlyPrunedNodes++;
+ }
node->Stat().Iteration() = iteration;
node->Stat().ClustersPruned() = (node->Parent() == NULL) ? 0 : -1;
// We have to set the closest query node to NULL because the cluster tree will
// be rebuilt.
- node->Stat().ClosestQueryNode() = NULL;
+// node->Stat().ClosestQueryNode() = NULL;
if (prunedLastIteration)
node->Stat().LastSecondClosestBound() -= clusterDistances[clusters];
else
node->Stat().LastSecondClosestBound() =
node->Stat().SecondMinQueryNodeDistance() - clusterDistances[clusters];
+// node->Stat().MinQueryNodeDistance() = DBL_MAX;
node->Stat().MinQueryNodeDistance() = DBL_MAX;
+ node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
if (prunedLastIteration && !node->Stat().HamerlyPruned())
+ {
node->Stat().MaxQueryNodeDistance() = DBL_MAX;
- node->Stat().SecondMinQueryNodeDistance() = DBL_MAX;
- node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
+ node->Stat().SecondMaxQueryNodeDistance() = DBL_MAX;
+ }
// This should change later, but I'm not yet sure how to do it.
// node->Stat().SecondClosestBound() = DBL_MAX;
// node->Stat().SecondClosestQueryNode() = NULL;
if (node->Parent() == NULL)
+ {
Log::Info << "Total Hamerly pruned points: " << hamerlyPruned << ".\n";
+ Log::Info << "Total pruned Hamerly nodes: " << hamerlyPrunedNodes << ".\n";
+ Log::Info << "Total nodes in tree: " << totalNodes << ".\n";
+ }
}
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index ced3ffa..3064d88 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -128,46 +128,33 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
// Calculate distance to node.
// This costs about the same (in terms of runtime) as a single MinDistance()
// call, so there only need to add one distance computation.
- math::Range distances = referenceNode.RangeDistance(&queryNode);
+ const math::Range distances = referenceNode.RangeDistance(&queryNode);
++distanceCalculations;
// Is this closer than the current best query node?
if (distances.Lo() < referenceNode.Stat().MinQueryNodeDistance())
{
// This is the new closest node.
- referenceNode.Stat().SecondClosestQueryNode() =
- referenceNode.Stat().ClosestQueryNode();
referenceNode.Stat().SecondMinQueryNodeDistance() =
referenceNode.Stat().MinQueryNodeDistance();
referenceNode.Stat().SecondMaxQueryNodeDistance() =
referenceNode.Stat().MaxQueryNodeDistance();
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
referenceNode.Stat().MinQueryNodeDistance() = distances.Lo();
referenceNode.Stat().MaxQueryNodeDistance() = distances.Hi();
}
else if (distances.Lo() < referenceNode.Stat().SecondMinQueryNodeDistance())
{
// This is the new second closest node.
- referenceNode.Stat().SecondClosestQueryNode() = (void*) &queryNode;
referenceNode.Stat().SecondMinQueryNodeDistance() = distances.Lo();
referenceNode.Stat().SecondMaxQueryNodeDistance() = distances.Hi();
}
else if (distances.Lo() > referenceNode.Stat().SecondMaxQueryNodeDistance())
{
- // This is a Pelleg-Moore type prune.
-// Log::Warn << "Pelleg-Moore prune: " << distances.Lo() << "/" <<
-//distances.Hi() << ", r" << referenceNode.Begin() << "c" << referenceNode.Count()
-//<< ", q" << queryNode.Begin() << "c" << queryNode.Count() << "; mQND " <<
-//referenceNode.Stat().MinQueryNodeDistance() << ", MQND " <<
-//referenceNode.Stat().MaxQueryNodeDistance() << ", smQND " <<
-//referenceNode.Stat().SecondMinQueryNodeDistance() << ", sMQND " <<
-//referenceNode.Stat().SecondMaxQueryNodeDistance() << ".\n";
-
referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
return DBL_MAX;
}
- return 0.0; // No pruning allowed at this time.
+ return distances.Lo(); // No pruning allowed at this time.
}
template<typename MetricType, typename TreeType>
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index c42eabe..8d762fd 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -17,8 +17,8 @@ class DualTreeKMeansStatistic
template<typename TreeType>
DualTreeKMeansStatistic(TreeType& node) :
- closestQueryNode(NULL),
- secondClosestQueryNode(NULL),
+// closestQueryNode(NULL),
+// secondClosestQueryNode(NULL),
minQueryNodeDistance(DBL_MAX),
maxQueryNodeDistance(DBL_MAX),
secondMinQueryNodeDistance(DBL_MAX),
@@ -52,14 +52,14 @@ class DualTreeKMeansStatistic
arma::vec& Centroid() { return centroid; }
//! Get the current closest query node.
- void* ClosestQueryNode() const { return closestQueryNode; }
+// void* ClosestQueryNode() const { return closestQueryNode; }
//! Modify the current closest query node.
- void*& ClosestQueryNode() { return closestQueryNode; }
+// void*& ClosestQueryNode() { return closestQueryNode; }
//! Get the second closest query node.
- void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
+// void* SecondClosestQueryNode() const { return secondClosestQueryNode; }
//! Modify the second closest query node.
- void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
+// void*& SecondClosestQueryNode() { return secondClosestQueryNode; }
//! Get the minimum distance to the closest query node.
double MinQueryNodeDistance() const { return minQueryNodeDistance; }
@@ -136,9 +136,9 @@ class DualTreeKMeansStatistic
arma::vec centroid;
//! The current closest query node to this reference node.
- void* closestQueryNode;
+// void* closestQueryNode;
//! The second closest query node.
- void* secondClosestQueryNode;
+// void* secondClosestQueryNode;
//! The minimum distance to the closest query node.
double minQueryNodeDistance;
//! The maximum distance to the closest query node.
More information about the mlpack-git
mailing list