[mlpack-git] master: Perform tree update at start of iteration. Cache some variables inside DTNNKMeans. (f440ca0)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:39 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit f440ca00db79bac726e2e3e7f825259cf0a575f9
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Feb 18 18:12:11 2015 -0500
Perform tree update at start of iteration. Cache some variables inside DTNNKMeans.
>---------------------------------------------------------------
f440ca00db79bac726e2e3e7f825259cf0a575f9
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 8 ++--
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 58 +++++++++++++-------------
2 files changed, 34 insertions(+), 32 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index bf83533..6c23bd2 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -92,17 +92,17 @@ class DTNNKMeans
arma::mat lastIterationCentroids; // For sanity checks.
+ arma::vec clusterDistances; // The amount the clusters moved last iteration.
+
//! Update the bounds in the tree before the next iteration.
+ //! centroids is the current (not yet searched) centroids.
void UpdateTree(TreeType& node,
- arma::vec& clusterDistances,
- std::vector<size_t>& oldFromNewCentroids,
- arma::mat& newCentroids);
+ const arma::mat& centroids);
//! Extract the centroids of the clusters.
void ExtractCentroids(TreeType& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
- std::vector<size_t>& oldFromNewCentroids,
arma::mat& centroids);
void CoalesceTree(TreeType& node, const size_t child = 0);
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 0ab1bd4..6bc3228 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -71,7 +71,10 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
Timer::Stop("tree_building");
for (size_t i = 0; i < dataset.n_cols; ++i)
+ {
prunedPoints[i] = false;
+ visited[i] = false;
+ }
assignments.fill(size_t(-1));
upperBounds.fill(DBL_MAX);
lowerBounds.fill(DBL_MAX);
@@ -91,16 +94,25 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::mat& newCentroids,
arma::Col<size_t>& counts)
{
- // Reset information.
- for (size_t i = 0; i < dataset.n_cols; ++i)
- visited[i] = false;
+ // Reset information, if we need to.
+ if (iteration > 0)
+ {
+ UpdateTree(*tree, centroids);
+
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ visited[i] = false;
+ }
+ else
+ {
+ // Not initialized yet.
+ clusterDistances.set_size(centroids.n_cols + 1);
+ }
// Build a tree on the centroids.
arma::mat oldCentroids(centroids); // Slow. :(
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
-
/*
Timer::Start("knn");
// Find the nearest neighbors of each of the clusters.
@@ -112,7 +124,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
distanceCalculations += nns.BaseCases() + nns.Scores();
Timer::Stop("knn");
*/
-
// We won't use the AllkNN class here because we have our own set of rules.
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
@@ -137,12 +148,10 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// Now we need to extract the clusters.
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
- ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids,
- oldCentroids);
+ ExtractCentroids(*tree, newCentroids, counts, oldCentroids);
// Now, calculate how far the clusters moved, after normalizing them.
double residual = 0.0;
- arma::vec clusterDistances(centroids.n_cols + 1);
clusterDistances[centroids.n_cols] = 0.0;
for (size_t c = 0; c < centroids.n_cols; ++c)
{
@@ -168,8 +177,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
distanceCalculations += centroids.n_cols;
- UpdateTree(*tree, clusterDistances, oldFromNewCentroids, newCentroids);
-
delete centroidTree;
++iteration;
@@ -180,16 +187,14 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
- arma::vec& clusterDistances,
- std::vector<size_t>& oldFromNewCentroids,
- arma::mat& newCentroids)
+ const arma::mat& centroids)
{
const bool prunedLastIteration = node.Stat().StaticPruned();
node.Stat().StaticPruned() = false;
// Grab information from the parent, if we can.
if (node.Parent() != NULL &&
- node.Parent()->Stat().Pruned() == clusterDistances.n_elem - 1)
+ node.Parent()->Stat().Pruned() == centroids.n_cols)
{
node.Stat().UpperBound() = node.Parent()->Stat().UpperBound();
node.Stat().LowerBound() = node.Parent()->Stat().LowerBound();
@@ -197,12 +202,12 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().Owner() = node.Parent()->Stat().Owner();
}
- if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
- (node.Stat().Owner() < clusterDistances.n_elem - 1))
+ if ((node.Stat().Pruned() == centroids.n_cols) &&
+ (node.Stat().Owner() < centroids.n_cols))
{
// Adjust bounds.
node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
- node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+ node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
node.Stat().StaticPruned() = true;
@@ -211,7 +216,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
{
// Tighten bound.
node.Stat().UpperBound() =
- node.MaxDistance(newCentroids.col(node.Stat().Owner()));
+ node.MaxDistance(centroids.col(node.Stat().Owner()));
++distanceCalculations;
if (node.Stat().UpperBound() < node.Stat().LowerBound())
{
@@ -221,7 +226,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
}
else
{
- node.Stat().LowerBound() -= clusterDistances[clusterDistances.n_elem - 1];
+ node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
}
if (!node.Stat().StaticPruned())
@@ -245,7 +250,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
prunedPoints[index] = false;
const size_t owner = assignments[node.Point(i)];
const double lowerBound = std::min(lowerBounds[index] -
- clusterDistances[newCentroids.n_cols], node.Stat().LowerBound());
+ clusterDistances[centroids.n_cols], node.Stat().LowerBound());
if (upperBounds[index] + clusterDistances[owner] < lowerBound)
{
prunedPoints[index] = true;
@@ -256,7 +261,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
{
// Attempt to tighten the bound.
upperBounds[index] = metric.Evaluate(dataset.col(index),
- newCentroids.col(owner));
+ centroids.col(owner));
++distanceCalculations;
if (upperBounds[index] < lowerBound)
{
@@ -280,7 +285,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().StaticUpperBoundMovement() +=
clusterDistances[node.Stat().Owner()];
node.Stat().StaticLowerBoundMovement() +=
- clusterDistances[newCentroids.n_cols];
+ clusterDistances[centroids.n_cols];
}
else
{
@@ -290,15 +295,14 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
}
for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids,
- newCentroids);
+ UpdateTree(node.Child(i), centroids);
if (!node.Stat().StaticPruned())
{
node.Stat().UpperBound() = DBL_MAX;
node.Stat().LowerBound() = DBL_MAX;
node.Stat().Pruned() = size_t(-1);
- node.Stat().Owner() = clusterDistances.n_elem - 1;
+ node.Stat().Owner() = centroids.n_cols;
node.Stat().StaticPruned() = false;
}
}
@@ -308,7 +312,6 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
TreeType& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
- std::vector<size_t>& oldFromNewCentroids,
arma::mat& centroids)
{
// Does this node own points?
@@ -386,8 +389,7 @@ assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
// The node is not entirely owned by a cluster. Recurse.
for (size_t i = 0; i < node.NumChildren(); ++i)
- ExtractCentroids(node.Child(i), newCentroids, newCounts,
- oldFromNewCentroids, centroids);
+ ExtractCentroids(node.Child(i), newCentroids, newCounts, centroids);
}
}
More information about the mlpack-git
mailing list