[mlpack-git] master: Coalesce and decoalesce the tree. (928fd6d)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:33 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 928fd6dde0a6b8187cc8455e3cd4a8ac0cb16c55
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Feb 18 09:57:28 2015 -0500
Coalesce and decoalesce the tree.
>---------------------------------------------------------------
928fd6dde0a6b8187cc8455e3cd4a8ac0cb16c55
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 3 ++
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 69 ++++++++++++++++++++++++++
src/mlpack/methods/kmeans/dtnn_statistic.hpp | 22 +++++++-
3 files changed, 92 insertions(+), 2 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index 26c90df..bf83533 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -104,6 +104,9 @@ class DTNNKMeans
arma::Col<size_t>& newCounts,
std::vector<size_t>& oldFromNewCentroids,
arma::mat& centroids);
+
+ void CoalesceTree(TreeType& node, const size_t child = 0);
+ void DecoalesceTree(TreeType& node);
};
//! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 985c359..0ab1bd4 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -121,11 +121,19 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
traverser(rules);
+ Timer::Start("tree_mod");
+ CoalesceTree(*tree);
+ Timer::Stop("tree_mod");
+
// Set the number of pruned centroids in the root to 0.
tree->Stat().Pruned() = 0;
traverser.Traverse(*tree, *centroidTree);
distanceCalculations += rules.BaseCases() + rules.Scores();
+ Timer::Start("tree_mod");
+ DecoalesceTree(*tree);
+ Timer::Stop("tree_mod");
+
// Now we need to extract the clusters.
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
@@ -383,6 +391,67 @@ assignments[node.Point(0)] << " with ub " << upperBounds[node.Point(0)] <<
}
}
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
+ TreeType& node,
+ const size_t child /* Which child are we? */)
+{
+ // If one of the two children is pruned, we hide this node.
+ // This assumes the BinarySpaceTree. (bad Ryan! bad!)
+ if (node.NumChildren() == 0)
+ return; // We can't do anything.
+
+ // If this is the root node, we can't coalesce.
+ if (node.Parent() != NULL)
+ {
+ if (node.Child(0).Stat().StaticPruned() &&
+ !node.Child(1).Stat().StaticPruned())
+ {
+ CoalesceTree(node.Child(1), 1);
+
+ // Link the right child to the parent.
+ node.Child(1).Parent() = node.Parent();
+ node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+ }
+ else if (!node.Child(0).Stat().StaticPruned() &&
+ node.Child(1).Stat().StaticPruned())
+ {
+ CoalesceTree(node.Child(0), 0);
+
+ // Link the left child to the parent.
+ node.Child(0).Parent() = node.Parent();
+ node.Parent()->ChildPtr(child) = node.ChildPtr(0);
+
+ }
+ else if (!node.Child(0).Stat().StaticPruned() &&
+ !node.Child(1).Stat().StaticPruned())
+ {
+ // The conditional is probably not necessary.
+ CoalesceTree(node.Child(0), 0);
+ CoalesceTree(node.Child(1), 1);
+ }
+ }
+ else
+ {
+ CoalesceTree(node.Child(0), 0);
+ CoalesceTree(node.Child(1), 1);
+ }
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
+{
+ node.Parent() = (TreeType*) node.Stat().TrueParent();
+ node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
+ node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+
+ if (node.NumChildren() > 0)
+ {
+ DecoalesceTree(node.Child(0));
+ DecoalesceTree(node.Child(1));
+ }
+}
+
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index 2601378..51f4f60 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -25,7 +25,10 @@ class DTNNStatistic : public
staticPruned(false),
staticUpperBoundMovement(0.0),
staticLowerBoundMovement(0.0),
- centroid()
+ centroid(),
+ trueParent(NULL),
+ trueLeft(NULL),
+ trueRight(NULL)
{
// Nothing to do.
}
@@ -39,7 +42,10 @@ class DTNNStatistic : public
pruned(size_t(-1)),
staticPruned(false),
staticUpperBoundMovement(0.0),
- staticLowerBoundMovement(0.0)
+ staticLowerBoundMovement(0.0),
+ trueParent(node.Parent()),
+ trueLeft((node.NumChildren() == 0) ? NULL : &node.Child(0)),
+ trueRight((node.NumChildren() == 0) ? NULL : &node.Child(1))
{
// Empirically calculate the centroid.
centroid.zeros(node.Dataset().n_rows);
@@ -77,6 +83,15 @@ class DTNNStatistic : public
double StaticLowerBoundMovement() const { return staticLowerBoundMovement; }
double& StaticLowerBoundMovement() { return staticLowerBoundMovement; }
+ void* TrueParent() const { return trueParent; }
+ void*& TrueParent() { return trueParent; }
+
+ void* TrueLeft() const { return trueLeft; }
+ void*& TrueLeft() { return trueLeft; }
+
+ void* TrueRight() const { return trueRight; }
+ void*& TrueRight() { return trueRight; }
+
std::string ToString() const
{
std::ostringstream o;
@@ -98,6 +113,9 @@ class DTNNStatistic : public
double staticUpperBoundMovement;
double staticLowerBoundMovement;
arma::vec centroid;
+ void* trueParent;
+ void* trueLeft;
+ void* trueRight;
};
} // namespace kmeans
More information about the mlpack-git
mailing list