[mlpack-git] master: Adapt to trees that aren't just binary. (55a7ba8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:06:04 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d
>---------------------------------------------------------------
commit 55a7ba8909ed26d5f9ce4427aae129d58f88cd27
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri May 8 13:13:34 2015 -0400
Adapt to trees that aren't just binary.
>---------------------------------------------------------------
55a7ba8909ed26d5f9ce4427aae129d58f88cd27
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 51 +++++++++++-----------
.../methods/kmeans/dual_tree_kmeans_statistic.hpp | 23 +++++-----
2 files changed, 36 insertions(+), 38 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index b9c9c5b..fa66007 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -509,37 +509,38 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
// If this is the root node, we can't coalesce.
if (node.Parent() != NULL)
{
- if (node.Child(0).Stat().StaticPruned() &&
- !node.Child(1).Stat().StaticPruned())
+ // First, we should coalesce those nodes that aren't statically pruned.
+ size_t numStaticallyPruned = 0;
+ size_t notPrunedIndex = 0;
+ for (size_t i = 0; i < node.NumChildren(); ++i)
{
- CoalesceTree(node.Child(1), 1);
-
- // Link the right child to the parent.
- node.Child(1).Parent() = node.Parent();
- node.Parent()->ChildPtr(child) = node.ChildPtr(1);
+ if (node.Child(i).Stat().StaticPruned())
+ {
+ ++numStaticallyPruned;
+ }
+ else
+ {
+ CoalesceTree(node.Child(i), i);
+ notPrunedIndex = i;
+ }
}
- else if (!node.Child(0).Stat().StaticPruned() &&
- node.Child(1).Stat().StaticPruned())
- {
- CoalesceTree(node.Child(0), 0);
-
- // Link the left child to the parent.
- node.Child(0).Parent() = node.Parent();
- node.Parent()->ChildPtr(child) = node.ChildPtr(0);
- }
- else if (!node.Child(0).Stat().StaticPruned() &&
- !node.Child(1).Stat().StaticPruned())
+ // If we've pruned all but one child, then notPrunedIndex will contain the
+ // index of that child, and we can coalesce this node entirely. Note that
+ // the case where all children are statically pruned should not happen,
+ // because then this node should itself be statically pruned.
+ if (numStaticallyPruned == node.NumChildren() - 1)
{
- // The conditional is probably not necessary.
- CoalesceTree(node.Child(0), 0);
- CoalesceTree(node.Child(1), 1);
+ node.Child(notPrunedIndex).Parent() = node.Parent();
+ node.Parent()->ChildPtr(child) = node.ChildPtr(notPrunedIndex);
}
}
else
{
- CoalesceTree(node.Child(0), 0);
- CoalesceTree(node.Child(1), 1);
+ // We can't coalesce the root, so call the children individually and
+ // coalesce them.
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ CoalesceTree(node.Child(i), i);
}
}
@@ -548,8 +549,8 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
TreeType& node)
{
node.Parent() = (TreeType*) node.Stat().TrueParent();
- node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
- node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ node.ChildPtr(i) = (TreeType*) node.Stat().TrueChild(i);
if (node.NumChildren() > 0)
{
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
index 8aaec4b..09a14a0 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -26,9 +26,7 @@ class DualTreeKMeansStatistic : public
staticUpperBoundMovement(0.0),
staticLowerBoundMovement(0.0),
centroid(),
- trueParent(NULL),
- trueLeft(NULL),
- trueRight(NULL)
+ trueParent(NULL)
{
// Nothing to do.
}
@@ -43,9 +41,7 @@ class DualTreeKMeansStatistic : public
staticPruned(false),
staticUpperBoundMovement(0.0),
staticLowerBoundMovement(0.0),
- trueParent(node.Parent()),
- trueLeft((node.NumChildren() == 0) ? NULL : &node.Child(0)),
- trueRight((node.NumChildren() == 0) ? NULL : &node.Child(1))
+ trueParent(node.Parent())
{
// Empirically calculate the centroid.
centroid.zeros(node.Dataset().n_rows);
@@ -57,6 +53,11 @@ class DualTreeKMeansStatistic : public
node.Child(i).Stat().Centroid();
centroid /= node.NumDescendants();
+
+ // Set the true children correctly.
+ trueChildren.resize(node.NumChildren());
+ for (size_t i = 0; i < node.NumChildren(); ++i)
+ trueChildren[i] = &node.Child(i);
}
double UpperBound() const { return upperBound; }
@@ -86,11 +87,8 @@ class DualTreeKMeansStatistic : public
void* TrueParent() const { return trueParent; }
void*& TrueParent() { return trueParent; }
- void* TrueLeft() const { return trueLeft; }
- void*& TrueLeft() { return trueLeft; }
-
- void* TrueRight() const { return trueRight; }
- void*& TrueRight() { return trueRight; }
+ void* TrueChild(const size_t i) const { return trueChildren[i]; }
+ void*& TrueChild(const size_t i) { return trueChildren[i]; }
std::string ToString() const
{
@@ -114,8 +112,7 @@ class DualTreeKMeansStatistic : public
double staticLowerBoundMovement;
arma::vec centroid;
void* trueParent;
- void* trueLeft;
- void* trueRight;
+ std::vector<void*> trueChildren;
};
} // namespace kmeans
More information about the mlpack-git
mailing list