[mlpack-git] master: Refactor for new TreeType API. (4af30af)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:34 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit 4af30af35f3cbf844d7806ac9038748c89a90e38
Author: ryan <ryan at ratml.org>
Date:   Sun Jul 26 23:07:29 2015 -0400

    Refactor for new TreeType API.


>---------------------------------------------------------------

4af30af35f3cbf844d7806ac9038748c89a90e38
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     | 26 +++++---
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 78 ++++++++++++++--------
 src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp  |  4 +-
 3 files changed, 68 insertions(+), 40 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index a0c69a7..4da3f08 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -29,13 +29,20 @@ namespace kmeans {
 template<
     typename MetricType,
     typename MatType,
-    typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
-        DualTreeKMeansStatistic> >
+    template<typename MetricType, typename StatisticType, typename MatType>
+        class TreeType = tree::KDTree>
 class DualTreeKMeans
 {
  public:
+  //! Convenience typedef.
+  typedef TreeType<MetricType, DualTreeKMeansStatistic, MatType> Tree;
+    
+  template<typename TMetricType, typename TStatisticType, typename TMatType>
+  using NNSTreeType = TreeType<TMetricType, DualTreeKMeansStatistic, TMatType>;
+
   /**
-   * Construct the DualTreeKMeans object, which will construct a tree on the points.
+   * Construct the DualTreeKMeans object, which will construct a tree on the
+   * points.
    */
   DualTreeKMeans(const MatType& dataset, MetricType& metric);
 
@@ -72,7 +79,7 @@ class DualTreeKMeans
   MetricType metric;
 
   //! The tree built on the points.
-  TreeType* tree;
+  Tree* tree;
 
   //! Track distance calculations.
   size_t distanceCalculations;
@@ -98,7 +105,7 @@ class DualTreeKMeans
 
   //! Update the bounds in the tree before the next iteration.
   //! centroids is the current (not yet searched) centroids.
-  void UpdateTree(TreeType& node,
+  void UpdateTree(Tree& node,
                   const arma::mat& centroids,
                   const double parentUpperBound = 0.0,
                   const double adjustedParentUpperBound = DBL_MAX,
@@ -106,13 +113,13 @@ class DualTreeKMeans
                   const double adjustedParentLowerBound = 0.0);
 
   //! Extract the centroids of the clusters.
-  void ExtractCentroids(TreeType& node,
+  void ExtractCentroids(Tree& node,
                         arma::mat& newCentroids,
                         arma::Col<size_t>& newCounts,
                         arma::mat& centroids);
 
-  void CoalesceTree(TreeType& node, const size_t child = 0);
-  void DecoalesceTree(TreeType& node);
+  void CoalesceTree(Tree& node, const size_t child = 0);
+  void DecoalesceTree(Tree& node);
 };
 
 //! Utility function for hiding children.  This actually does something, and is
@@ -153,8 +160,7 @@ using DefaultDualTreeKMeans = DualTreeKMeans<MetricType, MatType>;
 //! type.
 template<typename MetricType, typename MatType>
 using CoverTreeDualTreeKMeans = DualTreeKMeans<MetricType, MatType,
-    tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
-    DualTreeKMeansStatistic> >;
+    tree::StandardCoverTree>;
 
 } // namespace kmeans
 } // namespace mlpack
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index 70a5bdf..3cee7d3 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -44,12 +44,15 @@ TreeType* BuildTree(
   return new TreeType(dataset);
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
     const MatType& dataset,
     MetricType& metric) :
     datasetOrig(dataset),
-    dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
+    dataset(tree::TreeTraits<Tree>::RearrangesDataset ? datasetCopy :
         datasetOrig),
     metric(metric),
     distanceCalculations(0),
@@ -63,11 +66,11 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
   Timer::Start("tree_building");
 
   // Copy the dataset, if necessary.
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+  if (tree::TreeTraits<Tree>::RearrangesDataset)
     datasetCopy = datasetOrig;
 
   // Now build the tree.  We don't need any mappings.
-  tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+  tree = new Tree(const_cast<MatType&>(this->dataset));
 
   Timer::Stop("tree_building");
 
@@ -81,7 +84,10 @@ DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
   lowerBounds.fill(DBL_MAX);
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
 {
   if (tree)
@@ -89,7 +95,10 @@ DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
 }
 
 // Run a single iteration.
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
     const arma::mat& centroids,
     arma::mat& newCentroids,
@@ -98,21 +107,23 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   // Build a tree on the centroids.
   arma::mat oldCentroids(centroids); // Slow. :(
   std::vector<size_t> oldFromNewCentroids;
-  TreeType* centroidTree = BuildTree<TreeType>(
-      const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+  Tree* centroidTree = BuildTree<Tree>(const_cast<MatType&>(centroids),
+      oldFromNewCentroids);
 
   // Reset information in the tree, if we need to.
   if (iteration > 0)
   {
     Timer::Start("knn");
 
-    // Find the nearest neighbors of each of the clusters.
-    neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
-        TreeType> nns(centroidTree);
+    // Find the nearest neighbors of each of the clusters.  We have to make our
+    // own TreeType, which is a little bit abuse, but we know for sure the
+    // StatisticType we have will work.
+    neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, MatType,
+        NNSTreeType> nns(centroidTree);
 
     // If the tree maps points, we need an intermediate result matrix.
     arma::mat* interclusterDistancesTemp =
-        (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+        (tree::TreeTraits<Tree>::RearrangesDataset) ?
         new arma::mat(1, centroids.n_elem) : &interclusterDistances;
 
     arma::Mat<size_t> closestClusters; // We don't actually care about these.
@@ -120,7 +131,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
     distanceCalculations += nns.BaseCases() + nns.Scores();
 
     // We need to do the unmapping ourselves, if the tree does mapping.
-    if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    if (tree::TreeTraits<Tree>::RearrangesDataset)
     {
       for (size_t i = 0; i < interclusterDistances.n_elem; ++i)
         interclusterDistances[oldFromNewCentroids[i]] =
@@ -145,11 +156,11 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
   // We won't use the AllkNN class here because we have our own set of rules.
   lastIterationCentroids = oldCentroids;
-  typedef DualTreeKMeansRules<MetricType, TreeType> RuleType;
+  typedef DualTreeKMeansRules<MetricType, Tree> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
       metric, prunedPoints, oldFromNewCentroids, visited);
 
-  typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
+  typename Tree::template BreadthFirstDualTreeTraverser<RuleType>
       traverser(rules);
 
   Timer::Start("tree_mod");
@@ -176,7 +187,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
     // Get the mapping to the old cluster, if necessary.
-    const size_t old = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+    const size_t old = (tree::TreeTraits<Tree>::RearrangesDataset) ?
         oldFromNewCentroids[c] : c;
     if (counts[old] == 0)
     {
@@ -204,9 +215,12 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   return std::sqrt(residual);
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
-    TreeType& node,
+    Tree& node,
     const arma::mat& centroids,
     const double parentUpperBound,
     const double adjustedParentUpperBound,
@@ -344,7 +358,7 @@ visited[node.Descendant(i)] << ".\n";
   }
 
   bool allPointsPruned = true;
-  if (tree::TreeTraits<TreeType>::HasSelfChildren && node.NumChildren() > 0)
+  if (tree::TreeTraits<Tree>::HasSelfChildren && node.NumChildren() > 0)
   {
     // If this tree type has self-children, then we have already adjusted the
     // point bounds at a lower level, and we can determine if all of our points
@@ -406,7 +420,7 @@ visited[node.Descendant(i)] << ".\n";
           // lower level, though.  If that's the case, then we shouldn't
           // invalidate the bounds we've got -- it will happen at the lower
           // level.
-          if (!tree::TreeTraits<TreeType>::HasSelfChildren ||
+          if (!tree::TreeTraits<Tree>::HasSelfChildren ||
               node.NumChildren() == 0)
           {
             upperBounds[index] = DBL_MAX;
@@ -465,9 +479,12 @@ visited[node.Descendant(i)] << ".\n";
   }
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
-    TreeType& node,
+    Tree& node,
     arma::mat& newCentroids,
     arma::Col<size_t>& newCounts,
     arma::mat& centroids)
@@ -554,9 +571,12 @@ assignments[node.Point(i)] << " with ub " << upperBounds[node.Point(i)] <<
   }
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
-    TreeType& node,
+    Tree& node,
     const size_t child /* Which child are we? */)
 {
   // If all children except one are pruned, we can hide this node.
@@ -599,11 +619,13 @@ void DualTreeKMeans<MetricType, MatType, TreeType>::CoalesceTree(
   }
 }
 
-template<typename MetricType, typename MatType, typename TreeType>
-void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(
-    TreeType& node)
+template<typename MetricType,
+         typename MatType,
+         template<typename MetricType, typename StatisticType, typename MatType>
+             class TreeType>
+void DualTreeKMeans<MetricType, MatType, TreeType>::DecoalesceTree(Tree& node)
 {
-  node.Parent() = (TreeType*) node.Stat().TrueParent();
+  node.Parent() = (Tree*) node.Stat().TrueParent();
   RestoreChildren(node);
 
   for (size_t i = 0; i < node.NumChildren(); ++i)
diff --git a/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp b/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
index ce813b3..c7ac3f9 100644
--- a/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/pelleg_moore_kmeans.hpp
@@ -65,8 +65,8 @@ class PellegMooreKMeans
   size_t& DistanceCalculations() { return distanceCalculations; }
 
   //! Convenience typedef for the tree.
-  typedef tree::BinarySpaceTree<bound::HRectBound<2, true>,
-      PellegMooreKMeansStatistic, MatType> TreeType;
+  typedef tree::KDTree<MetricType, PellegMooreKMeansStatistic, MatType>
+      TreeType;
 
  private:
   //! The original dataset reference.



More information about the mlpack-git mailing list