[mlpack-git] master: This is an experimental method that I am working on. Right now it is not very useful as I have not implemented all of the pruning strategies that I intend to. (b35e8c2)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 22:02:24 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit b35e8c2f6b8f96376b84aab85f2ded67fb0a831d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Nov 5 19:36:49 2014 +0000

    This is an experimental method that I am working on.  Right now it is not very
    useful as I have not implemented all of the pruning strategies that I intend to.


>---------------------------------------------------------------

b35e8c2f6b8f96376b84aab85f2ded67fb0a831d
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp     |  71 +++++
 .../methods/kmeans/dual_tree_kmeans_impl.hpp       | 116 ++++++++
 .../methods/kmeans/dual_tree_kmeans_rules.hpp      |  80 ++++++
 .../methods/kmeans/dual_tree_kmeans_rules_impl.hpp | 317 +++++++++++++++++++++
 .../methods/kmeans/dual_tree_kmeans_statistic.hpp  |  96 +++++++
 5 files changed, 680 insertions(+)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
new file mode 100644
index 0000000..f2b6376
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -0,0 +1,71 @@
+/**
+ * @file dual_tree_kmeans.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree algorithm for a single k-means iteration.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_HPP
+
+#include "dual_tree_kmeans_statistic.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<
+    typename MetricType,
+    typename MatType,
+    typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>,
+        DualTreeKMeansStatistic>
+>
+class DualTreeKMeans
+{
+ public:
+  DualTreeKMeans(const MatType& dataset, MetricType& metric);
+
+  ~DualTreeKMeans();
+
+  double Iterate(const arma::mat& centroids,
+                 arma::mat& newCentroids,
+                 arma::Col<size_t>& counts);
+
+  //! Return the number of distance calculations.
+  size_t DistanceCalculations() const { return distanceCalculations; }
+  //! Modify the number of distance calculations.
+  size_t& DistanceCalculations() { return distanceCalculations; }
+
+ private:
+  //! The original dataset reference.
+  const MatType& datasetOrig;
+  //! The dataset we are using.
+  const MatType& dataset;
+  //! A copy of the dataset, if necessary.
+  MatType datasetCopy;
+  //! The metric.
+  MetricType metric;
+
+  //! The tree built on the points.
+  TreeType* tree;
+
+  arma::vec clusterDistances;
+  arma::Col<size_t> assignments;
+  arma::vec distances;
+  arma::Col<size_t> distanceIteration;
+
+  //! The current iteration.
+  size_t iteration;
+
+  //! Track distance calculations.
+  size_t distanceCalculations;
+};
+
+template<typename MetricType, typename MatType>
+using DefaultDualTreeKMeans = DualTreeKMeans<MetricType, MatType>;
+
+} // namespace kmeans
+} // namespace mlpack
+
+// Include implementation.
+#include "dual_tree_kmeans_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
new file mode 100644
index 0000000..35a3a9d
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -0,0 +1,116 @@
+/**
+ * @file dual_tree_kmeans_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A dual-tree algorithm for a single k-means iteration.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_kmeans.hpp"
+#include "dual_tree_kmeans_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename MatType, typename TreeType>
+DualTreeKMeans<MetricType, MatType, TreeType>::DualTreeKMeans(
+    const MatType& dataset,
+    MetricType& metric) :
+    datasetOrig(dataset),
+    dataset(tree::TreeTraits<TreeType>::RearrangesDataset ? datasetCopy :
+        datasetOrig),
+    metric(metric),
+    iteration(0),
+    distanceCalculations(0)
+{
+  distances.set_size(dataset.n_cols);
+  distances.fill(DBL_MAX);
+  assignments.zeros(dataset.n_cols);
+  distanceIteration.zeros(dataset.n_cols);
+
+  Timer::Start("tree_building");
+
+  // Copy the dataset, if necessary.
+  if (tree::TreeTraits<TreeType>::RearrangesDataset)
+    datasetCopy = datasetOrig;
+
+  // Now build the tree.  We don't need any mappings.
+  tree = new TreeType(const_cast<typename TreeType::Mat&>(this->dataset));
+
+  Timer::Stop("tree_building");
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+DualTreeKMeans<MetricType, MatType, TreeType>::~DualTreeKMeans()
+{
+  if (tree)
+    delete tree;
+}
+
+template<typename MetricType, typename MatType, typename TreeType>
+double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
+    const arma::mat& centroids,
+    arma::mat& newCentroids,
+    arma::Col<size_t>& counts)
+{
+  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  counts.zeros(centroids.n_cols);
+  if (clusterDistances.n_elem != centroids.n_cols + 1)
+  {
+    clusterDistances.set_size(centroids.n_cols + 1);
+    clusterDistances.fill(DBL_MAX / 2.0); // To prevent overflow.
+  }
+
+  // Build a tree on the centroids.
+  std::vector<size_t> oldFromNewCentroids;
+  TreeType* centroidTree = BuildTree<TreeType>(
+      const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
+
+  // Now run the dual-tree algorithm.
+  typedef DualTreeKMeansRules<MetricType, TreeType> RulesType;
+  RulesType rules(dataset, centroids, newCentroids, counts, oldFromNewCentroids,
+      iteration, clusterDistances, distances, assignments, distanceIteration,
+      metric);
+
+  // Use the dual-tree traverser.
+//typename TreeType::template DualTreeTraverser<RulesType> traverser(rules);
+  typename TreeType::template BreadthFirstDualTreeTraverser<RulesType>
+      traverser(rules);
+
+  traverser.Traverse(*centroidTree, *tree);
+
+  distanceCalculations += rules.DistanceCalculations();
+
+  // Now, calculate how far the clusters moved, after normalizing them.
+  double residual = 0.0;
+  clusterDistances.zeros();
+  for (size_t c = 0; c < centroids.n_cols; ++c)
+  {
+    if (counts[c] == 0)
+    {
+      newCentroids.col(c).fill(DBL_MAX); // Should have happened anyway I think.
+    }
+    else
+    {
+      const size_t oldCluster = oldFromNewCentroids[c];
+      newCentroids.col(oldCluster) /= counts(oldCluster);
+      const double dist = metric.Evaluate(centroids.col(c),
+                                          newCentroids.col(oldCluster));
+      if (dist > clusterDistances[centroids.n_cols])
+        clusterDistances[centroids.n_cols] = dist;
+      clusterDistances[oldCluster] = dist;
+      residual += std::pow(dist, 2.0);
+    }
+  }
+  Log::Info << clusterDistances.t();
+
+  ++iteration;
+  return std::sqrt(residual);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
new file mode 100644
index 0000000..e9320d1
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules.hpp
@@ -0,0 +1,80 @@
+/**
+ * @file dual_tree_kmeans_rules.hpp
+ * @author Ryan Curtin
+ *
+ * A set of tree traversal rules for dual-tree k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+class DualTreeKMeansRules
+{
+ public:
+  DualTreeKMeansRules(const typename TreeType::Mat& dataset,
+                      const arma::mat& centroids,
+                      arma::mat& newCentroids,
+                      arma::Col<size_t>& counts,
+                      const std::vector<size_t>& mappings,
+                      const size_t iteration,
+                      const arma::vec& clusterDistances,
+                      arma::vec& distances,
+                      arma::Col<size_t>& assignments,
+                      arma::Col<size_t>& distanceIteration,
+                      MetricType& metric);
+
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+  double Score(const size_t queryIndex, TreeType& referenceNode);
+
+  double Score(TreeType& queryNode, TreeType& referenceNode);
+
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore) const;
+
+  size_t DistanceCalculations() const { return distanceCalculations; }
+  size_t& DistanceCalculations() { return distanceCalculations; }
+
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
+  TraversalInfoType& TraversalInfo() { return traversalInfo; }
+
+ private:
+  const typename TreeType::Mat& dataset;
+  const arma::mat& centroids;
+  arma::mat& newCentroids;
+  arma::Col<size_t>& counts;
+  const std::vector<size_t>& mappings;
+  const size_t iteration;
+  const arma::vec& clusterDistances;
+  arma::vec& distances;
+  arma::Col<size_t>& assignments;
+  arma::Col<size_t> visited;
+  arma::Col<size_t>& distanceIteration;
+  MetricType& metric;
+
+  size_t distanceCalculations;
+
+  TraversalInfoType traversalInfo;
+
+  size_t IterationUpdate(TreeType& referenceNode) const;
+
+  bool IsDescendantOf(const TreeType& potentialParent, const TreeType&
+      potentialChild) const;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#include "dual_tree_kmeans_rules_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
new file mode 100644
index 0000000..adcedad
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_rules_impl.hpp
@@ -0,0 +1,317 @@
+/**
+ * @file dual_tree_kmeans_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * A set of tree traversal rules for dual-tree k-means clustering.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_RULES_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "dual_tree_kmeans_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+DualTreeKMeansRules<MetricType, TreeType>::DualTreeKMeansRules(
+    const typename TreeType::Mat& dataset,
+    const arma::mat& centroids,
+    arma::mat& newCentroids,
+    arma::Col<size_t>& counts,
+    const std::vector<size_t>& mappings,
+    const size_t iteration,
+    const arma::vec& clusterDistances,
+    arma::vec& distances,
+    arma::Col<size_t>& assignments,
+    arma::Col<size_t>& distanceIteration,
+    MetricType& metric) :
+    dataset(dataset),
+    centroids(centroids),
+    newCentroids(newCentroids),
+    counts(counts),
+    mappings(mappings),
+    iteration(iteration),
+    clusterDistances(clusterDistances),
+    distances(distances),
+    assignments(assignments),
+    distanceIteration(distanceIteration),
+    metric(metric),
+    distanceCalculations(0)
+{
+  // Nothing has been visited yet.
+  visited.zeros(dataset.n_cols);
+}
+
+template<typename MetricType, typename TreeType>
+inline force_inline double DualTreeKMeansRules<MetricType, TreeType>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
+{
+//  Log::Info << "Base case, query " << queryIndex << " (" << mappings[queryIndex]
+//      << "), reference " << referenceIndex << ".\n";
+
+  // Collect the number of clusters that have been pruned during the traversal.
+  // The ternary operator may not be necessary.
+  const size_t traversalPruned = (traversalInfo.LastReferenceNode() != NULL &&
+      traversalInfo.LastReferenceNode()->Stat().Iteration() == iteration) ?
+      traversalInfo.LastReferenceNode()->Stat().ClustersPruned() : 0;
+
+  // It's possible that the reference node has been pruned before we got to the
+  // base case.  In that case, don't do the base case, and just return.
+  if (traversalInfo.LastReferenceNode()->Stat().ClustersPruned() +
+      visited[referenceIndex] == centroids.n_cols)
+    return 0.0;
+
+  ++distanceCalculations;
+
+  const double distance = metric.Evaluate(centroids.col(queryIndex),
+                                          dataset.col(referenceIndex));
+
+  // Iteration change check.
+  if (distanceIteration[referenceIndex] < iteration)
+  {
+    distanceIteration[referenceIndex] = iteration;
+    distances[referenceIndex] = distance;
+    assignments[referenceIndex] = mappings[queryIndex];
+  }
+  else if (distance < distances[referenceIndex])
+  {
+    distances[referenceIndex] = distance;
+    assignments[referenceIndex] = mappings[queryIndex];
+  }
+
+  ++visited[referenceIndex];
+
+  if (visited[referenceIndex] + traversalPruned == centroids.n_cols)
+  {
+//    Log::Warn << "Commit reference index " << referenceIndex << " to cluster "
+//        << assignments[referenceIndex] << ".\n";
+    newCentroids.col(assignments[referenceIndex]) +=
+        dataset.col(referenceIndex);
+    ++counts(assignments[referenceIndex]);
+  }
+
+  return distance;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode)
+{
+  // Update from previous iteration, if necessary.
+  IterationUpdate(referenceNode);
+
+  // No pruning here, for now.
+  return 0.0;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Score(
+    TreeType& queryNode,
+    TreeType& referenceNode)
+{
+  IterationUpdate(referenceNode);
+
+  traversalInfo.LastReferenceNode() = &referenceNode;
+
+  // Can we update the minimum query node distance for this reference node?
+  const double minDistance = referenceNode.MinDistance(&queryNode);
+  ++distanceCalculations;
+  if (minDistance < referenceNode.Stat().MinQueryNodeDistance())
+  {
+    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+    referenceNode.Stat().MaxQueryNodeDistance() =
+        referenceNode.MaxDistance(&queryNode);
+    ++distanceCalculations;
+    return 0.0; // Pruning is not possible.
+  }
+  else if (IsDescendantOf(
+      *((TreeType*) referenceNode.Stat().ClosestQueryNode()), queryNode))
+  {
+    // Just update.
+    referenceNode.Stat().ClosestQueryNode() = (void*) &queryNode;
+    referenceNode.Stat().MinQueryNodeDistance() = minDistance;
+    referenceNode.Stat().MaxQueryNodeDistance() =
+        referenceNode.MaxDistance(&queryNode);
+    ++distanceCalculations;
+    return 0.0; // Pruning is not possible.
+  }
+
+  // See if we can do an Elkan-type prune on between-centroid distances.
+  const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+  const double minQueryDistance = queryNode.MinDistance((TreeType*)
+      referenceNode.Stat().ClosestQueryNode());
+  ++distanceCalculations;
+
+  if (minQueryDistance > 2.0 * maxDistance)
+  {
+    // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+    // means that N_q cannot possibly hold any clusters that own any points in
+    // N_r.
+    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+    // Have we pruned everything?
+    if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+    {
+      // Then the best query node must contain just one point.
+      const TreeType* bestQueryNode = (TreeType*)
+          referenceNode.Stat().ClosestQueryNode();
+      const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+      referenceNode.Stat().Owner() = cluster;
+      newCentroids.col(cluster) += referenceNode.NumDescendants() *
+          referenceNode.Stat().Centroid();
+      counts(cluster) += referenceNode.NumDescendants();
+      referenceNode.Stat().ClustersPruned()++;
+    }
+    else if (referenceNode.Stat().ClustersPruned() +
+        visited[referenceNode.Descendant(0)] == centroids.n_cols)
+    {
+      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+      {
+        const size_t cluster = assignments[referenceNode.Point(i)];
+        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+        counts(cluster)++;
+      }
+    }
+
+    return DBL_MAX;
+  }
+
+  return minQueryDistance;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
+    const size_t /* queryIndex */,
+    TreeType& /* referenceNode */,
+    const double oldScore) const
+{
+  return oldScore;
+}
+
+template<typename MetricType, typename TreeType>
+double DualTreeKMeansRules<MetricType, TreeType>::Rescore(
+    TreeType& queryNode,
+    TreeType& referenceNode,
+    const double oldScore) const
+{
+  if (oldScore == DBL_MAX)
+    return oldScore; // We can't unprune something.  This shouldn't happen.
+
+  // Can we update the minimum query node distance for this reference node?
+  const double minQueryDistance = oldScore;
+
+  // See if we can do an Elkan-type prune on between-centroid distances.
+  const double maxDistance = referenceNode.Stat().MaxQueryNodeDistance();
+
+  if (minQueryDistance > 2.0 * maxDistance)
+  {
+    // Then we can conclude d_max(best(N_r), N_r) <= d_min(N_q, N_r) which
+    // means that N_q cannot possibly hold any clusters that own any points in
+    // N_r.
+    referenceNode.Stat().ClustersPruned() += queryNode.NumDescendants();
+
+    // Have we pruned everything?
+    if (referenceNode.Stat().ClustersPruned() == centroids.n_cols - 1)
+    {
+      // Then the best query node must contain just one point.
+      const TreeType* bestQueryNode = (TreeType*)
+          referenceNode.Stat().ClosestQueryNode();
+      const size_t cluster = mappings[bestQueryNode->Descendant(0)];
+
+      referenceNode.Stat().Owner() = cluster;
+      newCentroids.col(cluster) += referenceNode.NumDescendants() *
+          referenceNode.Stat().Centroid();
+      counts(cluster) += referenceNode.NumDescendants();
+      referenceNode.Stat().ClustersPruned()++;
+    }
+    else if (referenceNode.Stat().ClustersPruned() +
+        visited[referenceNode.Descendant(0)] == centroids.n_cols)
+    {
+      for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
+      {
+        const size_t cluster = assignments[referenceNode.Point(i)];
+        newCentroids.col(cluster) += dataset.col(referenceNode.Point(i));
+        counts(cluster)++;
+      }
+    }
+
+    return DBL_MAX;
+  }
+
+  return oldScore;
+}
+
+template<typename MetricType, typename TreeType>
+inline size_t DualTreeKMeansRules<MetricType, TreeType>::IterationUpdate(
+    TreeType& referenceNode) const
+{
+  if (referenceNode.Stat().Iteration() == iteration)
+    return 0;
+
+  referenceNode.Stat().Iteration() = iteration;
+  referenceNode.Stat().ClustersPruned() = (referenceNode.Parent() == NULL) ?
+      0 : referenceNode.Parent()->Stat().ClustersPruned();
+  referenceNode.Stat().ClosestQueryNode() = (referenceNode.Parent() == NULL) ?
+      NULL : referenceNode.Parent()->Stat().ClosestQueryNode();
+
+  if (referenceNode.Stat().ClosestQueryNode() != NULL)
+    referenceNode.Stat().MinQueryNodeDistance() =
+        referenceNode.MinDistance((TreeType*)
+        referenceNode.Stat().ClosestQueryNode());
+
+  const size_t itDiff = iteration - referenceNode.Stat().Iteration();
+  if (itDiff > 1)
+  {
+    // Maybe this can be tighter?
+    referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+  }
+  else
+  {
+    if (referenceNode.Stat().MinQueryNodeDistance() != DBL_MAX)
+    {
+      // Update the distance to the closest query node.  If this node has an
+      // owner, we know how far to increase the bound.  Otherwise, increase it
+      // by the furthest amount that any centroid moved.
+//      if (referenceNode.Stat().Owner() < centroids.n_cols)
+//        referenceNode.Stat().MinQueryNodeDistance() +=
+//            clusterDistances(referenceNode.Stat().Owner());
+//      else
+//        referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+//            clusterDistances(centroids.n_cols);
+      if (referenceNode.Stat().MaxQueryNodeDistance() == DBL_MAX)
+        referenceNode.Stat().MinQueryNodeDistance() = DBL_MAX;
+      else
+      {
+        referenceNode.Stat().MinQueryNodeDistance() +=
+            clusterDistances(centroids.n_cols);
+//referenceNode.Stat().MaxQueryNodeDistance() +
+//clusterDistances(centroids.n_cols);
+      }
+    }
+  }
+
+  return 1;
+}
+
+template<typename MetricType, typename TreeType>
+bool DualTreeKMeansRules<MetricType, TreeType>::IsDescendantOf(
+    const TreeType& potentialParent,
+    const TreeType& potentialChild) const
+{
+  if (potentialChild.Parent() == &potentialParent)
+    return true;
+  else if (potentialChild.Parent() == NULL)
+    return false;
+  else
+    return IsDescendantOf(potentialParent, *potentialChild.Parent());
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
new file mode 100644
index 0000000..21481da
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_statistic.hpp
@@ -0,0 +1,96 @@
+/**
+ * @file dual_tree_kmeans_statistic.hpp
+ * @author Ryan Curtin
+ *
+ * Statistic for dual-tree k-means traversal.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_STATISTIC_HPP
+#define __MLPACK_METHODS_KMEANS_DUAL_TREE_KMEANS_STATISTIC_HPP
+
+namespace mlpack {
+namespace kmeans {
+
+class DualTreeKMeansStatistic
+{
+ public:
+  DualTreeKMeansStatistic() { /* Nothing to do. */ }
+
+  template<typename TreeType>
+  DualTreeKMeansStatistic(TreeType& node) :
+      closestQueryNode(NULL),
+      minQueryNodeDistance(DBL_MAX),
+      maxQueryNodeDistance(DBL_MAX),
+      clustersPruned(0),
+      iteration(size_t() - 1)
+  {
+    // Empirically calculate the centroid.
+    centroid.zeros(node.Dataset().n_rows);
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+      centroid += node.Dataset().col(node.Point(i));
+
+    for (size_t i = 0; i < node.NumChildren(); ++i)
+      centroid += node.Child(i).NumDescendants() *
+          node.Child(i).Stat().Centroid();
+
+    centroid /= node.NumDescendants();
+  }
+
+  //! Return the centroid.
+  const arma::vec& Centroid() const { return centroid; }
+  //! Modify the centroid.
+  arma::vec& Centroid() { return centroid; }
+
+  //! Get the current closest query node.
+  void* ClosestQueryNode() const { return closestQueryNode; }
+  //! Modify the current closest query node.
+  void*& ClosestQueryNode() { return closestQueryNode; }
+
+  //! Get the minimum distance to the closest query node.
+  double MinQueryNodeDistance() const { return minQueryNodeDistance; }
+  //! Modify the minimum distance to the closest query node.
+  double& MinQueryNodeDistance() { return minQueryNodeDistance; }
+
+  //! Get the maximum distance to the closest query node.
+  double MaxQueryNodeDistance() const { return maxQueryNodeDistance; }
+  //! Modify the maximum distance to the closest query node.
+  double& MaxQueryNodeDistance() { return maxQueryNodeDistance; }
+
+  //! Get the number of clusters that have been pruned during this iteration.
+  size_t ClustersPruned() const { return clustersPruned; }
+  //! Modify the number of clusters that have been pruned during this iteration.
+  size_t& ClustersPruned() { return clustersPruned; }
+
+  //! Get the current iteration.
+  size_t Iteration() const { return iteration; }
+  //! Modify the current iteration.
+  size_t& Iteration() { return iteration; }
+
+  //! Get the current owner (if any) of these reference points.
+  size_t Owner() const { return owner; }
+  //! Modify the current owner (if any) of these reference points.
+  size_t& Owner() { return owner; }
+
+ private:
+  //! The empirically calculated centroid of the node.
+  arma::vec centroid;
+
+  //! The current closest query node to this reference node.
+  void* closestQueryNode;
+  //! The minimum distance to the closest query node.
+  double minQueryNodeDistance;
+  //! The maximum distance to the closest query node.
+  double maxQueryNodeDistance;
+
+  //! The number of clusters that have been pruned.
+  size_t clustersPruned;
+  //! The current iteration.
+  size_t iteration;
+  //! The owner of these reference nodes (centroids.n_cols if there is no
+  //! owner).
+  size_t owner;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list