[mlpack-git] master: Remove debugging output. Cover trees work. (722d375)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:05:58 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d
>---------------------------------------------------------------
commit 722d37594b6472399e09daa74d7dc531ab80368c
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue May 19 15:27:17 2015 -0400
Remove debugging output. Cover trees work.
>---------------------------------------------------------------
722d37594b6472399e09daa74d7dc531ab80368c
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 104 +++++++--------------
.../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 2 +
2 files changed, 34 insertions(+), 72 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 27e7591..697fb26 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -120,6 +120,11 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
interclusterDistances[oldFromNewCentroids[i]] =
interclusterDistancesTemp[i];
}
+ else
+ {
+ // TODO: avoid copy.
+ interclusterDistances = interclusterDistancesTemp;
+ }
Timer::Stop("knn");
@@ -216,19 +221,10 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Parent()->Stat().Pruned() == centroids.n_cols &&
node.Parent()->Stat().Owner() < centroids.n_cols)
{
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
-node.NumDescendants() << " from parent " << node.Parent()->Point(0) << "c" <<
-node.Parent()->NumDescendants() << " from " << node.Stat().UpperBound() << " to " << parentUpperBound << ".\n";
// When taking bounds from the parent, note that the parent has already
// adjusted the bounds according to the cluster movements, so we need to
// de-adjust them since we'll adjust them again. Maybe there is a smarter
// way to do this...
- if (node.Stat().UpperBound() < parentUpperBound && !prunedLastIteration)
- {
- Log::Warn << node;
- Log::Fatal << "Wat, upper bound not DBL_MAX.\n";
- }
node.Stat().UpperBound() = parentUpperBound;
node.Stat().LowerBound() = parentLowerBound;
node.Stat().Pruned() = node.Parent()->Stat().Pruned();
@@ -240,7 +236,7 @@ node.Parent()->NumDescendants() << " from " << node.Stat().UpperBound() << " to
double adjustedLowerBound = adjustedParentLowerBound;
// Exhaustive lower bound check. Sigh.
-
+/*
if (!prunedLastIteration)
{
for (size_t i = 0; i < node.NumDescendants(); ++i)
@@ -277,6 +273,7 @@ node.Stat().UpperBound() << " with closest cluster distance " << closest <<
node.Stat().LowerBound()))
{
Log::Warn << distances.t();
+ Log::Warn << node;
Log::Fatal << "Point " << node.Descendant(i) << " in " << node.Point(0) <<
"c" << node.NumDescendants() << " invalidates lower bound " <<
std::min(lowerBounds[node.Descendant(i)], node.Stat().LowerBound()) << " (" <<
@@ -288,39 +285,26 @@ visited[node.Descendant(i)] << ".\n";
}
}
}
-
+*/
if ((node.Stat().Pruned() == centroids.n_cols) &&
(node.Stat().Owner() < centroids.n_cols))
{
// Adjust bounds.
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Debug << "Update upper bound for node " << node.Point(0) << "c" <<
-node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
-node.Stat().UpperBound() + clusterDistances[node.Stat().Owner()] << " as cluster"
- << " adjustment.\n";
node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
if (node.Point(0) == 10475 || node.Point(0) == 12756)
Log::Debug << node;
if (adjustedParentUpperBound < node.Stat().UpperBound())
- {
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- {
- Log::Debug << "Take adjusted parent upper bound of " <<
-adjustedParentUpperBound << " vs. adjusted bound of " <<
-node.Stat().UpperBound() << ".\n";
- }
node.Stat().UpperBound() = adjustedParentUpperBound;
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Debug << node;
- }
if (adjustedParentLowerBound > node.Stat().LowerBound())
node.Stat().LowerBound() = adjustedParentLowerBound;
+ // Try to use the inter-cluster distances to produce a better lower bound,
+ // if possible.
const double interclusterBound = interclusterDistances[node.Stat().Owner()]
/ 2.0;
if (interclusterBound > node.Stat().LowerBound())
@@ -328,43 +312,22 @@ node.Stat().UpperBound() << ".\n";
node.Stat().LowerBound() = interclusterBound;
adjustedLowerBound = node.Stat().LowerBound();
}
+
if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
-<< " as statically pruned.\n" << node;
node.Stat().StaticPruned() = true;
}
else
{
// Tighten bound.
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- {
- Log::Debug << centroids.col(node.Stat().Owner()).t() << ".\n";
- Log::Debug << dataset.col(node.Point(0)).t() << ".\n";
- Log::Debug << "FDD " << node.FurthestDescendantDistance() << ".\n";
- Log::Debug << metric.Evaluate(centroids.col(node.Stat().Owner()),
-dataset.col(node.Point(0))) << ".\n";
- Log::Debug << "Tighten upper bound for node " << node.Point(0) << "c" <<
-node.NumDescendants() << " from " << node.Stat().UpperBound() << " to " <<
-node.MaxDistance(centroids.col(node.Stat().Owner())) << " with owner " <<
-node.Stat().Owner() << ".\n";
- }
node.Stat().UpperBound() =
std::min(node.Stat().UpperBound(),
node.MaxDistance(centroids.col(node.Stat().Owner())));
adjustedUpperBound = node.Stat().UpperBound();
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Debug << node;
++distanceCalculations;
if (node.Stat().UpperBound() < node.Stat().LowerBound())
- {
- if (node.Point(0) == 10475 || node.Point(0) == 12756)
- Log::Warn << "Mark r" << node.Point(0) << "c" << node.NumDescendants()
-<< " as statically pruned.\n" << node;
node.Stat().StaticPruned() = true;
- }
}
}
else
@@ -421,9 +384,16 @@ node.Stat().Owner() << ".\n";
}
else
{
- // Point cannot be pruned.
- upperBounds[index] = DBL_MAX;
- lowerBounds[index] = DBL_MAX;
+ // Point cannot be pruned. We may have to inspect the point at a
+ // lower level, though. If that's the case, then we shouldn't
+ // invalidate the bounds we've got -- it will happen at the lower
+ // level.
+ if (!tree::TreeTraits<TreeType>::HasSelfChildren ||
+ node.NumChildren() == 0)
+ {
+ upperBounds[index] = DBL_MAX;
+ lowerBounds[index] = DBL_MAX;
+ }
allPointsPruned = false;
}
}
@@ -440,17 +410,17 @@ node.Stat().Owner() << ".\n";
adjustedLowerBound);
if (!node.Child(i).Stat().StaticPruned())
allChildrenPruned = false;
-// node.Child(i).Stat().StaticPruned() = true;
}
+/*
if (node.Stat().StaticPruned() && !allChildrenPruned)
{
Log::Warn << node;
for (size_t i = 0; i < node.NumChildren(); ++i)
Log::Warn << "child " << i << ":\n" << node.Child(i);
-// Log::Warn << "grandchild: " << node.Child(1).Child(1);
Log::Fatal << "Node is statically pruned but not all its children are!\n";
}
+*/
// If all of the children and points are pruned, we may mark this node as
// pruned.
@@ -487,9 +457,6 @@ node.Stat().Owner() << ".\n";
clusterDistances[centroids.n_cols];
}
}
-
- if (node.Point(0) == 12756)
- Log::Debug << "Node at end of iteration:\n" << node;
}
template<typename MetricType, typename MatType, typename TreeType>
@@ -508,7 +475,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
newCounts[owner] += node.NumDescendants();
// Perform the sanity check here.
-
+/*
for (size_t i = 0; i < node.NumDescendants(); ++i)
{
const size_t index = node.Descendant(i);
@@ -532,7 +499,7 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
}
}
-
+*/
}
else
{
@@ -546,7 +513,7 @@ node.Stat().UpperBound() << " and owner " << node.Stat().Owner() << ".\n";
newCentroids.col(owner) += dataset.col(node.Point(i));
++newCounts[owner];
-
+/*
const size_t index = node.Point(i);
arma::vec trueDistances(centroids.n_cols);
for (size_t j = 0; j < centroids.n_cols; ++j)
@@ -571,7 +538,7 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
(visited[node.Point(i)] ? "true"
: "false") << ".\n";
}
-
+*/
}
}
@@ -583,10 +550,9 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
- TreeType& /*node*/,
- const size_t /*child /* Which child are we? */)
+ TreeType& node,
+ const size_t child /* Which child are we? */)
{
-/*
// If all children except one are pruned, we can hide this node.
if (node.NumChildren() == 0)
return; // We can't do anything.
@@ -627,24 +593,18 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
for (size_t i = 0; i < node.NumChildren(); ++i)
CoalesceTree(node.Child(i), i);
}
-*/
}
template<typename MetricType, typename MatType, typename TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
- TreeType& /*node*/)
+ TreeType& node)
{
-/*
node.Parent() = (TreeType*) node.Stat().TrueParent();
for (size_t i = 0; i < node.NumChildren(); ++i)
node.ChildPtr(i) = (TreeType*) node.Stat().TrueChild(i);
- if (node.NumChildren() > 0)
- {
- DecoalesceTree(node.Child(0));
- DecoalesceTree(node.Child(1));
- }
-*/
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ DecoalesceTree(node.Child(i));
}
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
index 8fc883e..c690679 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -269,12 +269,14 @@ inline double DualTreeKMeansRules<MetricType, TreeType>::Score(
}
// Is everything pruned?
+
if (queryNode.Stat().Pruned() == centroids.n_cols - 1)
{
queryNode.Stat().Pruned() = centroids.n_cols; // Owner() is already set.
return DBL_MAX;
}
+
// Set traversal information.
traversalInfo.LastQueryNode() = &queryNode;
traversalInfo.LastReferenceNode() = &referenceNode;
More information about the mlpack-git
mailing list