[mlpack-git] master: Refactor to re-use distances to clusters. This is better than setting the distances to DBL_MAX every iteration, and provides reasonable speedup. (76e10e0)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:19 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 76e10e046328e8f24b1db03d80b70e35cbf4512c
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 3 17:22:27 2015 -0500
Refactor to re-use distances to clusters. This is better than setting the distances to DBL_MAX every iteration, and provides reasonable speedup.
>---------------------------------------------------------------
76e10e046328e8f24b1db03d80b70e35cbf4512c
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 2 ++
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 38 ++++++++++++++++++++++----
src/mlpack/methods/kmeans/dtnn_rules.hpp | 5 +++-
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 9 ++++--
4 files changed, 45 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index b939f32..f9b35bb 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -99,6 +99,8 @@ class DTNNKMeans
arma::mat distances;
arma::Mat<size_t> assignments;
+ std::vector<bool> visited; // Was the point visited this iteration?
+
arma::mat lastIterationCentroids; // For sanity checks.
//! Update the bounds in the tree before the next iteration.
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 95546d3..ae2dde7 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -62,6 +62,13 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
lowerSecondBounds.zeros(dataset.n_cols);
lastOwners.zeros(dataset.n_cols);
+ assignments.set_size(2, dataset.n_cols);
+ assignments.fill(size_t(-1));
+ distances.set_size(2, dataset.n_cols);
+ distances.fill(DBL_MAX);
+
+ visited.resize(dataset.n_cols, false);
+
Timer::Start("tree_building");
// Copy the dataset, if necessary.
@@ -135,11 +142,9 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
- distances.fill(DBL_MAX);
- assignments.fill(size_t(-1));
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, distances, metric,
- prunedPoints, oldFromNewCentroids);
+ prunedPoints, oldFromNewCentroids, visited);
// Now construct the traverser ourselves.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -153,10 +158,12 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// From the assignments, calculate the new centroids and counts.
for (size_t i = 0; i < dataset.n_cols; ++i)
{
- if (assignments(0, i) != size_t(-1))
+ if (visited[i])
{
newCentroids.col(assignments(0, i)) += dataset.col(i);
++counts(assignments(0, i));
+ // Reset for next iteration.
+ visited[i] = false;
}
}
@@ -541,7 +548,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
*/
prunedPoints[index] = true;
upperBounds[index] += clusterDistances[owner];
+ distances(0, index) += clusterDistances[owner];
lastOwners[index] = owner;
+ distances(1, index) += clusterDistances[centroids.n_cols];
lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
prunedCentroids.col(owner) += dataset.col(index);
prunedCounts(owner)++;
@@ -551,7 +560,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
{
prunedPoints[index] = true;
upperBounds[index] += clusterDistances[owner];
+ distances(0, index) += clusterDistances[owner];
lastOwners[index] = owner;
+ distances(1, index) += clusterDistances[centroids.n_cols];
lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
prunedCentroids.col(owner) += dataset.col(index);
prunedCounts(owner)++;
@@ -559,24 +570,29 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
else
{
// Attempt to tighten the lower bound.
- upperBounds[index] = metric.Evaluate(centroids.col(owner),
+ distances(0, index) = metric.Evaluate(centroids.col(owner),
dataset.col(index));
+ upperBounds[index] = distances(0, index);
++distanceCalculations;
if (upperBounds[index] < lowerSecondBounds[index] -
clusterDistances[centroids.n_cols])
{
prunedPoints[index] = true;
lastOwners[index] = owner;
+ upperBounds[index] += clusterDistances[owner];
lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
+ distances(1, index) += clusterDistances[centroids.n_cols];
prunedCentroids.col(owner) += dataset.col(index);
prunedCounts(owner)++;
}
else if (upperBounds[index] < 0.5 *
- interclusterDistances[newFromOldCentroids[owner]])
+ interclusterDistances[newFromOldCentroids[owner]])
{
prunedPoints[index] = true;
lastOwners[index] = owner;
+ upperBounds[index] += clusterDistances[owner];
lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
+ distances(1, index) += clusterDistances[centroids.n_cols];
prunedCentroids.col(owner) += dataset.col(index);
prunedCounts(owner)++;
}
@@ -584,6 +600,9 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
{
prunedPoints[index] = false;
allPruned = false;
+ // Still update these anyway.
+ distances(0, index) += clusterDistances[owner];
+ distances(1, index) += clusterDistances[centroids.n_cols];
}
}
}
@@ -607,6 +626,13 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
}
}
+ // Make sure all the point bounds are updated.
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ distances(0, node.Point(i)) += clusterDistances[centroids.n_cols];
+ distances(1, node.Point(i)) += clusterDistances[centroids.n_cols];
+ }
+
if (node.Stat().FirstBound() != DBL_MAX)
node.Stat().FirstBound() += clusterDistances[centroids.n_cols];
if (node.Stat().SecondBound() != DBL_MAX)
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 2ded3ea..0183e9e 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -25,7 +25,8 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
arma::mat& distances,
MetricType& metric,
const std::vector<bool>& prunedPoints,
- const std::vector<size_t>& oldFromNewCentroids);
+ const std::vector<size_t>& oldFromNewCentroids,
+ std::vector<bool>& visited);
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
@@ -43,6 +44,8 @@ class DTNNKMeansRules : public neighbor::NeighborSearchRules<
const std::vector<bool>& prunedPoints;
const std::vector<size_t>& oldFromNewCentroids;
+
+ std::vector<bool>& visited;
};
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 11c2c3e..7e4e48b 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -20,11 +20,13 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
arma::mat& distances,
MetricType& metric,
const std::vector<bool>& prunedPoints,
- const std::vector<size_t>& oldFromNewCentroids) :
+ const std::vector<size_t>& oldFromNewCentroids,
+ std::vector<bool>& visited) :
neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
TreeType>(centroids, dataset, neighbors, distances, metric),
prunedPoints(prunedPoints),
- oldFromNewCentroids(oldFromNewCentroids)
+ oldFromNewCentroids(oldFromNewCentroids),
+ visited(visited)
{
// Nothing to do.
}
@@ -38,6 +40,9 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
if (prunedPoints[queryIndex])
return 0.0; // Returning 0 shouldn't be a problem.
+ // Any base cases imply that we will get a result.
+ visited[queryIndex] = true;
+
// This is basically an inlined NeighborSearchRules::BaseCase(), but it
// differs in that it applies the mappings to the results automatically.
// We can also skip a check or two.
More information about the mlpack-git
mailing list