[mlpack-git] master: Make DTNNKMeans work again. Next up, tree coalescion. (Is that a word?) (45a731f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:50 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 45a731f7b47a033f78a3a06ac58647d53277c478
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 17 14:14:01 2015 -0500

    Make DTNNKMeans work again. Next up, tree coalescion. (Is that a word?)


>---------------------------------------------------------------

45a731f7b47a033f78a3a06ac58647d53277c478
 src/mlpack/core/tree/cover_tree/cover_tree.hpp |  3 ++
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  7 ++-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 68 +++++++++++++++++++-------
 src/mlpack/methods/kmeans/dtnn_rules.hpp       |  6 +++
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  |  4 +-
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   |  7 +++
 6 files changed, 74 insertions(+), 21 deletions(-)

diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 41db6fb..8536718 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -208,6 +208,9 @@ class CoverTree
   template<typename RuleType>
   class DualTreeTraverser;
 
+  template<typename RuleType>
+  using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
+
   //! Get a reference to the dataset.
   const arma::mat& Dataset() const { return dataset; }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index d27988e..e756669 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -93,12 +93,15 @@ class DTNNKMeans
   arma::mat lastIterationCentroids; // For sanity checks.
 
   //! Update the bounds in the tree before the next iteration.
-  void UpdateTree(TreeType& node);
+  void UpdateTree(TreeType& node,
+                  arma::vec& clusterDistances,
+                  std::vector<size_t>& oldFromNewCentroids);
 
   //! Extract the centroids of the clusters.
   void ExtractCentroids(TreeType& node,
                         arma::mat& newCentroids,
-                        arma::Col<size_t>& newCounts);
+                        arma::Col<size_t>& newCounts,
+                        std::vector<size_t>& oldFromNewCentroids);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index bd7cb80..38e7650 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -29,7 +29,7 @@ TreeType* BuildTree(
 {
   // This is a hack.  I know this will be BinarySpaceTree, so force a leaf size
   // of two.
-  return new TreeType(dataset, oldFromNew);
+  return new TreeType(dataset, oldFromNew, 2);
 }
 
 //! Call the tree constructor that does not do mapping.
@@ -99,14 +99,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   std::vector<size_t> oldFromNewCentroids;
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
-  // Calculate new from old mappings.
-  std::vector<size_t> newFromOldCentroids;
-  if (tree::TreeTraits<TreeType>::RearrangesDataset)
-  {
-    newFromOldCentroids.resize(centroids.n_cols);
-    for (size_t i = 0; i < centroids.n_cols; ++i)
-      newFromOldCentroids[oldFromNewCentroids[i]] = i;
-  }
 
 /*
   Timer::Start("knn");
@@ -121,7 +113,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 */
 
   // We won't use the AllkNN class here because we have our own set of rules.
-  typedef typename DTNNKMeansRules<MetricType, TreeType> RuleType;
+  typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
       metric, prunedPoints, oldFromNewCentroids, visited);
 
@@ -131,14 +123,18 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   // Set the number of pruned centroids in the root to 0.
   tree->Stat().Pruned() = 0;
   traverser.Traverse(*tree, *centroidTree);
+  distanceCalculations += rules.BaseCases() + rules.Scores();
 
   // Now we need to extract the clusters.
   newCentroids.zeros(centroids.n_rows, centroids.n_cols);
   counts.zeros(centroids.n_cols);
-  ExtractCentroids(*tree, newCentroids, counts);
+  ExtractCentroids(*tree, newCentroids, counts, oldFromNewCentroids);
+  Log::Warn << "New counts: " << counts.t();
+  Log::Warn << accu(counts) << ".\n";
 
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
+  arma::vec clusterDistances(centroids.n_cols + 1);
   clusterDistances[centroids.n_cols] = 0.0;
   for (size_t c = 0; c < centroids.n_cols; ++c)
   {
@@ -164,6 +160,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
   distanceCalculations += centroids.n_cols;
 
+  UpdateTree(*tree, clusterDistances, oldFromNewCentroids);
+
   delete centroidTree;
 
   ++iteration;
@@ -173,28 +171,61 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
-    TreeType& node)
+    TreeType& node,
+    arma::vec& clusterDistances,
+    std::vector<size_t>& oldFromNewCentroids)
 {
   // Simply reset the bounds.
   node.Stat().UpperBound() = DBL_MAX;
   node.Stat().LowerBound() = DBL_MAX;
+  if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
+      (node.Stat().Owner() < clusterDistances.n_elem - 1))
+  {
+    const size_t owner = oldFromNewCentroids[node.Stat().Owner()];
+
+    node.Stat().LastUpperBound() = node.Stat().UpperBound() +
+        clusterDistances[owner];
+
+    // Update child bounds, at least a little.
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+    {
+      node.Child(i).Stat().UpperBound() = node.Stat().UpperBound();
+      node.Child(i).Stat().LowerBound() = node.Stat().LowerBound();
+      node.Child(i).Stat().Owner() = node.Stat().Owner();
+      node.Child(i).Stat().Pruned() = node.Stat().Pruned();
+    }
+  }
+  else if ((node.Stat().Pruned() == clusterDistances.n_elem - 1) &&
+           (node.Stat().Owner() >= clusterDistances.n_elem - 1))
+  {
+    Log::Warn << clusterDistances.n_cols - 1 << ".\n";
+    Log::Warn << node;
+    Log::Fatal << "Node is pruned, but has no owner!\n";
+  }
+  else
+  {
+    node.Stat().LastUpperBound() = node.Stat().UpperBound() +
+        clusterDistances[clusterDistances.n_elem - 1];
+  }
   node.Stat().Pruned() = size_t(-1);
   node.Stat().Owner() = size_t(-1);
+  node.Stat().LowerBound() = DBL_MAX;
 
   for (size_t i = 0; i < node.NumChildren(); ++i)
-    UpdateTree(node.Child(i));
+    UpdateTree(node.Child(i), clusterDistances, oldFromNewCentroids);
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     TreeType& node,
     arma::mat& newCentroids,
-    arma::Col<size_t>& newCounts)
+    arma::Col<size_t>& newCounts,
+    std::vector<size_t>& oldFromNewCentroids)
 {
   // Does this node own points?
-  if (node.Stat().Owner() < newCentroids.n_cols)
+  if (node.Stat().Pruned() == newCentroids.n_cols)
   {
-    const size_t owner = node.Stat().Owner();
+    const size_t owner = oldFromNewCentroids[node.Stat().Owner()];
     newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
     newCounts[owner] += node.NumDescendants();
   }
@@ -203,14 +234,15 @@ void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     // Check each point held in the node.
     for (size_t i = 0; i < node.NumPoints(); ++i)
     {
-      const size_t owner = assignments[node.Point(i)];
+      const size_t owner = oldFromNewCentroids[assignments[node.Point(i)]];
       newCentroids.col(owner) += dataset.col(node.Point(i));
       ++newCounts[owner];
     }
 
     // The node is not entirely owned by a cluster.  Recurse.
     for (size_t i = 0; i < node.NumChildren(); ++i)
-      ExtractCentroids(node.Child(i), newCentroids, newCounts);
+      ExtractCentroids(node.Child(i), newCentroids, newCounts,
+          oldFromNewCentroids);
   }
 }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index 1c3cda4..5252050 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -44,6 +44,12 @@ class DTNNKMeansRules
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
 
+  size_t BaseCases() const { return baseCases; }
+  size_t& BaseCases() { return baseCases; }
+
+  size_t Scores() const { return scores; }
+  size_t& Scores() { return scores; }
+
  private:
   const arma::mat& centroids;
   const arma::mat& dataset;
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index 7e1904f..6a57273 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -91,6 +91,7 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   {
     queryNode.Stat().Pruned() = queryNode.Parent()->Stat().Pruned();
     queryNode.Stat().LowerBound() = queryNode.Parent()->Stat().LowerBound();
+    queryNode.Stat().Owner() = queryNode.Parent()->Stat().Owner();
   }
 
   if (queryNode.Stat().Pruned() == centroids.n_cols)
@@ -102,7 +103,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   math::Range distances = queryNode.RangeDistance(&referenceNode);
   double score = distances.Lo();
   ++scores;
-  if (distances.Lo() > queryNode.Stat().UpperBound())
+  if (distances.Lo() > queryNode.Stat().UpperBound() ||
+      distances.Lo() > queryNode.Stat().LastUpperBound())
   {
     // The reference node can own no points in this query node.  We may improve
     // the lower bound on pruned nodes, though.
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index b81a17a..ee24b23 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -20,6 +20,7 @@ class DTNNStatistic : public
       neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
       upperBound(DBL_MAX),
       lowerBound(DBL_MAX),
+      lastUpperBound(DBL_MAX),
       owner(size_t(-1)),
       pruned(size_t(-1)),
       centroid()
@@ -32,6 +33,7 @@ class DTNNStatistic : public
       neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
       upperBound(DBL_MAX),
       lowerBound(DBL_MAX),
+      lastUpperBound(DBL_MAX),
       owner(size_t(-1)),
       pruned(size_t(-1))
   {
@@ -53,6 +55,9 @@ class DTNNStatistic : public
   double LowerBound() const { return lowerBound; }
   double& LowerBound() { return lowerBound; }
 
+  double LastUpperBound() const { return lastUpperBound; }
+  double& LastUpperBound() { return lastUpperBound; }
+
   const arma::vec& Centroid() const { return centroid; }
   arma::vec& Centroid() { return centroid; }
 
@@ -68,6 +73,7 @@ class DTNNStatistic : public
     o << "DTNNStatistic [" << this << "]:\n";
     o << "  Upper bound: " << upperBound << ".\n";
     o << "  Lower bound: " << lowerBound << ".\n";
+    o << "  Last upper bound: " << lastUpperBound << ".\n";
     o << "  Pruned: " << pruned << ".\n";
     o << "  Owner: " << owner << ".\n";
     return o.str();
@@ -76,6 +82,7 @@ class DTNNStatistic : public
  private:
   double upperBound;
   double lowerBound;
+  double lastUpperBound;
   size_t owner;
   size_t pruned;
   arma::vec centroid;



More information about the mlpack-git mailing list