[mlpack-git] master: Some fairly serious refactoring, but it works. (56cbc1c)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:39 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 56cbc1c3bbf7fc08fb7b4b1b8bb607902b57a7a6
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Jan 12 17:07:50 2015 -0500
Some fairly serious refactoring, but it works.
There are significant optimizations to still be made, but at this point we get
some acceleration over the naive algorithm.
>---------------------------------------------------------------
56cbc1c3bbf7fc08fb7b4b1b8bb607902b57a7a6
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 5 +-
.../methods/kmeans/dual_tree_kmeans_rules.hpp | 2 -
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 158 +++++----------------
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 2 +-
4 files changed, 37 insertions(+), 130 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index ff6ddf8..c59b496 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -88,6 +88,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
typename TreeType::template BreadthFirstDualTreeTraverser<RulesType>
traverser(rules);
+ tree->Stat().ClustersPruned() = 0; // The constructor sets this to -1.
traverser.Traverse(*centroidTree, *tree);
distanceCalculations += rules.DistanceCalculations();
@@ -220,8 +221,8 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
// We have to set the closest query node to NULL because the cluster tree will
// be rebuilt.
node->Stat().ClosestQueryNode() = NULL;
-// node->Stat().MaxQueryNodeDistance() = DBL_MAX;
-// node->Stat().MinQueryNodeDistance() = DBL_MAX;
+ node->Stat().MaxQueryNodeDistance() = DBL_MAX;
+ node->Stat().MinQueryNodeDistance() = DBL_MAX;
for (size_t i = 0; i < node->NumChildren(); ++i)
TreeUpdate(&node->Child(i), clusters, clusterDistances);
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index fe88edc..2fedea0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -68,8 +68,6 @@ class DualTreeKMeansRules
TraversalInfoType traversalInfo;
- double IterationUpdate(TreeType& referenceNode);
-
bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
potentialChild) const;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index f0c4c00..b13e544 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -50,13 +50,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
const size_t queryIndex,
const size_t referenceIndex)
{
-// Log::Info << "Base case, query " << queryIndex << " (" << mappings[queryIndex]
-// << "), reference " << referenceIndex << ".\n";
-
// Collect the number of clusters that have been pruned during the traversal.
// The ternary operator may not be necessary.
const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL) ?
-// traversalInfo.LastReferenceNode()->Stat().Iteration() == iteration) ?
traversalInfo.LastReferenceNode()->Stat().ClustersPruned() : 0;
// It's possible that the reference node has been pruned before we got to the
@@ -87,8 +83,6 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
{
-// Log::Warn << "Commit reference index " << referenceIndex << " to cluster "
-// << assignments[referenceIndex] << ".\n";
newCentroids.col(assignments[referenceIndex]) +=
dataset.col(referenceIndex);
++counts(assignments[referenceIndex]);
@@ -99,12 +93,9 @@ inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
template<typename MetricType, typename TreeType>
double DualTreeKMeansRules<MetricType, TreeType>::Score(
- const size_t /* queryIndex */,
+ const size_t queryIndex,
TreeType& referenceNode)
{
- // Update from previous iteration, if necessary.
-// IterationUpdate(referenceNode);
-
// No pruning here, for now.
return 0.0;
}
@@ -114,6 +105,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
+ // This won't happen with the root since it is explicitly set to 0.
if (referenceNode.Stat().ClustersPruned() == size_t(-1))
referenceNode.Stat().ClustersPruned() =
referenceNode.Parent()->Stat().ClustersPruned();
@@ -124,35 +116,34 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
// We also have to update things if the closest query node is null. This can
// probably be improved.
- if (score != DBL_MAX || referenceNode.Stat().ClosestQueryNode() == NULL)
+ const double minDistance = referenceNode.MinDistance(&queryNode);
+ const double maxDistance = referenceNode.MaxDistance(&queryNode);
+ distanceCalculations += 2;
+ score = PellegMooreScore(queryNode, referenceNode, minDistance);
+
+ if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX &&
+ referenceNode.Parent() != NULL &&
+ referenceNode.Parent()->Stat().MaxQueryNodeDistance() != DBL_MAX)
{
- // Can we update the minimum query node distance for this reference node?
- const double minDistance = referenceNode.MinDistance(&queryNode);
- const double maxDistance = referenceNode.MaxDistance(&queryNode);
- distanceCalculations += 2;
- if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance())
- {
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
-// referenceNode.MaxDistance(&queryNode);
-// ++distanceCalculations;
- return 0.0; // Pruning is not possible.
- }
-
- else if (IsDescendantOf(
- *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
- {
- // Just update.
- referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
- referenceNode.Stat().MinQueryNodeDistance() = minDistance;
- referenceNode.Stat().MaxQueryNodeDistance() =
- referenceNode.MaxDistance(&queryNode);
- ++distanceCalculations;
- return 0.0; // Pruning is not possible.
- }
+ referenceNode.Stat().ClosestQueryNode() =
+ referenceNode.Parent()->Stat().ClosestQueryNode();
+ referenceNode.Stat().MaxQueryNodeDistance() =
+ referenceNode.Parent()->Stat().MaxQueryNodeDistance();
+ }
- score = PellegMooreScore(queryNode, referenceNode, minDistance);
+ if (maxDistance < referenceNode.Stat().MaxQueryNodeDistance() ||
+ referenceNode.Stat().ClosestQueryNode() == NULL)
+ {
+ referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
+ }
+ else if (IsDescendantOf(*((TreeType*)
+ referenceNode.Stat().ClosestQueryNode()), queryNode))
+ {
+ referenceNode.Stat().ClosestQueryNode() == (void*) &queryNode;
+ referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+ referenceNode.Stat().MaxQueryNodeDistance() = maxDistance;
}
if (score == DBL_MAX)
@@ -176,17 +167,16 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
else if (referenceNode.Stat().ClustersPruned() +
visited[referenceNode.Descendant(0)] == centroids.n_cols)
{
- for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+ for (size_t i = 0; i < referenceNode.NumDescendants(); ++i)
{
- const size_t cluster = assignments[referenceNode.Point(i)];
- newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+ const size_t cluster = assignments[referenceNode.Descendant(i)];
+ newCentroids.col(cluster) += dataset.col(referenceNode.Descendant(i));
counts(cluster)++;
}
}
}
return score;
-// return 0.0;
}
template<typename MetricType, typename TreeType>
@@ -205,87 +195,6 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
const double oldScore) const
{
return oldScore;
-
-// if (oldScore == DBL_MAX)
-// return oldScore; // We can't unprune something. This shouldn't happen.
-
-// return ElkanTypeScore(queryNode, referenceNode, oldScore);
-}
-
-template<typename MetricType, typename TreeType>
-inline double DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
- TreeType& referenceNode)
-{
- Log::Fatal << "Update! Why!\n";
- if (referenceNode.Stat().Iteration() == iteration)
- return 0;
-
- const size_t itDiff = iteration - referenceNode.Stat().Iteration();
- referenceNode.Stat().Iteration() = iteration;
- referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
- 0 : referenceNode.Parent()->Stat().ClustersPruned();
- referenceNode.Stat().ClosestQueryNode() = (referenceNode.Parent() == NULL) ?
- NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
-
- if (referenceNode.Stat().ClosestQueryNode() != NULL)
- {
- referenceNode.Stat().MinQueryNodeDistance() =
- referenceNode.MinDistance((TreeType*)
- referenceNode.Stat().ClosestQueryNode());
- referenceNode.Stat().MaxQueryNodeDistance() =
- referenceNode.MaxDistance((TreeType*)
- referenceNode.Stat().ClosestQueryNode());
- distanceCalculations += 2;
- }
-
-
- if (itDiff > 1)
- {
-// referenceNode.Stat().BestMaxDistance() = DBL_MAX;
- referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
- referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
- }
- else
- {
- if (referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX)
- {
- // Update the distance to the closest query node. If this node has an
- // owner, we know how far to increase the bound. Otherwise, increase it
- // by the furthest amount that any centroid moved.
- if (referenceNode.Stat().Owner() < centroids.n_cols)
- {
- referenceNode.Stat().MinQueryNodeDistance() +=
- clusterDistances(referenceNode.Stat().Owner());
- referenceNode.Stat().MaxQueryNodeDistance() +=
- clusterDistances(referenceNode.Stat().Owner());
- }
- else
- {
- referenceNode.Stat().MinQueryNodeDistance() +=
- clusterDistances(centroids.n_cols);
- referenceNode.Stat().MaxQueryNodeDistance() +=
- clusterDistances(centroids.n_cols);
- }
- }
- else
- {
- referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
- referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
- }
- }
-
-// if (referenceNode.Stat().BestMaxDistance() != DBL_MAX)
-// {
-// if (referenceNode.Stat().Owner() < centroids.n_cols)
-// referenceNode.Stat().BestMaxDistance() +=
-// clusterDistances(referenceNode.Stat().Owner());
-// else
-// referenceNode.Stat().BestMaxDistance() +=
-// clusterDistances(centroids.n_cols);
-// }
-// }
-
- return 1;
}
template<typename MetricType, typename TreeType>
@@ -308,12 +217,11 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
{
// We have to calculate the minimum distance between the query node and the
// reference node's best query node. First, try to use the cached distance.
-// const double minQueryDistance = queryNode.Stat().FirstBound();
+ const double minQueryDistance = queryNode.Stat().FirstBound();
if (queryNode.NumDescendants() == 1)
{
const double score = ElkanTypeScore(queryNode, referenceNode,
interclusterDistances[queryNode.Descendant(0)]);
-// Log::Warn << "Elkan scoring: " << score << ".\n";
return score;
}
else
@@ -344,7 +252,7 @@ double DualTreeKMeansRules<MetricType, TreeType>::ElkanTypeScore(
template<typename MetricType, typename TreeType>
double DualTreeKMeansRules<MetricType, TreeType>::PellegMooreScore(
- TreeType& /* queryNode */,
+ TreeType& queryNode,
TreeType& referenceNode,
const double minDistance) const
{
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 6d394b7..4874d5b 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -20,7 +20,7 @@ class DualTreeKMeansStatistic
closestQueryNode(NULL),
minQueryNodeDistance(DBL_MAX),
maxQueryNodeDistance(DBL_MAX),
- clustersPruned(0),
+ clustersPruned(size_t(-1)),
iteration(size_t() - 1),
firstBound(DBL_MAX),
secondBound(DBL_MAX),
More information about the mlpack-git
mailing list