[mlpack-git] master: Coalesce the tree before kNN. Speedup! This is a fairly significant speedup, actually. (e36e25f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:42 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit e36e25f2d9aa1dd4b113a2d036563eeee15069f7
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Feb 4 14:15:57 2015 -0500
Coalesce the tree before kNN. Speedup! This is a fairly significant speedup, actually.
>---------------------------------------------------------------
e36e25f2d9aa1dd4b113a2d036563eeee15069f7
.../tree/binary_space_tree/binary_space_tree.hpp | 3 +
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 2 +
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 3 +
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 72 ++++++++++++++++++++++
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 10 +++
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 22 ++++++-
6 files changed, 110 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index db1ece7..6f98c6c 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -332,6 +332,9 @@ class BinarySpaceTree
*/
BinarySpaceTree& Child(const size_t child) const;
+ BinarySpaceTree*& ChildPtr(const size_t child)
+ { return (child == 0) ? left : right; }
+
//! Return the number of points in this node (0 if not a leaf).
size_t NumPoints() const;
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index bad4b8b..41db6fb 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -224,6 +224,8 @@ class CoverTree
//! Modify a particular child node.
CoverTree& Child(const size_t index) { return *children[index]; }
+ CoverTree*& ChildPtr(const size_t index) { return children[index]; }
+
//! Get the number of children.
size_t NumChildren() const { return children.size(); }
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 547842b..e6ccd58 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -108,6 +108,9 @@ class DTNNKMeans
const std::vector<size_t>& newFromOldCentroids);
void PrecalculateCentroids(TreeType& node);
+
+ void CoalesceTree(TreeType& node, const size_t child = 0);
+ void DecoalesceTree(TreeType& node);
};
//! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 40dec27..f30f85d 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -118,6 +118,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
newFromOldCentroids[oldFromNewCentroids[i]] = i;
}
+ Timer::Start("knn");
// Find the nearest neighbors of each of the clusters.
neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
nns(centroidTree, centroids);
@@ -125,21 +126,29 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::Mat<size_t> closestClusters; // We don't actually care about these.
nns.Search(1, closestClusters, interclusterDistances);
distanceCalculations += nns.BaseCases() + nns.Scores();
+ Timer::Stop("knn");
if (iteration != 0)
{
// Do the tree update for the previous iteration.
// Reset centroids and counts for things we will collect during pruning.
+ Timer::Start("it_update");
prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
prunedCounts.zeros(centroids.n_cols);
UpdateTree(*tree, oldCentroids, interclusterDistances, newFromOldCentroids);
PrecalculateCentroids(*tree);
+ Timer::Stop("it_update");
}
+ Timer::Start("tree_mod");
+ CoalesceTree(*tree);
+ Timer::Stop("tree_mod");
+
// We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
+ Timer::Start("knn");
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, distances, metric,
prunedPoints, oldFromNewCentroids, visited);
@@ -148,6 +157,11 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*tree, *centroidTree);
+ Timer::Stop("knn");
+
+ Timer::Start("tree_mod");
+ DecoalesceTree(*tree);
+ Timer::Stop("tree_mod");
Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
rules.Scores() << " scores.\n";
@@ -482,6 +496,64 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
}
template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
+ TreeType& node,
+ const size_t child /* Which child are we? */)
+{
+ // If one of the two children is pruned, we hide this node.
+ // This assumes the BinarySpaceTree. (bad Ryan! bad!)
+ if (node.NumChildren() == 0)
+ return; // We can't do anything.
+
+ // If this is the root node, we can't coalesce.
+ if (node.Parent() != NULL)
+ {
+ if (node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
+ {
+ CoalesceTree(node.Child(1), 1);
+
+ // Link the right child to the parent.
+ node.Child(1).Parent() = node.Parent();
+ node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+ }
+ else if (!node.Child(0).Stat().Pruned() && node.Child(1).Stat().Pruned())
+ {
+ CoalesceTree(node.Child(0), 0);
+
+ // Link the left child to the parent.
+ node.Child(0).Parent() = node.Parent();
+ node.Parent()->ChildPtr(child) = node.ChildPtr(0);
+
+ }
+ else if (!node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
+ {
+ // The conditional is probably not necessary.
+ CoalesceTree(node.Child(0), 0);
+ CoalesceTree(node.Child(1), 1);
+ }
+ }
+ else
+ {
+ CoalesceTree(node.Child(0), 0);
+ CoalesceTree(node.Child(1), 1);
+ }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
+{
+ node.Parent() = (TreeType*) node.Stat().TrueParent();
+ node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
+ node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+
+ if (node.NumChildren() > 0)
+ {
+ DecoalesceTree(node.Child(0));
+ DecoalesceTree(node.Child(1));
+ }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::PrecalculateCentroids(
TreeType& node)
{
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 7e4e48b..a526c94 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -37,8 +37,11 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
const size_t referenceIndex)
{
// We'll check if the query point has been pruned. If so, don't continue.
+// Log::Debug << "Base case " << queryIndex << ", " << referenceIndex <<
+//".\n";
if (prunedPoints[queryIndex])
return 0.0; // Returning 0 shouldn't be a problem.
+// Log::Debug << "(not pruned.)\n";
// Any base cases imply that we will get a result.
visited[queryIndex] = true;
@@ -107,8 +110,11 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& referenceNode)
{
// If the query point has already been pruned, then don't recurse further.
+// Log::Debug << "Score " << queryIndex << ", r" << referenceNode.Point(0) << "c"
+// << referenceNode.NumDescendants() << ".\n";
if (prunedPoints[queryIndex])
return DBL_MAX;
+// Log::Debug << "(not pruned)\n";
return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
MetricType, TreeType>::Score(queryIndex, referenceNode);
@@ -119,8 +125,12 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+// Log::Debug << "Score q" << queryNode.Point(0) << "c" <<
+//queryNode.NumDescendants() << ", r" << referenceNode.Point(0) << "c" <<
+//referenceNode.NumDescendants() << ".\n";
if (queryNode.Stat().Pruned())
return DBL_MAX;
+// Log::Debug << "(not pruned.)\n";
// Check if the query node is Hamerly pruned, and if not, then don't continue.
return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 5ae7f30..691a643 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -23,7 +23,10 @@ class DTNNStatistic : public
maxClusterDistance(DBL_MAX),
secondClusterBound(0.0),
owner(size_t(-1)),
- centroid()
+ centroid(),
+ trueLeft(NULL),
+ trueRight(NULL),
+ trueParent(NULL)
{
// Nothing to do.
}
@@ -35,7 +38,10 @@ class DTNNStatistic : public
iteration(0),
maxClusterDistance(DBL_MAX),
secondClusterBound(0.0),
- owner(size_t(-1))
+ owner(size_t(-1)),
+ trueLeft((void*) &node.Child(0)),
+ trueRight((void*) &node.Child(1)),
+ trueParent((void*) node.Parent())
{
// Empirically calculate the centroid.
centroid.zeros(node.Dataset().n_rows);
@@ -67,6 +73,15 @@ class DTNNStatistic : public
const arma::vec& Centroid() const { return centroid; }
arma::vec& Centroid() { return centroid; }
+ const void* TrueLeft() const { return trueLeft; }
+ void*& TrueLeft() { return trueLeft; }
+
+ const void* TrueRight() const { return trueRight; }
+ void*& TrueRight() { return trueRight; }
+
+ const void* TrueParent() const { return trueParent; }
+ void*& TrueParent() { return trueParent; }
+
std::string ToString() const
{
std::ostringstream o;
@@ -86,6 +101,9 @@ class DTNNStatistic : public
double secondClusterBound;
size_t owner;
arma::vec centroid;
+ void* trueLeft;
+ void* trueRight;
+ void* trueParent;
};
} // namespace kmeans
More information about the mlpack-git
mailing list