[mlpack-git] master: Slight refactoring of dual-tree k-means rules. (c63a518)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:01:31 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit c63a518c8b6bf01e393028c5e91c7f7a95df662a
Author: ryan <ryan at ratml.org>
Date: Thu Jan 1 18:56:00 2015 -0500
Slight refactoring of dual-tree k-means rules.
This is an attempt to clean up this code because it quickly became unwieldy and
hard to work with. A more modular approach could lead to a better, faster
solution.
>---------------------------------------------------------------
c63a518c8b6bf01e393028c5e91c7f7a95df662a
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 75 ++++++++++++++++++++++
.../methods/kmeans/dual_tree_kmeans_rules.hpp | 2 +-
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 57 +++++++++++-----
3 files changed, 118 insertions(+), 16 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index fe63498..088d565 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -112,6 +112,81 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
return std::sqrt(residual);
}
+/*
+template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::ClusterTreeUpdate(
+ TreeType* node)
+{
+ // We will abuse stat.owner to hold the cluster with the most change.
+ // stat.minQueryNodeDistance will hold the distance.
+ double maxChange = 0.0;
+ size_t maxChangeCluster = 0;
+
+ for (size_t i = 0; i < node->NumChildren(); ++i)
+ {
+ ClusterTreeUpdate(&node->Child(i));
+
+ const double nodeChange = node->Child(i).Stat().MinQueryNodeDistance();
+ if (nodeChange > maxChange)
+ {
+ maxChange = nodeChange;
+ maxChangeCluster = node->Child(i).Stat().Owner();
+ }
+ }
+
+ for (size_t i = 0; i < node->NumPoints(); ++i)
+ {
+ const size_t cluster = oldFromNewCentroids[node->Point(i)];
+ const double pointChange = clusterDistances[cluster];
+ if (pointChange > maxChange)
+ {
+ maxChange = pointChange;
+ maxChangeCluster = cluster;
+ }
+ }
+
+ node->Stat().Owner() = maxChangeCluster;
+ node->Stat().MinQueryNodeDistance() = maxChange;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::TreeUpdate(
+ TreeType* node)
+{
+ // This is basically IterationUpdate(), but pulled out to be separate from the
+ // actual dual-tree algorithm.
+
+ // First, update the iteration.
+ const size_t itDiff = node->Stat().Iteration() - iteration;
+
+ if (itDiff == 1)
+ {
+ // The easy case.
+ if (node->Stat().Owner() < centroids.n_cols)
+ {
+ // During the last iteration, this node was pruned. In addition, we have
+ // cached a lower bound on the second closest cluster. So, use the
+ // triangle inequality: if the maximum distance between the point and the
+ // cluster centroid plus the distance that centroid moved is less than the
+ // lower bound minus the maximum moving centroid, then this cluster *must*
+ // still have the same owner.
+ const size_t owner = node->Stat().Owner();
+ const double closestUpperBound = node->Stat().MaxQueryNodeDistance() +
+ clusterDistances[owner];
+ const TreeType* nonOwner = (TreeType*) node->Stat().ClosestNonOwner();
+ const double tightestLowerBound = node->Stat().ClosestNonOwnerDistance() -
+ nonOwner->Stat().MinQueryNodeDistance() /* abused from earlier *;
+ if (closestUpperBound <= tightestLowerBound)
+ {
+ // Then the owner must not have changed.
+
+ }
+ }
+ }
+}
+*/
+
+
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
index 978c12c..24cad31 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -66,7 +66,7 @@ class DualTreeKMeansRules
TraversalInfoType traversalInfo;
- size_t IterationUpdate(TreeType& referenceNode) const;
+ double IterationUpdate(TreeType& referenceNode);
bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
potentialChild) const;
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 33ea4ae..7676e1e 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -112,7 +112,11 @@ double DualTreeKMeansRules<MetricType, TreeType>::Score(
TreeType& queryNode,
TreeType& referenceNode)
{
- IterationUpdate(referenceNode);
+ if (IterationUpdate(referenceNode) == DBL_MAX)
+ {
+ // The iteration update showed that the owner could not possibly change.
+ return DBL_MAX;
+ }
traversalInfo.LastReferenceNode() = &referenceNode;
@@ -201,12 +205,13 @@ double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
}
template<typename MetricType, typename TreeType>
-inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
- TreeType& referenceNode) const
+inline double DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
+ TreeType& referenceNode)
{
if (referenceNode.Stat().Iteration() == iteration)
return 0;
+ const size_t itDiff = iteration - referenceNode.Stat().Iteration();
referenceNode.Stat().Iteration() = iteration;
referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
0 : referenceNode.Parent()->Stat().ClustersPruned();
@@ -214,15 +219,22 @@ inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
if (referenceNode.Stat().ClosestQueryNode() != NULL)
+ {
referenceNode.Stat().MinQueryNodeDistance() =
referenceNode.MinDistance((TreeType*)
referenceNode.Stat().ClosestQueryNode());
+ referenceNode.Stat().MaxQueryNodeDistance() =
+ referenceNode.MaxDistance((TreeType*)
+ referenceNode.Stat().ClosestQueryNode());
+ distanceCalculations += 2;
+ }
+
- const size_t itDiff = iteration - referenceNode.Stat().Iteration();
if (itDiff > 1)
{
- // Maybe this can be tighter?
+// referenceNode.Stat().BestMaxDistance() = DBL_MAX;
referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+ referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
}
else
{
@@ -231,24 +243,39 @@ inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
// Update the distance to the closest query node. If this node has an
// owner, we know how far to increase the bound. Otherwise, increase it
// by the furthest amount that any centroid moved.
-// if (referenceNode.Stat().Owner() < centroids.n_cols)
-// referenceNode.Stat().MinQueryNodeDistance() +=
-// clusterDistances(referenceNode.Stat().Owner());
-// else
-// referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
-// clusterDistances(centroids.n_cols);
- if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX)
- referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+ if (referenceNode.Stat().Owner() < centroids.n_cols)
+ {
+ referenceNode.Stat().MinQueryNodeDistance() +=
+ clusterDistances(referenceNode.Stat().Owner());
+ referenceNode.Stat().MaxQueryNodeDistance() +=
+ clusterDistances(referenceNode.Stat().Owner());
+ }
else
{
referenceNode.Stat().MinQueryNodeDistance() +=
clusterDistances(centroids.n_cols);
-//referenceNode.Stat().MaxQueryNodeDistance() +
-//clusterDistances(centroids.n_cols);
+ referenceNode.Stat().MaxQueryNodeDistance() +=
+ clusterDistances(centroids.n_cols);
}
}
+ else
+ {
+ referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+ referenceNode.Stat().MaxQueryNodeDistance() = DBL_MAX;
+ }
}
+// if (referenceNode.Stat().BestMaxDistance() != DBL_MAX)
+// {
+// if (referenceNode.Stat().Owner() < centroids.n_cols)
+// referenceNode.Stat().BestMaxDistance() +=
+// clusterDistances(referenceNode.Stat().Owner());
+// else
+// referenceNode.Stat().BestMaxDistance() +=
+// clusterDistances(centroids.n_cols);
+// }
+// }
+
return 1;
}
More information about the mlpack-git
mailing list