[mlpack-git] master: Refactor UpdateTree() to pass parent bounds. This is a little bit tighter, I think, and is necessary for cover trees, where the MaxDistance() to a node is not necessarily bounded above by the MaxDistance() to the node's parent. (64b52e0)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:05:53 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d
>---------------------------------------------------------------
commit 64b52e0c1296e4b4c9edd9ec05d780f4afedbcff
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon May 18 23:40:54 2015 -0400
Refactor UpdateTree() to pass parent bounds.
This is a little bit tighter, I think, and is necessary for cover trees, where the MaxDistance() to a node is not necessarily bounded above by the MaxDistance() to the node's parent.
>---------------------------------------------------------------
64b52e0c1296e4b4c9edd9ec05d780f4afedbcff
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 6 +-
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 124 +++++++++++++++++----
2 files changed, 110 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index d849379..f4fd3aa 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -100,7 +100,11 @@ class DualTreeKMeans
//! centroids is the current (not yet searched) centroids.
void UpdateTree(TreeType& node,
const arma::mat& centroids,
- const arma::vec& interclusterDistances);
+ const arma::vec& interclusterDistances,
+ const double parentUpperBound = 0.0,
+ const double adjustedParentUpperBound = DBL_MAX,
+ const double parentLowerBound = DBL_MAX,
+ const double adjustedParentLowerBound = 0.0);
//! Extract the centroids of the clusters.
void ExtractCentroids(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 8ccdb43..27e7591 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -136,7 +136,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
}
// We won't use the AllkNN class here because we have our own set of rules.
- //lastIterationCentroids = oldCentroids;
+ lastIterationCentroids = oldCentroids;
typedef DualTreeKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
metric, prunedPoints, oldFromNewCentroids, visited);
@@ -192,6 +192,8 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
delete centroidTree;
++iteration;
+ Log::Debug << counts.t();
+ Log::Debug << arma::accu(counts) << "!\n";
return std::sqrt(residual);
}
@@ -200,25 +202,46 @@ template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
const arma::mat& centroids,
- const arma::vec& interclusterDistances)
+ const arma::vec& interclusterDistances,
+ const double parentUpperBound,
+ const double adjustedParentUpperBound,
+ const double parentLowerBound,
+ const double adjustedParentLowerBound)
{
const bool prunedLastIteration = node.Stat().StaticPruned();
node.Stat().StaticPruned() = false;
// Grab information from the parent, if we can.
if (node.Parent() != NULL &&
- node.Parent()->Stat().Pruned() == centroids.n_cols)
+ node.Parent()->Stat().Pruned() == centroids.n_cols &&
+ node.Parent()->Stat().Owner() < centroids.n_cols)
{
- node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
- node.Stat().LowerBound() = node.Parent()->Stat().LowerBound() +
- clusterDistances[centroids.n_cols];
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from parent " << node.Parent()->Point(0) << "c" <<
+node.Parent()->NumDescendants() << " from " << node.Stat().UpperBound() << " to " << parentUpperBound << ".\n";
+ // When taking bounds from the parent, note that the parent has already
+ // adjusted the bounds according to the cluster movements, so we need to
+ // de-adjust them since we'll adjust them again. Maybe there is a smarter
+ // way to do this...
+ if (node.Stat().UpperBound() < parentUpperBound && !prunedLastIteration)
+ {
+ Log::Warn << node;
+ Log::Fatal << "Wat, upper bound not DBL_MAX.\n";
+ }
+ node.Stat().UpperBound() = parentUpperBound;
+ node.Stat().LowerBound() = parentLowerBound;
node.Stat().Pruned() = node.Parent()->Stat().Pruned();
node.Stat().Owner() = node.Parent()->Stat().Owner();
}
-
+ const double unadjustedUpperBound = node.Stat().UpperBound();
+ double adjustedUpperBound = adjustedParentUpperBound;
+ const double unadjustedLowerBound = node.Stat().LowerBound();
+ double adjustedLowerBound = adjustedParentLowerBound;
// Exhaustive lower bound check. Sigh.
-/* if (!prunedLastIteration)
+
+ if (!prunedLastIteration)
{
for (size_t i = 0; i < node.NumDescendants(); ++i)
{
@@ -239,7 +262,6 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
else if (dist < secondClosest)
secondClosest = dist;
}
-
if (closest - 1e-10 > node.Stat().UpperBound())
{
Log::Warn << distances.t();
@@ -265,31 +287,82 @@ visited[node.Descendant(i)] << ".\n";
}
}
}
- }*/
+ }
+
if ((node.Stat().Pruned() == centroids.n_cols) &&
(node.Stat().Owner() < centroids.n_cols))
{
// Adjust bounds.
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
+node.Stat().UpperBound() + clusterDistances[node.Stat().Owner()] << " as cluster"
+ << " adjustment.\n";
node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Debug << node;
+
+ if (adjustedParentUpperBound < node.Stat().UpperBound())
+ {
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ {
+ Log::Debug << "Take adjusted parent upper bound of " <<
+adjustedParentUpperBound << " vs. adjusted bound of " <<
+node.Stat().UpperBound() << ".\n";
+ }
+ node.Stat().UpperBound() = adjustedParentUpperBound;
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Debug << node;
+ }
+
+ if (adjustedParentLowerBound > node.Stat().LowerBound())
+ node.Stat().LowerBound() = adjustedParentLowerBound;
+
const double interclusterBound = interclusterDistances[node.Stat().Owner()]
/ 2.0;
if (interclusterBound > node.Stat().LowerBound())
+ {
node.Stat().LowerBound() = interclusterBound;
+ adjustedLowerBound = node.Stat().LowerBound();
+ }
if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
+<< " as statically pruned.\n" << node;
node.Stat().StaticPruned() = true;
}
else
{
// Tighten bound.
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ {
+ Log::Debug << centroids.col(node.Stat().Owner()).t() << ".\n";
+ Log::Debug << dataset.col(node.Point(0)).t() << ".\n";
+ Log::Debug << "FDD " << node.FurthestDescendantDistance() << ".\n";
+ Log::Debug << metric.Evaluate(centroids.col(node.Stat().Owner()),
+dataset.col(node.Point(0))) << ".\n";
+ Log::Debug << "Tighten upper bound for node " << node.Point(0) << "c" <<
+node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
+node.MaxDistance(centroids.col(node.Stat().Owner())) << " with owner " <<
+node.Stat().Owner() << ".\n";
+ }
node.Stat().UpperBound() =
- node.MaxDistance(centroids.col(node.Stat().Owner()));
+ std::min(node.Stat().UpperBound(),
+ node.MaxDistance(centroids.col(node.Stat().Owner())));
+ adjustedUpperBound = node.Stat().UpperBound();
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Debug << node;
+
++distanceCalculations;
if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
+ if (node.Point(0) == 10475 || node.Point(0) == 12756)
+ Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
+<< " as statically pruned.\n" << node;
node.Stat().StaticPruned() = true;
}
}
@@ -362,14 +435,20 @@ visited[node.Descendant(i)] << ".\n";
bool allChildrenPruned = true;
for (size_t i = 0; i < node.NumChildren(); ++i)
{
- UpdateTree(node.Child(i), centroids, interclusterDistances);
+ UpdateTree(node.Child(i), centroids, interclusterDistances,
+ unadjustedUpperBound, adjustedUpperBound, unadjustedLowerBound,
+ adjustedLowerBound);
if (!node.Child(i).Stat().StaticPruned())
allChildrenPruned = false;
+// node.Child(i).Stat().StaticPruned() = true;
}
if (node.Stat().StaticPruned() && !allChildrenPruned)
{
Log::Warn << node;
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ Log::Warn << "child " << i << ":\n" << node.Child(i);
+// Log::Warn << "grandchild: " << node.Child(1).Child(1);
Log::Fatal << "Node is statically pruned but not all its children are!\n";
}
@@ -408,6 +487,9 @@ visited[node.Descendant(i)] << ".\n";
clusterDistances[centroids.n_cols];
}
}
+
+ if (node.Point(0) == 12756)
+ Log::Debug << "Node at end of iteration:\n" << node;
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -426,7 +508,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
newCounts[owner] += node.NumDescendants();
// Perform the sanity check here.
-/*
+
for (size_t i = 0; i < node.NumDescendants(); ++i)
{
const size_t index = node.Descendant(i);
@@ -450,7 +532,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
}
}
-*/
+
}
else
{
@@ -464,7 +546,7 @@ node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
newCentroids.col(owner) += dataset.col(node.Point(i));
++newCounts[owner];
-/*
+
const size_t index = node.Point(i);
arma::vec trueDistances(centroids.n_cols);
for (size_t j = 0; j < centroids.n_cols; ++j)
@@ -489,7 +571,7 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
(visited[node.Point(i)] ? "true"
: "false") << ".\n";
}
-*/
+
}
}
@@ -501,9 +583,10 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
- TreeType& node,
- const size_t child /* Which child are we? */)
+ TreeType& /*node*/,
+ const size_t /*child /* Which child are we? */)
{
+/*
// If all children except one are pruned, we can hide this node.
if (node.NumChildren() == 0)
return; // We can't do anything.
@@ -544,12 +627,14 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
for (size_t i = 0; i < node.NumChildren(); ++i)
CoalesceTree(node.Child(i), i);
}
+*/
}
template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
- TreeType& node)
+ TreeType& /*node*/)
{
+/*
node.Parent() = (TreeType*) node.Stat().TrueParent();
for (size_t i = 0; i < node.NumChildren(); ++i)
node.ChildPtr(i) = (TreeType*) node.Stat().TrueChild(i);
@@ -559,6 +644,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
DecoalesceTree(node.Child(0));
DecoalesceTree(node.Child(1));
}
+*/
}
} // namespace kmeans
More information about the mlpack-git
mailing list