[mlpack-git] master: Start applying prune on intercluster distances. Not quite done yet. (95cb008)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:05:14 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 95cb008c6ddde0e2b5a7b3fcc7027a509f766fd8
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Feb 18 18:34:34 2015 -0500
Start applying prune on intercluster distances. Not quite done yet.
>---------------------------------------------------------------
95cb008c6ddde0e2b5a7b3fcc7027a509f766fd8
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 3 +-
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 47 ++++++++++++++------------
2 files changed, 28 insertions(+), 22 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 6c23bd2..ec141db 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -97,7 +97,8 @@ class DTNNKMeans
//! Update the bounds in the tree before the next iteration.
//! centroids is the current (not yet searched) centroids.
void UpdateTree(TreeType& node,
- const arma::mat& centroids);
+ const arma::mat& centroids,
+ const arma::mat& interclusterDistances);
//! Extract the centroids of the clusters.
void ExtractCentroids(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 6bc3228..ef6cc2f 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -94,26 +94,12 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::mat& newCentroids,
arma::Col<size_t>& counts)
{
- // Reset information, if we need to.
- if (iteration > 0)
- {
- UpdateTree(*tree, centroids);
-
- for (size_t i = 0; i < dataset.n_cols; ++i)
- visited[i] = false;
- }
- else
- {
- // Not initialized yet.
- clusterDistances.set_size(centroids.n_cols + 1);
- }
-
// Build a tree on the centroids.
arma::mat oldCentroids(centroids); // Slow. :(
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
-/*
+
Timer::Start("knn");
// Find the nearest neighbors of each of the clusters.
neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
@@ -121,9 +107,23 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
arma::mat interclusterDistances;
arma::Mat<size_t> closestClusters; // We don't actually care about these.
nns.Search(1, closestClusters, interclusterDistances);
- distanceCalculations += nns.BaseCases() + nns.Scores();
+// distanceCalculations += nns.BaseCases() + nns.Scores();
Timer::Stop("knn");
-*/
+
+ // Reset information in the tree, if we need to.
+ if (iteration > 0)
+ {
+ UpdateTree(*tree, oldCentroids, interclusterDistances);
+
+ for (size_t i = 0; i < dataset.n_cols; ++i)
+ visited[i] = false;
+ }
+ else
+ {
+ // Not initialized yet.
+ clusterDistances.set_size(centroids.n_cols + 1);
+ }
+
// We won't use the AllkNN class here because we have our own set of rules.
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
@@ -187,7 +187,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
TreeType& node,
- const arma::mat& centroids)
+ const arma::mat& centroids,
+ const arma::mat& interclusterDistances)
{
const bool prunedLastIteration = node.Stat().StaticPruned();
node.Stat().StaticPruned() = false;
@@ -208,7 +209,9 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
// Adjust bounds.
node.Stat().UpperBound() += clusterDistances[node.Stat().Owner()];
node.Stat().LowerBound() -= clusterDistances[centroids.n_cols];
- if (node.Stat().UpperBound() < node.Stat().LowerBound())
+ const double lowerBound = std::max(node.Stat().LowerBound(),
+ interclusterDistances[node.Stat().Owner()] / 2.0);
+ if (node.Stat().UpperBound() < lowerBound)
{
node.Stat().StaticPruned() = true;
}
@@ -218,7 +221,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
node.Stat().UpperBound() =
node.MaxDistance(centroids.col(node.Stat().Owner()));
++distanceCalculations;
- if (node.Stat().UpperBound() < node.Stat().LowerBound())
+ if (node.Stat().UpperBound() < lowerBound)
{
node.Stat().StaticPruned() = true;
}
@@ -251,6 +254,8 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
const size_t owner = assignments[node.Point(i)];
const double lowerBound = std::min(lowerBounds[index] -
clusterDistances[centroids.n_cols], node.Stat().LowerBound());
+// const double pruningLowerBound = std::max(lowerBound,
+// interclusterDistances[owner] / 2.0);
if (upperBounds[index] + clusterDistances[owner] < lowerBound)
{
prunedPoints[index] = true;
@@ -295,7 +300,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
}
for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i), centroids);
+ UpdateTree(node.Child(i), centroids, interclusterDistances);
if (!node.Stat().StaticPruned())
{
More information about the mlpack-git
mailing list