[mlpack-git] master: Refactor for new TreeType API. (4af30af)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:34 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit 4af30af35f3cbf844d7806ac9038748c89a90e38
Author: ryan <ryan at ratml.org>
Date: Sun Jul 26 23:07:29 2015 -0400
Refactor for new TreeType API.
>---------------------------------------------------------------
4af30af35f3cbf844d7806ac9038748c89a90e38
src/mlpack/methods/kmeans/dual_tree_kmeans.hpp | 26 +++++---
.../methods/kmeans/dual_tree_kmeans_impl.hpp | 78 ++++++++++++++--------
src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp | 4 +-
3 files changed, 68 insertions(+), 40 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index a0c69a7..4da3f08 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -29,13 +29,20 @@ namespace kmeans {
template<
typename MetricType,
typename MatType,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
- DualTreeKMeansStatistic> >
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType = tree::KDTree>
class DualTreeKMeans
{
public:
+ //! Convenience typedef.
+ typedef TreeType<MetricType, DualTreeKMeansStatistic, MatType> Tree;
+
+ template<typename TMetricType, typename TStatisticType, typename TMatType>
+ using NNSTreeType = TreeType<TMetricType, DualTreeKMeansStatistic, TMatType>;
+
/**
- * Construct the DualTreeKMeans object, which will construct a tree on the points.
+ * Construct the DualTreeKMeans object, which will construct a tree on the
+ * points.
*/
DualTreeKMeans(const MatType& dataset, MetricType& metric);
@@ -72,7 +79,7 @@ class DualTreeKMeans
MetricType metric;
//! The tree built on the points.
- TreeType* tree;
+ Tree* tree;
//! Track distance calculations.
size_t distanceCalculations;
@@ -98,7 +105,7 @@ class DualTreeKMeans
//! Update the bounds in the tree before the next iteration.
//! centroids is the current (not yet searched) centroids.
- void UpdateTree(TreeType& node,
+ void UpdateTree(Tree& node,
const arma::mat& centroids,
const double parentUpperBound = 0.0,
const double adjustedParentUpperBound = DBL_MAX,
@@ -106,13 +113,13 @@ class DualTreeKMeans
const double adjustedParentLowerBound = 0.0);
//! Extract the centroids of the clusters.
- void ExtractCentroids(TreeType& node,
+ void ExtractCentroids(Tree& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
arma::mat& centroids);
- void CoalesceTree(TreeType& node, const size_t child = 0);
- void DecoalesceTree(TreeType& node);
+ void CoalesceTree(Tree& node, const size_t child = 0);
+ void DecoalesceTree(Tree& node);
};
//! Utility function for hiding children. This actually does something, and is
@@ -153,8 +160,7 @@ using DefaultDualTreeKMeans = DualTreeKMeans<MetricType, MatType>;
//! type.
template<typename MetricType, typename MatType>
using CoverTreeDualTreeKMeans = DualTreeKMeans<MetricType, MatType,
- tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
- DualTreeKMeansStatistic> >;
+ tree::StandardCoverTree>;
} // namespace kmeans
} // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 70a5bdf..3cee7d3 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -44,12 +44,15 @@ TreeType* BuildTree(
return new TreeType(dataset);
}
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
const MatType& dataset,
MetricType& metric) :
datasetOrig(dataset),
- dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
+ dataset(tree::TreeTraits<Tree>::RearrangesDataset ? datasetCopy :
datasetOrig),
metric(metric),
distanceCalculations(0),
@@ -63,11 +66,11 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
Timer::Start("tree_building");
// Copy the dataset, if necessary.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
datasetCopy = datasetOrig;
// Now build the tree. We don't need any mappings.
- tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+ tree = new Tree(const_cast<MatType&>(this->dataset));
Timer::Stop("tree_building");
@@ -81,7 +84,10 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
lowerBounds.fill(DBL_MAX);
}
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
{
if (tree)
@@ -89,7 +95,10 @@ DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
}
// Run a single iteration.
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
const arma::mat& centroids,
arma::mat& newCentroids,
@@ -98,21 +107,23 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
// Build a tree on the centroids.
arma::mat oldCentroids(centroids); // Slow. :(
std::vector<size_t> oldFromNewCentroids;
- TreeType* centroidTree = BuildTree<TreeType>(
- const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+ Tree* centroidTree = BuildTree<Tree>(const_cast<MatType&>(centroids),
+ oldFromNewCentroids);
// Reset information in the tree, if we need to.
if (iteration > 0)
{
Timer::Start("knn");
- // Find the nearest neighbors of each of the clusters.
- neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
- TreeType> nns(centroidTree);
+ // Find the nearest neighbors of each of the clusters. We have to make our
+ // own TreeType, which is a little bit abuse, but we know for sure the
+ // StatisticType we have will work.
+ neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, MatType,
+ NNSTreeType> nns(centroidTree);
// If the tree maps points, we need an intermediate result matrix.
arma::mat* interclusterDistancesTemp =
- (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+ (tree::TreeTraits<Tree>::RearrangesDataset) ?
new arma::mat(1, centroids.n_elem) : &interclusterDistances;
arma::Mat<size_t> closestClusters; // We don't actually care about these.
@@ -120,7 +131,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
distanceCalculations += nns.BaseCases() + nns.Scores();
// We need to do the unmapping ourselves, if the tree does mapping.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
{
for (size_t i = 0; i < interclusterDistances.n_elem; ++i)
interclusterDistances[oldFromNewCentroids[i]] =
@@ -145,11 +156,11 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
// We won't use the AllkNN class here because we have our own set of rules.
lastIterationCentroids = oldCentroids;
- typedef DualTreeKMeansRules<MetricType, TreeType> RuleType;
+ typedef DualTreeKMeansRules<MetricType, Tree> RuleType;
RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
metric, prunedPoints, oldFromNewCentroids, visited);
- typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
+ typename Tree::template BreadthFirstDualTreeTraverser<RuleType>
traverser(rules);
Timer::Start("tree_mod");
@@ -176,7 +187,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
for (size_t c = 0; c < centroids.n_cols; ++c)
{
// Get the mapping to the old cluster, if necessary.
- const size_t old = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+ const size_t old = (tree::TreeTraits<Tree>::RearrangesDataset) ?
oldFromNewCentroids[c] : c;
if (counts[old] == 0)
{
@@ -204,9 +215,12 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
return std::sqrt(residual);
}
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
- TreeType& node,
+ Tree& node,
const arma::mat& centroids,
const double parentUpperBound,
const double adjustedParentUpperBound,
@@ -344,7 +358,7 @@ visited[node.Descendant(i)] << ".\n";
}
bool allPointsPruned = true;
- if (tree::TreeTraits<TreeType>::HasSelfChildren && node.NumChildren() > 0)
+ if (tree::TreeTraits<Tree>::HasSelfChildren && node.NumChildren() > 0)
{
// If this tree type has self-children, then we have already adjusted the
// point bounds at a lower level, and we can determine if all of our points
@@ -406,7 +420,7 @@ visited[node.Descendant(i)] << ".\n";
// lower level, though. If that's the case, then we shouldn't
// invalidate the bounds we've got -- it will happen at the lower
// level.
- if (!tree::TreeTraits<TreeType>::HasSelfChildren ||
+ if (!tree::TreeTraits<Tree>::HasSelfChildren ||
node.NumChildren() == 0)
{
upperBounds[index] = DBL_MAX;
@@ -465,9 +479,12 @@ visited[node.Descendant(i)] << ".\n";
}
}
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
- TreeType& node,
+ Tree& node,
arma::mat& newCentroids,
arma::Col<size_t>& newCounts,
arma::mat& centroids)
@@ -554,9 +571,12 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
}
}
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
- TreeType& node,
+ Tree& node,
const size_t child /* Which child are we? */)
{
// If all children except one are pruned, we can hide this node.
@@ -599,11 +619,13 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
}
}
-template<typename MetricType, typename MatType, typename TreeType>
-void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
- TreeType& node)
+template<typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(Tree& node)
{
- node.Parent() = (TreeType*) node.Stat().TrueParent();
+ node.Parent() = (Tree*) node.Stat().TrueParent();
RestoreChildren(node);
for (size_t i = 0; i < node.NumChildren(); ++i)
diff --git a/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp b/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
index ce813b3..c7ac3f9 100644
--- a/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
@@ -65,8 +65,8 @@ class PellegMooreKMeans
size_t& DistanceCalculations() { return distanceCalculations; }
//! Convenience typedef for the tree.
- typedef tree::BinarySpaceTree<bound::HRectBound<2, true>,
- PellegMooreKMeansStatistic, MatType> TreeType;
+ typedef tree::KDTree<MetricType, PellegMooreKMeansStatistic, MatType>
+ TreeType;
private:
//! The original dataset reference.
More information about the mlpack-git
mailing list