[mlpack-git] master: Refactor DTNNKMeans according to the new algorithm. Lots of stuff torn out. It'll go back in when the time is right. (1d2edb9)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:04:21 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 1d2edb9c20dc8aa162a040cbbd9dfea04d4f8999
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Feb 16 17:18:57 2015 -0500

    Refactor DTNNKMeans according to the new algorithm. Lots of stuff torn out. It'll go back in when the time is right.


>---------------------------------------------------------------

1d2edb9c20dc8aa162a040cbbd9dfea04d4f8999
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp      |  30 +-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 467 +++----------------------
 src/mlpack/methods/kmeans/dtnn_statistic.hpp   |  36 +-
 3 files changed, 62 insertions(+), 471 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index e6ccd58..d27988e 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -79,38 +79,26 @@ class DTNNKMeans
   //! Track iteration number.
   size_t iteration;
 
-  //! Centroids from pruning.  Not normalized.
-  arma::mat prunedCentroids;
-  //! Counts from pruning.  Not normalized.
-  arma::Col<size_t> prunedCounts;
-
-  //! Distances that the clusters moved last iteration.
-  arma::vec clusterDistances;
-
+  //! Upper bounds on nearest centroid.
+  arma::vec upperBounds;
   //! Lower bounds on second closest cluster distance for each point.
-  arma::vec lowerSecondBounds;
+  arma::vec lowerBounds;
   //! Indicator of whether or not the point is pruned.
   std::vector<bool> prunedPoints;
-  //! The last cluster each point was assigned to.
-  arma::Col<size_t> lastOwners;
 
-  arma::mat distances;
-  arma::Mat<size_t> assignments;
+  arma::Col<size_t> assignments;
 
   std::vector<bool> visited; // Was the point visited this iteration?
 
   arma::mat lastIterationCentroids; // For sanity checks.
 
   //! Update the bounds in the tree before the next iteration.
-  void UpdateTree(TreeType& node,
-                  const arma::mat& centroids,
-                  const arma::mat& interclusterDistances,
-                  const std::vector<size_t>& newFromOldCentroids);
-
-  void PrecalculateCentroids(TreeType& node);
+  void UpdateTree(TreeType& node);
 
-  void CoalesceTree(TreeType& node, const size_t child = 0);
-  void DecoalesceTree(TreeType& node);
+  //! Extract the centroids of the clusters.
+  void ExtractCentroids(TreeType& node,
+                        arma::mat& newCentroids,
+                        arma::Col<size_t>& newCounts);
 };
 
 //! A template typedef for the DTNNKMeans algorithm with the default tree type
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 3fa25ab..bd7cb80 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -53,20 +53,12 @@ DTNNKMeans<MetricType, MatType, TreeType>::DTNNKMeans(const MatType& dataset,
     metric(metric),
     distanceCalculations(0),
     iteration(0),
-    distances(2, dataset.n_cols),
-    assignments(2, dataset.n_cols)
+    upperBounds(dataset.n_cols),
+    lowerBounds(dataset.n_cols),
+    prunedPoints(dataset.n_cols, false), // Fill with false.
+    assignments(dataset.n_cols),
+    visited(dataset.n_cols, false) // Fill with false.
 {
-  prunedPoints.resize(dataset.n_cols, false); // Fill with false.
-  lowerSecondBounds.zeros(dataset.n_cols);
-  lastOwners.zeros(dataset.n_cols);
-
-  assignments.set_size(2, dataset.n_cols);
-  assignments.fill(size_t(-1));
-  distances.set_size(2, dataset.n_cols);
-  distances.fill(DBL_MAX);
-
-  visited.resize(dataset.n_cols, false);
-
   Timer::Start("tree_building");
 
   // Copy the dataset, if necessary.
@@ -93,17 +85,15 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
     arma::mat& newCentroids,
     arma::Col<size_t>& counts)
 {
-  if (iteration == 0)
+  // Reset information.
+  upperBounds.fill(DBL_MAX);
+  lowerBounds.fill(DBL_MAX);
+  for (size_t i = 0; i < dataset.n_cols; ++i)
   {
-    prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
-    prunedCounts.zeros(centroids.n_cols);
-    // The last element stores the maximum.
-    clusterDistances.zeros(centroids.n_cols + 1);
+    prunedPoints[i] = false;
+    visited[i] = false;
   }
 
-  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
-  counts.zeros(centroids.n_cols);
-
   // Build a tree on the centroids.
   arma::mat oldCentroids(centroids); // Slow. :(
   std::vector<size_t> oldFromNewCentroids;
@@ -118,6 +108,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
       newFromOldCentroids[oldFromNewCentroids[i]] = i;
   }
 
+/*
   Timer::Start("knn");
   // Find the nearest neighbors of each of the clusters.
   neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType, TreeType>
@@ -127,50 +118,24 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   nns.Search(1, closestClusters, interclusterDistances);
   distanceCalculations += nns.BaseCases() + nns.Scores();
   Timer::Stop("knn");
-
-  if (iteration != 0)
-  {
-    // Do the tree update for the previous iteration.
-
-    // Reset centroids and counts for things we will collect during pruning.
-    Timer::Start("it_update");
-    prunedCentroids.zeros(centroids.n_rows, centroids.n_cols);
-    prunedCounts.zeros(centroids.n_cols);
-    UpdateTree(*tree, oldCentroids, interclusterDistances, newFromOldCentroids);
-
-    PrecalculateCentroids(*tree);
-    Timer::Stop("it_update");
-  }
-
-  Timer::Start("tree_mod");
-  CoalesceTree(*tree);
-  Timer::Stop("tree_mod");
+*/
 
   // We won't use the AllkNN class here because we have our own set of rules.
-  // This is a lot of overhead.  We don't need the distances.
+  typedef typename DTNNKMeansRules<MetricType, TreeType> RuleType;
+  RuleType rules(centroids, dataset, assignments, upperBounds, lowerBounds,
+      metric, prunedPoints, oldFromNewCentroids, visited);
 
-  Timer::Start("tree_mod");
-  DecoalesceTree(*tree);
-  Timer::Stop("tree_mod");
+  typename TreeType::template BreadthFirstDualTreeTraverser<RuleType>
+      traverser(rules);
 
-//  Log::Info << "This iteration: " << rules.BaseCases() << " base cases, " <<
-//      rules.Scores() << " scores.\n";
-//  distanceCalculations += rules.BaseCases() + rules.Scores();
+  // Set the number of pruned centroids in the root to 0.
+  tree->Stat().Pruned() = 0;
+  traverser.Traverse(*tree, *centroidTree);
 
-  // From the assignments, calculate the new centroids and counts.
-  for (size_t i = 0; i < dataset.n_cols; ++i)
-  {
-    if (visited[i])
-    {
-      newCentroids.col(assignments(0, i)) += dataset.col(i);
-      ++counts(assignments(0, i));
-      // Reset for next iteration.
-      visited[i] = false;
-    }
-  }
-
-  newCentroids += prunedCentroids;
-  counts += prunedCounts;
+  // Now we need to extract the clusters.
+  newCentroids.zeros(centroids.n_rows, centroids.n_cols);
+  counts.zeros(centroids.n_cols);
+  ExtractCentroids(*tree, newCentroids, counts);
 
   // Now, calculate how far the clusters moved, after normalizing them.
   double residual = 0.0;
@@ -199,8 +164,6 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   }
   distanceCalculations += centroids.n_cols;
 
-//  lastIterationCentroids = oldCentroids;
-
   delete centroidTree;
 
   ++iteration;
@@ -210,376 +173,44 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
 
 template<typename MetricType, typename MatType, typename TreeType>
 void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
-    TreeType& node,
-    const arma::mat& centroids,
-    const arma::mat& interclusterDistances,
-    const std::vector<size_t>& newFromOldCentroids)
+    TreeType& node)
 {
-/*
-  // Update iteration.
-//  node.Stat().Iteration() = iteration;
+  // Simply reset the bounds.
+  node.Stat().UpperBound() = DBL_MAX;
+  node.Stat().LowerBound() = DBL_MAX;
+  node.Stat().Pruned() = size_t(-1);
+  node.Stat().Owner() = size_t(-1);
 
-  if (node.Stat().Owner() == size_t(-1))
-    node.Stat().Owner() = centroids.n_cols;
-
-  // Do the tree update in a depth-first manner: leaves first.
-  bool childrenPruned = true;
   for (size_t i = 0; i < node.NumChildren(); ++i)
-  {
-    UpdateTree(node.Child(i), centroids, interclusterDistances,
-        newFromOldCentroids);
-    if (!node.Child(i).Stat().Pruned())
-      childrenPruned = false; // Not all children are pruned.
-  }
-
-  const bool prunedLastIteration = node.Stat().Pruned();
-
-  // Does the node have a single owner?
-  // It would be nice if we could do this during the traversal.
-  bool singleOwner = true;
-  size_t owner = centroids.n_cols + 1;
-  if (!node.Stat().Pruned() && childrenPruned)
-  {
-    // Determine the bounds for the points.
-    double newMaxClusterDistance = 0.0;
-    double newSecondClusterBound = DBL_MAX;
-    for (size_t i = 0; i < node.NumPoints(); ++i)
-    {
-      // Don't forget to map back from the new cluster index.
-      size_t c;
-      if (!prunedPoints[node.Point(i)])
-        c = assignments(0, node.Point(i));
-      else
-        c = lastOwners[node.Point(i)];
-
-      if (owner == centroids.n_cols + 1)
-        owner = c;
-      else if (owner != c)
-      {
-        singleOwner = false;
-        break;
-      }
-
-      // Update maximum cluster distance and second cluster bound.
-      if (!prunedPoints[node.Point(i)])
-      {
-        if (distances(0, node.Point(i)) > newMaxClusterDistance)
-          newMaxClusterDistance = distances(0, node.Point(i));
-        if (distances(1, node.Point(i)) < newSecondClusterBound)
-          newSecondClusterBound = distances(1, node.Point(i));
-      }
-      else
-      {
-        // Use the cached bounds.
-        if (distances(0, node.Point(i)) > newMaxClusterDistance)
-          newMaxClusterDistance = distances(0, node.Point(i));
-        if (lowerSecondBounds[node.Point(i)] < newSecondClusterBound)
-          newSecondClusterBound = lowerSecondBounds[node.Point(i)];
-      }
-    }
-
-    for (size_t i = 0; i < node.NumChildren(); ++i)
-    {
-      if (owner == centroids.n_cols + 1)
-        owner = node.Child(i).Stat().Owner();
-      else if ((node.Child(i).Stat().Owner() == centroids.n_cols) ||
-               (owner != node.Child(i).Stat().Owner()))
-      {
-        singleOwner = false;
-        break;
-      }
-
-      // Update maximum cluster distance and second cluster bound.
-      if (node.Child(i).Stat().MaxClusterDistance() > newMaxClusterDistance)
-        newMaxClusterDistance = node.Child(i).Stat().MaxClusterDistance();
-      if (node.Child(i).Stat().SecondClusterBound() < newSecondClusterBound)
-        newSecondClusterBound = node.Child(i).Stat().SecondClusterBound();
-    }
-
-    // Okay, now we know if it's owned or not, and by which cluster.
-    if (singleOwner)
-    {
-      node.Stat().Owner() = owner;
-
-      // What do we do with the new cluster bounds?
-      if (newMaxClusterDistance > 0.0 && newMaxClusterDistance <
-          node.Stat().MaxClusterDistance())
-        node.Stat().MaxClusterDistance() = newMaxClusterDistance;
-      if (newSecondClusterBound != DBL_MAX && newSecondClusterBound >
-          node.Stat().SecondClusterBound())
-        node.Stat().SecondClusterBound() = newSecondClusterBound;
-
-      // Convenience variables to clean up the expressions.
-      const double mcd = node.Stat().MaxClusterDistance();
-      const double scb = node.Stat().SecondClusterBound();
-      const double ownerMovement = clusterDistances[owner];
-      const double maxMovement = clusterDistances[centroids.n_cols];
-      const double closestClusterDistance =
-          interclusterDistances[newFromOldCentroids[owner]];
-      if ((node.NumPoints() == 0 && childrenPruned) ||
-          (mcd + ownerMovement < scb - maxMovement) ||
-          (mcd < 0.5 * closestClusterDistance))
-        node.Stat().Pruned() = true;
-
-      if (!node.Stat().Pruned() && (mcd - ownerMovement) < (scb - maxMovement))
-      {
-        // Calculate the next MCD by hand.
-        const double newDist = node.MaxDistance(centroids.col(owner));
-        ++distanceCalculations;
-        node.Stat().MaxClusterDistance() = newDist;
-
-        if ((newDist < scb - maxMovement) ||
-            (newDist < 0.5 * closestClusterDistance))
-          node.Stat().Pruned() = true;
-        else
-          node.Stat().SecondClusterBound() -= maxMovement;
-      }
-      else
-      {
-        // Adjust bounds for next iteration, regardless of whether or not the
-        // node was pruned.  (Does this adjustment need to happen if there is no
-        // prune?
-        node.Stat().MaxClusterDistance() += ownerMovement;
-        node.Stat().SecondClusterBound() -= maxMovement;
-      }
-    }
-    else if (childrenPruned && node.NumChildren() > 0 && node.NumPoints() == 0)
-    {
-      // The node isn't owned by a single cluster.  But if it has no points and
-      // its children are all pruned, we may prune it too.
-      node.Stat().Pruned() = true;
-      node.Stat().Owner() = centroids.n_cols;
-    }
-  }
-  else if (node.Stat().Pruned())
-  {
-    // The node was pruned last iteration.  See if the node can remain pruned.
-    singleOwner = false;
-
-    // If it was pruned because all points were pruned, we need to check
-    // individually.
-    if (node.Stat().Owner() == centroids.n_cols)
-    {
-      node.Stat().Pruned() = false;
-    }
-    else
-    { 
-      // Will our bounds still work?
-      if (node.Stat().MaxClusterDistance() +
-          clusterDistances[node.Stat().Owner()] <
-          node.Stat().SecondClusterBound() - clusterDistances[centroids.n_cols])
-      {
-        // The node remains pruned.  Adjust the bounds for next iteration.
-        node.Stat().MaxClusterDistance() +=
-            clusterDistances[node.Stat().Owner()];
-        node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
-      }
-      else
-      {
-        // Attempt other prune.
-        if (node.Stat().MaxClusterDistance() < 0.5 *
-            interclusterDistances[newFromOldCentroids[node.Stat().Owner()]])
-        {
-          // The node remains pruned.  Adjust the bounds for next iteration.
-          node.Stat().MaxClusterDistance() +=
-              clusterDistances[node.Stat().Owner()];
-          node.Stat().SecondClusterBound() -= clusterDistances[centroids.n_cols];
-        }
-        else
-        {
-          node.Stat().Pruned() = false;
-          node.Stat().MaxClusterDistance() = DBL_MAX;
-          node.Stat().SecondClusterBound() = 0.0;
-        }
-      }
-    }
-  }
-  else
-  {
-    // The children haven't been pruned, so we can't.
-    // This node was not pruned last iteration, so we simply need to adjust the
-    // bounds.
-    node.Stat().Owner() = centroids.n_cols;
-    if (node.Stat().MaxClusterDistance() != DBL_MAX)
-      node.Stat().MaxClusterDistance() += clusterDistances[centroids.n_cols];
-    if (node.Stat().SecondClusterBound() != DBL_MAX)
-      node.Stat().SecondClusterBound() = std::max(0.0,
-          node.Stat().SecondClusterBound() -
-          clusterDistances[centroids.n_cols]);
-  }
-
-  // If the node wasn't pruned, try to prune individual points.
-  if (!node.Stat().Pruned())
-  {
-    bool allPruned = true;
-    for (size_t i = 0; i < node.NumPoints(); ++i)
-    {
-      const size_t index = node.Point(i);
-      size_t owner;
-      if (prunedLastIteration && node.Stat().Owner() < centroids.n_cols)
-        owner = node.Stat().Owner();
-      else
-        owner = assignments(0, index);
-
-      // Update lower bound, if possible.
-      if (!prunedLastIteration && !prunedPoints[index])
-        lowerSecondBounds[index] = distances(1, index);
-
-      const double upperPointBound = distances(0, index) +
-          clusterDistances[owner];
-      const double lowerSecondBound = lowerSecondBounds[index] -
-          clusterDistances[centroids.n_cols];
-      const double closestClusterDistance =
-          interclusterDistances[newFromOldCentroids[owner]];
-      if ((upperPointBound < lowerSecondBound) ||
-          (upperPointBound < 0.5 * closestClusterDistance))
-      {
-        prunedPoints[index] = true;
-        distances(0, index) += clusterDistances[owner];
-        lastOwners[index] = owner;
-        distances(1, index) += clusterDistances[centroids.n_cols];
-        lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
-        prunedCentroids.col(owner) += dataset.col(index);
-        prunedCounts(owner)++;
-      }
-      else
-      {
-        // Attempt to tighten the lower bound.
-        distances(0, index) = metric.Evaluate(centroids.col(owner),
-                                             dataset.col(index));
-        ++distanceCalculations;
-
-        if ((distances(0, index) < lowerSecondBound) ||
-            (distances(0, index) < 0.5 * closestClusterDistance))
-        {
-          prunedPoints[index] = true;
-          lastOwners[index] = owner;
-          lowerSecondBounds[index] -= clusterDistances[centroids.n_cols];
-          distances(1, index) += clusterDistances[centroids.n_cols];
-          prunedCentroids.col(owner) += dataset.col(index);
-          prunedCounts(owner)++;
-        }
-        else
-        {
-          prunedPoints[index] = false;
-          allPruned = false;
-          // Still update these anyway.
-          distances(1, index) += clusterDistances[centroids.n_cols];
-        }
-      }
-    }
-
-    if (allPruned && node.NumPoints() > 0)
-    {
-      // Prune the entire node.
-      node.Stat().Pruned() = true;
-      node.Stat().Owner() = centroids.n_cols;
-    }
-  }
-
-  if (node.Stat().Pruned())
-  {
-    // Update bounds.
-    for (size_t i = 0; i < node.NumPoints(); ++i)
-    {
-      const size_t index = node.Point(i);
-      lowerSecondBounds[index] -= clusterDistances[node.Stat().Owner()];
-    }
-  }
-
-  // Make sure all the point bounds are updated.
-  for (size_t i = 0; i < node.NumPoints(); ++i)
-  {
-    const size_t index = node.Point(i);
-    distances(0, index) += clusterDistances[assignments(0, index)];
-    distances(1, index) += clusterDistances[assignments(1, index)];
-  }
-
-  if (node.Stat().FirstBound() != DBL_MAX)
-    node.Stat().FirstBound() += clusterDistances[centroids.n_cols];
-  if (node.Stat().SecondBound() != DBL_MAX)
-    node.Stat().SecondBound() += clusterDistances[centroids.n_cols];
-  if (node.Stat().Bound() != DBL_MAX)
-    node.Stat().Bound() += clusterDistances[centroids.n_cols];
-*/
+    UpdateTree(node.Child(i));
 }
 
 template<typename MetricType, typename MatType, typename TreeType>
-void DTNNKMeans<MetricType, MatType, TreeType>::CoalesceTree(
+void DTNNKMeans<MetricType, MatType, TreeType>::ExtractCentroids(
     TreeType& node,
-    const size_t child /* Which child are we? */)
+    arma::mat& newCentroids,
+    arma::Col<size_t>& newCounts)
 {
-/*
-  // If one of the two children is pruned, we hide this node.
-  // This assumes the BinarySpaceTree.  (bad Ryan! bad!)
-  if (node.NumChildren() == 0)
-    return; // We can't do anything.
-
-  // If this is the root node, we can't coalesce.
-  if (node.Parent() != NULL)
+  // Does this node own points?
+  if (node.Stat().Owner() < newCentroids.n_cols)
   {
-    if (node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
-    {
-      CoalesceTree(node.Child(1), 1);
-
-      // Link the right child to the parent.
-      node.Child(1).Parent() = node.Parent();
-      node.Parent()->ChildPtr(child) = node.ChildPtr(1);
-    }
-    else if (!node.Child(0).Stat().Pruned() && node.Child(1).Stat().Pruned())
-    {
-      CoalesceTree(node.Child(0), 0);
-
-      // Link the left child to the parent.
-      node.Child(0).Parent() = node.Parent();
-      node.Parent()->ChildPtr(child) = node.ChildPtr(0);
-
-    }
-    else if (!node.Child(0).Stat().Pruned() && !node.Child(1).Stat().Pruned())
-    {
-      // The conditional is probably not necessary.
-      CoalesceTree(node.Child(0), 0);
-      CoalesceTree(node.Child(1), 1);
-    }
+    const size_t owner = node.Stat().Owner();
+    newCentroids.col(owner) += node.Stat().Centroid() * node.NumDescendants();
+    newCounts[owner] += node.NumDescendants();
   }
   else
   {
-    CoalesceTree(node.Child(0), 0);
-    CoalesceTree(node.Child(1), 1);
-  }
-*/
-}
-
-template<typename MetricType, typename MatType, typename TreeType>
-void DTNNKMeans<MetricType, MatType, TreeType>::DecoalesceTree(TreeType& node)
-{
-/*
-  node.Parent() = (TreeType*) node.Stat().TrueParent();
-  node.ChildPtr(0) = (TreeType*) node.Stat().TrueLeft();
-  node.ChildPtr(1) = (TreeType*) node.Stat().TrueRight();
-
-  if (node.NumChildren() > 0)
-  {
-    DecoalesceTree(node.Child(0));
-    DecoalesceTree(node.Child(1));
-  }
-*/
-}
+    // Check each point held in the node.
+    for (size_t i = 0; i < node.NumPoints(); ++i)
+    {
+      const size_t owner = assignments[node.Point(i)];
+      newCentroids.col(owner) += dataset.col(node.Point(i));
+      ++newCounts[owner];
+    }
 
-template<typename MetricType, typename MatType, typename TreeType>
-void DTNNKMeans<MetricType, MatType, TreeType>::PrecalculateCentroids(
-    TreeType& node)
-{
-  if (node.Stat().Pruned() && node.Stat().Owner() < prunedCentroids.n_cols)
-  {
-    prunedCentroids.col(node.Stat().Owner()) += node.Stat().Centroid() *
-        node.NumDescendants();
-    prunedCounts(node.Stat().Owner()) += node.NumDescendants();
-  }
-  else
-  {
+    // The node is not entirely owned by a cluster.  Recurse.
     for (size_t i = 0; i < node.NumChildren(); ++i)
-      PrecalculateCentroids(node.Child(i));
+      ExtractCentroids(node.Child(i), newCentroids, newCounts);
   }
 }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_statistic.hpp b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
index dfeedbb..b81a17a 100644
--- a/src/mlpack/methods/kmeans/dtnn_statistic.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_statistic.hpp
@@ -22,10 +22,7 @@ class DTNNStatistic : public
       lowerBound(DBL_MAX),
       owner(size_t(-1)),
       pruned(size_t(-1)),
-      centroid(),
-      trueLeft(NULL),
-      trueRight(NULL),
-      trueParent(NULL)
+      centroid()
   {
     // Nothing to do.
   }
@@ -36,8 +33,7 @@ class DTNNStatistic : public
       upperBound(DBL_MAX),
       lowerBound(DBL_MAX),
       owner(size_t(-1)),
-      pruned(size_t(-1)),
-      trueParent((void*) node.Parent())
+      pruned(size_t(-1))
   {
     // Empirically calculate the centroid.
     centroid.zeros(node.Dataset().n_rows);
@@ -49,18 +45,6 @@ class DTNNStatistic : public
           node.Child(i).Stat().Centroid();
 
     centroid /= node.NumDescendants();
-
-    // Do we have children?
-    if (node.NumChildren() >= 2)
-    {
-      trueLeft = &node.Child(0);
-      trueRight = &node.Child(1);
-    }
-    else
-    {
-      trueLeft = NULL;
-      trueRight = NULL;
-    }
   }
 
   double UpperBound() const { return upperBound; }
@@ -72,20 +56,11 @@ class DTNNStatistic : public
   const arma::vec& Centroid() const { return centroid; }
   arma::vec& Centroid() { return centroid; }
 
-  size_t Pruned() const { return pruned; }
-  size_t& Pruned() { return pruned; }
-
   size_t Owner() const { return owner; }
   size_t& Owner() { return owner; }
 
-  const void* TrueLeft() const { return trueLeft; }
-  void*& TrueLeft() { return trueLeft; }
-
-  const void* TrueRight() const { return trueRight; }
-  void*& TrueRight() { return trueRight; }
-
-  const void* TrueParent() const { return trueParent; }
-  void*& TrueParent() { return trueParent; }
+  size_t Pruned() const { return pruned; }
+  size_t& Pruned() { return pruned; }
 
   std::string ToString() const
   {
@@ -104,9 +79,6 @@ class DTNNStatistic : public
   size_t owner;
   size_t pruned;
   arma::vec centroid;
-  void* trueLeft;
-  void* trueRight;
-  void* trueParent;
 };
 
 } // namespace kmeans



More information about the mlpack-git mailing list