[mlpack-git] master: Refactor: do tree update before kNN search. This will allow us to make tighter prunes. (d631986)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:52 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit d631986bf08dd5fc00f5b6a928e04a432c1e4b0a
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Feb 2 20:57:51 2015 -0500
Refactor: do tree update before kNN search. This will allow us to make tighter prunes.
>---------------------------------------------------------------
d631986bf08dd5fc00f5b6a928e04a432c1e4b0a
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 13 ++++---
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 54 +++++++++++++-------------
2 files changed, 34 insertions(+), 33 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index abc8236..b8c2792 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -84,6 +84,9 @@ class DTNNKMeans
//! Counts from pruning. Not normalized.
arma::Col<size_t> prunedCounts;
+ //! Distances that the clusters moved last iteration.
+ arma::vec clusterDistances;
+
//! Upper bounds on cluster distances for each point.
arma::vec upperBounds;
//! Lower bounds on second closest cluster distance for each point.
@@ -93,14 +96,14 @@ class DTNNKMeans
//! The last cluster each point was assigned to.
arma::Col<size_t> lastOwners;
+ arma::mat distances;
+ arma::Mat<size_t> assignments;
+
+ std::vector<size_t> lastOldFromNewCentroids;
+
//! Update the bounds in the tree before the next iteration.
void UpdateTree(TreeType& node,
- const double tolerance,
const arma::mat& centroids,
- const arma::Mat<size_t>& assignments,
- const arma::mat& distances,
- const arma::mat& clusterDistances,
- const std::vector<size_t>& oldFromNewCentroids,
const arma::mat& interclusterDistances,
const std::vector<size_t>& newFromOldCentroids);
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 305fbc1..407b16c 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -52,7 +52,9 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
datasetOrig),
metric(metric),
distanceCalculations(0),
- iteration(0)
+ iteration(0),
+ distances(2, dataset.n_cols),
+ assignments(2, dataset.n_cols)
{
prunedPoints.resize(dataset.n_cols, false); // Fill with false.
upperBounds.set_size(dataset.n_cols);
@@ -90,6 +92,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
{
prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
prunedCounts.zeros(centroids.n_cols);
+ // The last element stores the maximum.
+ clusterDistances.zeros(centroids.n_cols + 1);
}
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
@@ -117,10 +121,21 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
nns.Search(1, closestClusters, interclusterDistances);
distanceCalculations += nns.BaseCases() + nns.Scores();
+ if (iteration != 0)
+ {
+ // Do the tree update for the previous iteration.
+ Log::Warn << "Performing tree update.\n";
+
+ // Reset centroids and counts for things we will collect during pruning.
+ prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
+ prunedCounts.zeros(centroids.n_cols);
+ UpdateTree(*tree, oldCentroids, interclusterDistances, newFromOldCentroids);
+
+ PrecalculateCentroids(*tree);
+ }
+
// We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
- arma::mat distances(2, dataset.n_cols);
- arma::Mat<size_t> assignments(2, dataset.n_cols);
distances.fill(DBL_MAX);
assignments.fill(size_t(-1));
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
@@ -161,7 +176,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// Now, calculate how far the clusters moved, after normalizing them.
double residual = 0.0;
double maxMovement = 0.0;
- arma::vec clusterDistances(centroids.n_cols + 1);
for (size_t c = 0; c < centroids.n_cols; ++c)
{
// Get the mapping to the old cluster, if necessary.
@@ -187,14 +201,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
clusterDistances[centroids.n_cols] = maxMovement;
distanceCalculations += centroids.n_cols;
- // Reset centroids and counts for things we will collect during pruning.
- prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
- prunedCounts.zeros(centroids.n_cols);
- UpdateTree(*tree, maxMovement, oldCentroids, assignments, distances,
- clusterDistances, oldFromNewCentroids, interclusterDistances,
- newFromOldCentroids);
-
- PrecalculateCentroids(*tree);
+ lastOldFromNewCentroids = oldFromNewCentroids;
delete centroidTree;
@@ -206,12 +213,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
- const double tolerance,
const arma::mat& centroids,
- const arma::Mat<size_t>& assignments,
- const arma::mat& distances,
- const arma::mat& clusterDistances,
- const std::vector<size_t>& oldFromNewCentroids,
const arma::mat& interclusterDistances,
const std::vector<size_t>& newFromOldCentroids)
{
@@ -225,8 +227,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
bool childrenPruned = true;
for (size_t i = 0; i < node.NumChildren(); ++i)
{
- UpdateTree(node.Child(i), tolerance, centroids, assignments, distances,
- clusterDistances, oldFromNewCentroids, interclusterDistances,
+ UpdateTree(node.Child(i), centroids, interclusterDistances,
newFromOldCentroids);
if (!node.Child(i).Stat().Pruned())
childrenPruned = false; // Not all children are pruned.
@@ -249,7 +250,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
size_t c;
if (!prunedPoints[node.Point(i)])
c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
- oldFromNewCentroids[assignments(0, node.Point(i))] :
+ lastOldFromNewCentroids[assignments(0, node.Point(i))] :
assignments(0, node.Point(i));
else
c = lastOwners[node.Point(i)];
@@ -293,10 +294,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
}
-// if (node.NumChildren() > 0)
-// Log::Warn << "Node:\n" << node << "single owner: " << singleOwner <<
-//".l\n" << node.Child(0) << ".r\n" << node.Child(1) << ".\n";
-
// What do we do with the new cluster bounds?
if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
node.Stat().MaxClusterDistance())
@@ -483,7 +480,8 @@ node.Child(0) << ", r\n" << node.Child(1) << ".\n";
if (!prunedLastIteration && !prunedPoints[index])
{
owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
- oldFromNewCentroids[assignments(0, index)] : assignments(0, index);
+ lastOldFromNewCentroids[assignments(0, index)] :
+ assignments(0, index);
// Establish bounds, since these points were searched this iteration.
upperBounds[index] = distances(0, index);
lowerSecondBounds[index] = distances(1, index);
@@ -564,11 +562,11 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
}
if (node.Stat().FirstBound() != DBL_MAX)
- node.Stat().FirstBound() += tolerance;
+ node.Stat().FirstBound() += clusterDistances[centroids.n_cols];
if (node.Stat().SecondBound() != DBL_MAX)
- node.Stat().SecondBound() += tolerance;
+ node.Stat().SecondBound() += clusterDistances[centroids.n_cols];
if (node.Stat().Bound() != DBL_MAX)
- node.Stat().Bound() += tolerance;
+ node.Stat().Bound() += clusterDistances[centroids.n_cols];
}
template<typename MetricType, typename MatType, typename TreeType>
More information about the mlpack-git
mailing list