[mlpack-git] master: Fix unnecessary copy. (4c71640)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed May 20 23:06:16 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/77d750c8fd46140b1d6060424f68768a21c89377...7e9cd46afb53817ae93ccbd02637d7726137ce4d

>---------------------------------------------------------------

commit 4c71640e73bc58a58074f31f3f8dbab56b9da0d7
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed May 20 10:16:13 2015 -0400

    Fix unnecessary copy.


>---------------------------------------------------------------

4c71640e73bc58a58074f31f3f8dbab56b9da0d7
 src/mlpack/methods/kmeans/dual_tree_kmeans.hpp      |  3 +--
 src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp | 14 ++++++--------
 2 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
index f4fd3aa..ae9cb71 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans.hpp
@@ -94,13 +94,12 @@ class DualTreeKMeans
 
   arma::vec clusterDistances; // The amount the clusters moved last iteration.
 
-  arma::vec interclusterDistances; // Static storage for intercluster distances.
+  arma::mat interclusterDistances; // Static storage for intercluster distances.
 
   //! Update the bounds in the tree before the next iteration.
   //! centroids is the current (not yet searched) centroids.
   void UpdateTree(TreeType& node,
                   const arma::mat& centroids,
-                  const arma::vec& interclusterDistances,
                   const double parentUpperBound = 0.0,
                   const double adjustedParentUpperBound = DBL_MAX,
                   const double parentLowerBound = DBL_MAX,
diff --git a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
index d29e290..71ffe2d 100644
--- a/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dual_tree_kmeans_impl.hpp
@@ -112,8 +112,8 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
     // If the tree maps points, we need an intermediate result matrix.
     arma::mat* interclusterDistancesTemp =
-        (tree::TreeTraits<TreeType>::RearrangesDataset) ? new arma::mat :
-        &interclusterDistances;
+        (tree::TreeTraits<TreeType>::RearrangesDataset) ?
+        new arma::mat(1, centroids.n_elem) : &interclusterDistances;
 
     arma::Mat<size_t> closestClusters; // We don't actually care about these.
     nns.Search(1, closestClusters, *interclusterDistancesTemp);
@@ -131,7 +131,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
 
     Timer::Stop("knn");
 
-    UpdateTree(*tree, oldCentroids, interclusterDistances);
+    UpdateTree(*tree, oldCentroids);
 
     for (size_t i = 0; i < dataset.n_cols; ++i)
       visited[i] = false;
@@ -140,7 +140,7 @@ double DualTreeKMeans<MetricType, MatType, TreeType>::Iterate(
   {
     // Not initialized yet.
     clusterDistances.set_size(centroids.n_cols + 1);
-    interclusterDistances.set_size(centroids.n_cols);
+    interclusterDistances.set_size(1, centroids.n_cols);
   }
 
   // We won't use the AllkNN class here because we have our own set of rules.
@@ -210,7 +210,6 @@ template<typename MetricType, typename MatType, typename TreeType>
 void DualTreeKMeans<MetricType, MatType, TreeType>::UpdateTree(
     TreeType& node,
     const arma::mat& centroids,
-    const arma::vec& interclusterDistances,
     const double parentUpperBound,
     const double adjustedParentUpperBound,
     const double parentLowerBound,
@@ -408,9 +407,8 @@ visited[node.Descendant(i)] << ".\n";
   bool allChildrenPruned = true;
   for (size_t i = 0; i < node.NumChildren(); ++i)
   {
-    UpdateTree(node.Child(i), centroids, interclusterDistances,
-        unadjustedUpperBound, adjustedUpperBound, unadjustedLowerBound,
-        adjustedLowerBound);
+    UpdateTree(node.Child(i), centroids, unadjustedUpperBound,
+        adjustedUpperBound, unadjustedLowerBound, adjustedLowerBound);
     if (!node.Child(i).Stat().StaticPruned())
       allChildrenPruned = false;
   }



More information about the mlpack-git mailing list