[mlpack-git] master: Make DTNNKMeans work again. Next up, tree coalescion. (Is that a word?) (45a731f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:50 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 45a731f7b47a033f78a3a06ac58647d53277c478
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 17 14:14:01 2015 -0500
Make DTNNKMeans work again. Next up, tree coalescion. (Is that a word?)
>---------------------------------------------------------------
45a731f7b47a033f78a3a06ac58647d53277c478
src/mlpack/core/tree/cover_tree/cover_tree.hpp | 3 ++
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 7 ++-
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 68 +++++++++++++++++++-------
src/mlpack/methods/kmeans/dtnn_rules.hpp | 6 +++
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 4 +-
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 7 +++
6 files changed, 74 insertions(+), 21 deletions(-)
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 41db6fb..8536718 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -208,6 +208,9 @@ class CoverTree
template<typename RuleType>
class DualTreeTraverser;
+ template<typename RuleType>
+ using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
+
//! Get a reference to the dataset.
const arma::mat& Dataset() const { return dataset; }
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index d27988e..e756669 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -93,12 +93,15 @@ class DTNNKMeans
arma::mat lastIterationCentroids; // For sanity checks.
//! Update the bounds in the tree before the next iteration.
- void UpdateTree(TreeType& node);
+ void UpdateTree(TreeType& node,
+ arma::vec& clusterDistances,
+ std::vector<size_t>& oldFromNewCentroids);
//! Extract the centroids of the clusters.
void ExtractCentroids(TreeType& node,
arma::mat& newCentroids,
- arma::Col<size_t>& newCounts);
+ arma::Col<size_t>& newCounts,
+ std::vector<size_t>& oldFromNewCentroids);
};
//! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index bd7cb80..38e7650 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -29,7 +29,7 @@ TreeType* BuildTree(
{
// This is a hack. I know this will be BinarySpaceTree, so force a leaf size
// of two.
- return new TreeType(dataset, oldFromNew);
+ return new TreeType(dataset, oldFromNew, 2);
}
//! Call the tree constructor that does not do mapping.
@@ -99,14 +99,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
std::vector<size_t> oldFromNewCentroids;
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
- // Calculate new from old mappings.
- std::vector<size_t> newFromOldCentroids;
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- newFromOldCentroids.resize(centroids.n_cols);
- for (size_t i = 0; i < centroids.n_cols; ++i)
- newFromOldCentroids[oldFromNewCentroids[i]] = i;
- }
/*
Timer::Start("knn");
@@ -121,7 +113,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
*/
// We won't use the AllkNN class here because we have our own set of rules.
- typedef typename DTNNKMeansRules<MetricType, TreeType> RuleType;
+ typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
metric, prunedPoints, oldFromNewCentroids, visited);
@@ -131,14 +123,18 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
// Set the number of pruned centroids in the root to 0.
tree->Stat().Pruned() = 0;
traverser.Traverse(*tree, *centroidTree);
+ distanceCalculations += rules.BaseCases() + rules.Scores();
// Now we need to extract the clusters.
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
- ExtractCentroids(*tree, newCentroids, counts);
+ ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids);
+ Log::Warn << "New counts: " << counts.t();
+ Log::Warn << accu(counts) << ".\n";
// Now, calculate how far the clusters moved, after normalizing them.
double residual = 0.0;
+ arma::vec clusterDistances(centroids.n_cols + 1);
clusterDistances[centroids.n_cols] = 0.0;
for (size_t c = 0; c < centroids.n_cols; ++c)
{
@@ -164,6 +160,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
}
distanceCalculations += centroids.n_cols;
+ UpdateTree(*tree, clusterDistances, oldFromNewCentroids);
+
delete centroidTree;
++iteration;
@@ -173,28 +171,61 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
- TreeType& node)
+ TreeType& node,
+ arma::vec& clusterDistances,
+ std::vector<size_t>& oldFromNewCentroids)
{
// Simply reset the bounds.
node.Stat().UpperBound() = DBL_MAX;
node.Stat().LowerBound() = DBL_MAX;
+ if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
+ (node.Stat().Owner() < clusterDistances.n_elem - 1))
+ {
+ const size_t owner = oldFromNewCentroids[node.Stat().Owner()];
+
+ node.Stat().LastUpperBound() = node.Stat().UpperBound() +
+ clusterDistances[owner];
+
+ // Update child bounds, at least a little.
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ {
+ node.Child(i).Stat().UpperBound() = node.Stat().UpperBound();
+ node.Child(i).Stat().LowerBound() = node.Stat().LowerBound();
+ node.Child(i).Stat().Owner() = node.Stat().Owner();
+ node.Child(i).Stat().Pruned() = node.Stat().Pruned();
+ }
+ }
+ else if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
+ (node.Stat().Owner() >= clusterDistances.n_elem - 1))
+ {
+ Log::Warn << clusterDistances.n_cols - 1 << ".\n";
+ Log::Warn << node;
+ Log::Fatal << "Node is pruned, but has no owner!\n";
+ }
+ else
+ {
+ node.Stat().LastUpperBound() = node.Stat().UpperBound() +
+ clusterDistances[clusterDistances.n_elem - 1];
+ }
node.Stat().Pruned() = size_t(-1);
node.Stat().Owner() = size_t(-1);
+ node.Stat().LowerBound() = DBL_MAX;
for (size_t i = 0; i < node.NumChildren(); ++i)
- UpdateTree(node.Child(i));
+ UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids);
}
template<typename MetricType, typename MatType, typename TreeType>
void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
TreeType& node,
arma::mat& newCentroids,
- arma::Col<size_t>& newCounts)
+ arma::Col<size_t>& newCounts,
+ std::vector<size_t>& oldFromNewCentroids)
{
// Does this node own points?
- if (node.Stat().Owner() < newCentroids.n_cols)
+ if (node.Stat().Pruned() == newCentroids.n_cols)
{
- const size_t owner = node.Stat().Owner();
+ const size_t owner = oldFromNewCentroids[node.Stat().Owner()];
newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
newCounts[owner] += node.NumDescendants();
}
@@ -203,14 +234,15 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
// Check each point held in the node.
for (size_t i = 0; i < node.NumPoints(); ++i)
{
- const size_t owner = assignments[node.Point(i)];
+ const size_t owner = oldFromNewCentroids[assignments[node.Point(i)]];
newCentroids.col(owner) += dataset.col(node.Point(i));
++newCounts[owner];
}
// The node is not entirely owned by a cluster. Recurse.
for (size_t i = 0; i < node.NumChildren(); ++i)
- ExtractCentroids(node.Child(i), newCentroids, newCounts);
+ ExtractCentroids(node.Child(i), newCentroids, newCounts,
+ oldFromNewCentroids);
}
}
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 1c3cda4..5252050 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -44,6 +44,12 @@ class DTNNKMeansRules
TraversalInfoType& TraversalInfo() { return traversalInfo; }
const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+ size_t BaseCases() const { return baseCases; }
+ size_t& BaseCases() { return baseCases; }
+
+ size_t Scores() const { return scores; }
+ size_t& Scores() { return scores; }
+
private:
const arma::mat& centroids;
const arma::mat& dataset;
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 7e1904f..6a57273 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -91,6 +91,7 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
{
queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
+ queryNode.Stat().Owner() = queryNode.Parent()->Stat().Owner();
}
if (queryNode.Stat().Pruned() == centroids.n_cols)
@@ -102,7 +103,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
math::Range distances = queryNode.RangeDistance(&referenceNode);
double score = distances.Lo();
++scores;
- if (distances.Lo() > queryNode.Stat().UpperBound())
+ if (distances.Lo() > queryNode.Stat().UpperBound() ||
+ distances.Lo() > queryNode.Stat().LastUpperBound())
{
// The reference node can own no points in this query node. We may improve
// the lower bound on pruned nodes, though.
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index b81a17a..ee24b23 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -20,6 +20,7 @@ class DTNNStatistic : public
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
upperBound(DBL_MAX),
lowerBound(DBL_MAX),
+ lastUpperBound(DBL_MAX),
owner(size_t(-1)),
pruned(size_t(-1)),
centroid()
@@ -32,6 +33,7 @@ class DTNNStatistic : public
neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
upperBound(DBL_MAX),
lowerBound(DBL_MAX),
+ lastUpperBound(DBL_MAX),
owner(size_t(-1)),
pruned(size_t(-1))
{
@@ -53,6 +55,9 @@ class DTNNStatistic : public
double LowerBound() const { return lowerBound; }
double& LowerBound() { return lowerBound; }
+ double LastUpperBound() const { return lastUpperBound; }
+ double& LastUpperBound() { return lastUpperBound; }
+
const arma::vec& Centroid() const { return centroid; }
arma::vec& Centroid() { return centroid; }
@@ -68,6 +73,7 @@ class DTNNStatistic : public
o << "DTNNStatistic [" << this << "]:\n";
o << " Upper bound: " << upperBound << ".\n";
o << " Lower bound: " << lowerBound << ".\n";
+ o << " Last upper bound: " << lastUpperBound << ".\n";
o << " Pruned: " << pruned << ".\n";
o << " Owner: " << owner << ".\n";
return o.str();
@@ -76,6 +82,7 @@ class DTNNStatistic : public
private:
double upperBound;
double lowerBound;
+ double lastUpperBound;
size_t owner;
size_t pruned;
arma::vec centroid;
More information about the mlpack-git
mailing list