[mlpack-git] master: Refactor BaseCase() to apply mappings. This reduces memory usage. (acd8db8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:54 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit acd8db891c8c67b5c1455cde20ed60ae9b4acd93
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Feb 3 11:14:41 2015 -0500

    Refactor BaseCase() to apply mappings. This reduces memory usage.


>---------------------------------------------------------------

acd8db891c8c67b5c1455cde20ed60ae9b4acd93
 src/mlpack/methods/kmeans/dtnn_kmeans.hpp          |   2 +-
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp     | 126 +++++++++++----------
 src/mlpack/methods/kmeans/dtnn_rules.hpp           |  30 ++---
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp      |  78 +++++++++++--
 .../neighbor_search/neighbor_search_rules.hpp      |   2 +-
 5 files changed, 146 insertions(+), 92 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index b8c2792..b939f32 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -99,7 +99,7 @@ class DTNNKMeans
   arma::mat distances;
   arma::Mat<size_t> assignments;
 
-  std::vector<size_t> lastOldFromNewCentroids;
+  arma::mat lastIterationCentroids; // For sanity checks.
 
   //! Update the bounds in the tree before the next iteration.
   void UpdateTree(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 1f48545..95546d3 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -139,7 +139,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   assignments.fill(size_t(-1));
   typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
   RuleType rules(centroids, dataset, assignments, distances, metric,
-      prunedPoints);
+      prunedPoints, oldFromNewCentroids);
 
   // Now construct the traverser ourselves.
   typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -155,17 +155,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   {
     if (assignments(0, i) != size_t(-1))
     {
-      if (tree::TreeTraits<TreeType>::RearrangesDataset)
-      {
-        newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
-            dataset.col(i);
-        ++counts(oldFromNewCentroids[assignments(0, i)]);
-      }
-      else
-      {
-        newCentroids.col(assignments(0, i)) += dataset.col(i);
-        ++counts(assignments(0, i));
-      }
+      newCentroids.col(assignments(0, i)) += dataset.col(i);
+      ++counts(assignments(0, i));
     }
   }
 
@@ -200,7 +191,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   clusterDistances[centroids.n_cols] = maxMovement;
   distanceCalculations += centroids.n_cols;
 
-  lastOldFromNewCentroids = oldFromNewCentroids;
+//  lastIterationCentroids = oldCentroids;
 
   delete centroidTree;
 
@@ -248,9 +239,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
       // Don't forget to map back from the new cluster index.
       size_t c;
       if (!prunedPoints[node.Point(i)])
-        c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
-            lastOldFromNewCentroids[assignments(0, node.Point(i))] :
-            assignments(0, node.Point(i));
+        c = assignments(0, node.Point(i));
       else
         c = lastOwners[node.Point(i)];
 
@@ -310,30 +299,42 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
 /*
       for (size_t i = 0; i < node.NumPoints(); ++i)
       {
-        const double ownerDist = metric.Evaluate(dataset.col(node.Point(i)),
-            centroids.col(owner));
+        arma::vec dists(centroids.n_cols);
+        size_t trueOwner = centroids.n_cols;
+        double trueDist = DBL_MAX;
         for (size_t j = 0; j < centroids.n_cols; ++j)
         {
           const double dist = metric.Evaluate(dataset.col(node.Point(i)),
-              centroids.col(j));
-          if (dist < ownerDist)
+              lastIterationCentroids.col(j));
+          dists(j) = dist;
+          if (dist < trueDist)
           {
-            Log::Warn << node << "...\n" << *node.Parent();
+            trueDist = dist;
+            trueOwner = j;
+          }
+        }
+
+        if (trueOwner != owner)
+        {
+          Log::Warn << node << "...\n" << *node.Parent();
+          Log::Warn << dists.t();
+          Log::Warn << "Assignment: " << assignments(0, node.Point(i)) << ".\n";
+          Log::Warn << "Dists: " << distances(0, node.Point(i)) << ", " <<
+distances(1, node.Point(i)) << ".\n";
 //            TreeType* n = node.Parent()->Parent();
 //            while (n != NULL)
 //            {
 //              Log::Warn << "...\n" << *n;
 //              n = n->Parent();
 //            }
-            Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
-                << owner << " but has true owner " << j << "! [" <<
-oldFromNewCentroids[assignments(0, node.Point(i))] << " -- " <<
+          Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+                << owner << " but has true owner " << trueOwner << "! [" <<
+assignments(0, node.Point(i)) << " -- " <<
 metric.Evaluate(dataset.col(node.Point(i)),
-centroids.col(oldFromNewCentroids[assignments(0, node.Point(i))])) << "] " <<
+centroids.col(assignments(0, node.Point(i)))) << "] " <<
 distances(0, node.Point(i)) << " " <<
-oldFromNewCentroids[assignments(0, node.Point(i))] << " " <<
-oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
-          }
+assignments(0, node.Point(i)) << " " <<
+assignments(0, node.Point(i - 1)) << ".\n";
         }
       }
 */
@@ -400,34 +401,6 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
     // The node was pruned last iteration.  See if the node can remain pruned.
     singleOwner = false;
 
-/*
-      for (size_t i = 0; i < node.NumPoints(); ++i)
-      {
-        size_t trueOwner = 0;
-        double ownerDist = DBL_MAX;
-        arma::vec distances(centroids.n_cols);
-        for (size_t j = 0; j < centroids.n_cols; ++j)
-        {
-          const double dist = metric.Evaluate(dataset.col(node.Point(i)),
-              centroids.col(j));
-          distances(j) = dist;
-          if (dist < ownerDist)
-          {
-            trueOwner = j;
-            ownerDist = dist;
-          }
-        }
-
-        if (trueOwner != node.Stat().Owner())
-        {
-            Log::Warn << node << "...\n" << *node.Parent();
-            Log::Warn << distances.t();
-            Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
-                << node.Stat().Owner() << " but has true owner " << trueOwner <<
-"!\n";
-        }
-      }*/
-
     // If it was pruned because all points were pruned, we need to check
     // individually.
     if (node.Stat().Owner() == centroids.n_cols)
@@ -465,6 +438,37 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
         }
       }
     }
+/*
+    if (node.Stat().Pruned() && node.Stat().Owner() != centroids.n_cols)
+    {
+      for (size_t i = 0; i < node.NumPoints(); ++i)
+      {
+        size_t trueOwner = 0;
+        double ownerDist = DBL_MAX;
+        arma::vec distances(centroids.n_cols);
+        for (size_t j = 0; j < centroids.n_cols; ++j)
+        {
+          const double dist = metric.Evaluate(dataset.col(node.Point(i)),
+              lastIterationCentroids.col(j));
+          distances(j) = dist;
+          if (dist < ownerDist)
+          {
+            trueOwner = j;
+            ownerDist = dist;
+          }
+        }
+
+        if (trueOwner != node.Stat().Owner())
+        {
+            Log::Warn << node << "...\n" << *node.Parent();
+            Log::Warn << distances.t();
+            Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+                << node.Stat().Owner() << " but has true owner " << trueOwner <<
+"!\n";
+        }
+      }
+    }
+*/
   }
   else
   {
@@ -490,9 +494,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
       size_t owner;
       if (!prunedLastIteration && !prunedPoints[index])
       {
-        owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
-            lastOldFromNewCentroids[assignments(0, index)] :
-            assignments(0, index);
+        owner = assignments(0, index);
         // Establish bounds, since these points were searched this iteration.
         upperBounds[index] = distances(0, index);
         lowerSecondBounds[index] = distances(1, index);
@@ -516,7 +518,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
         arma::vec distances(centroids.n_cols);
         for (size_t j = 0; j < centroids.n_cols; ++j)
         {
-          const double dist = metric.Evaluate(centroids.col(j),
+          const double dist = metric.Evaluate(lastIterationCentroids.col(j),
                                               dataset.col(index));
           distances(j) = dist;
           if (dist < trueDist)
@@ -535,8 +537,8 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
           Log::Warn << distances.t();
           Log::Fatal << "Assigned owner " << owner << " but true owner is "
               << trueOwner << "!\n";
-        }*/
-
+        }
+*/
         prunedPoints[index] = true;
         upperBounds[index] += clusterDistances[owner];
         lastOwners[index] = owner;
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index da2a63d..2ded3ea 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -15,15 +15,17 @@ namespace mlpack {
 namespace kmeans {
 
 template<typename MetricType, typename TreeType>
-class DTNNKMeansRules
+class DTNNKMeansRules : public neighbor::NeighborSearchRules<
+    neighbor::NearestNeighborSort, MetricType, TreeType>
 {
  public:
   DTNNKMeansRules(const arma::mat& centroids,
-                      const arma::mat& dataset,
-                      arma::Mat<size_t>& neighbors,
-                      arma::mat& distances,
-                      MetricType& metric,
-                      const std::vector<bool>& prunedPoints);
+                  const arma::mat& dataset,
+                  arma::Mat<size_t>& neighbors,
+                  arma::mat& distances,
+                  MetricType& metric,
+                  const std::vector<bool>& prunedPoints,
+                  const std::vector<size_t>& oldFromNewCentroids);
 
   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
 
@@ -36,23 +38,11 @@ class DTNNKMeansRules
                  TreeType& referenceNode,
                  const double oldScore);
 
-  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
-
-  size_t Scores() const { return rules.Scores(); }
-  size_t& Scores() { return rules.Scores(); }
-  size_t BaseCases() const { return rules.BaseCases(); }
-  size_t& BaseCases() { return rules.BaseCases(); }
-
-  const TraversalInfoType& TraversalInfo() const
-  { return rules.TraversalInfo(); }
-  TraversalInfoType& TraversalInfo() { return rules.TraversalInfo(); }
-
  private:
 
-  typename neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
-      MetricType, TreeType> rules;
-
   const std::vector<bool>& prunedPoints;
+
+  const std::vector<size_t>& oldFromNewCentroids;
 };
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index bce7d47..11c2c3e 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -19,9 +19,12 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
     arma::Mat<size_t>& neighbors,
     arma::mat& distances,
     MetricType& metric,
-    const std::vector<bool>& prunedPoints) :
-    rules(centroids, dataset, neighbors, distances, metric),
-    prunedPoints(prunedPoints)
+    const std::vector<bool>& prunedPoints,
+    const std::vector<size_t>& oldFromNewCentroids) :
+    neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
+        TreeType>(centroids, dataset, neighbors, distances, metric),
+    prunedPoints(prunedPoints),
+    oldFromNewCentroids(oldFromNewCentroids)
 {
   // Nothing to do.
 }
@@ -35,7 +38,62 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
   if (prunedPoints[queryIndex])
     return 0.0; // Returning 0 shouldn't be a problem.
 
-  return rules.BaseCase(queryIndex, referenceIndex);
+  // This is basically an inlined NeighborSearchRules::BaseCase(), but it
+  // differs in that it applies the mappings to the results automatically.
+  // We can also skip a check or two.
+
+  // By the way, all of the this-> is necessary because the parent class is a
+  // dependent name, so all of the members of that parent aren't resolvable
+  // before type substitution.  The 'this->' turns that member into a dependent
+  // name too (since the type of 'this' is dependent), and thus the compiler
+  // resolves the name later and we get no error.  Hooray C++!
+  //
+  // See also:
+  // http://stackoverflow.com/questions/10639053/name-lookups-in-c-templates
+
+  // If we have already performed this base case, do not perform it again.
+  if ((this->lastQueryIndex == queryIndex) &&
+      (this->lastReferenceIndex == referenceIndex))
+    return this->lastBaseCase;
+
+  double distance = this->metric.Evaluate(this->querySet.col(queryIndex),
+      this->referenceSet.col(referenceIndex));
+  ++this->baseCases;
+
+  const size_t cluster = oldFromNewCentroids[referenceIndex];
+
+  // Is this better than either existing candidate?
+  if (distance < this->distances(0, queryIndex))
+  {
+    // Do we need to replace the assignment, or is it an old assignment from a
+    // previous iteration?
+    if (this->neighbors(0, queryIndex) != cluster &&
+        this->neighbors(0, queryIndex) < this->referenceSet.n_cols)
+    {
+      // We must push the old closest assignment down the stack.
+      this->neighbors(1, queryIndex) = this->neighbors(0, queryIndex);
+      this->distances(1, queryIndex) = this->distances(0, queryIndex);
+      this->neighbors(0, queryIndex) = cluster;
+    }
+    else if (this->neighbors(0, queryIndex) >= this->referenceSet.n_cols)
+    {
+      this->neighbors(0, queryIndex) = cluster;
+    }
+
+    this->distances(0, queryIndex) = distance;
+  }
+  else if (distance < this->distances(1, queryIndex))
+  {
+    // Here it doesn't actually matter if the assignment is the same.
+    this->neighbors(1, queryIndex) = cluster;
+    this->distances(1, queryIndex) = distance;
+  }
+
+  this->lastQueryIndex = queryIndex;
+  this->lastReferenceIndex = referenceIndex;
+  this->lastBaseCase = distance;
+
+  return distance;
 }
 
 template<typename MetricType, typename TreeType>
@@ -47,7 +105,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
   if (prunedPoints[queryIndex])
     return DBL_MAX;
 
-  return rules.Score(queryIndex, referenceNode);
+  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+      MetricType, TreeType>::Score(queryIndex, referenceNode);
 }
 
 template<typename MetricType, typename TreeType>
@@ -59,7 +118,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
     return DBL_MAX;
 
   // Check if the query node is Hamerly pruned, and if not, then don't continue.
-  return rules.Score(queryNode, referenceNode);
+  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+      MetricType, TreeType>::Score(queryNode, referenceNode);
 }
 
 template<typename MetricType, typename TreeType>
@@ -68,7 +128,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
     TreeType& referenceNode,
     const double oldScore)
 {
-  return rules.Rescore(queryIndex, referenceNode, oldScore);
+  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+      MetricType, TreeType>::Rescore(queryIndex, referenceNode, oldScore);
 }
 
 template<typename MetricType, typename TreeType>
@@ -79,7 +140,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
 {
   // No need to check for a Hamerly prune.  Because we've already done that in
   // Score().
-  return rules.Rescore(queryNode, referenceNode, oldScore);
+  return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+      MetricType, TreeType>::Rescore(queryNode, referenceNode, oldScore);
 }
 
 } // namespace kmeans
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index effa4ac..e09df0c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -100,7 +100,7 @@ class NeighborSearchRules
   //! Modify the traversal info.
   TraversalInfoType& TraversalInfo() { return traversalInfo; }
 
- private:
+ protected:
   //! The reference set.
   const typename TreeType::Mat& referenceSet;
 



More information about the mlpack-git mailing list