[mlpack-git] master: Refactor BaseCase() to apply mappings. This reduces memory usage. (acd8db8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:54 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit acd8db891c8c67b5c1455cde20ed60ae9b4acd93
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Feb 3 11:14:41 2015 -0500
Refactor BaseCase() to apply mappings. This reduces memory usage.
>---------------------------------------------------------------
acd8db891c8c67b5c1455cde20ed60ae9b4acd93
src/mlpack/methods/kmeans/dtnn_kmeans.hpp | 2 +-
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 126 +++++++++++----------
src/mlpack/methods/kmeans/dtnn_rules.hpp | 30 ++---
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 78 +++++++++++--
.../neighbor_search/neighbor_search_rules.hpp | 2 +-
5 files changed, 146 insertions(+), 92 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
index b8c2792..b939f32 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans.hpp
@@ -99,7 +99,7 @@ class DTNNKMeans
arma::mat distances;
arma::Mat<size_t> assignments;
- std::vector<size_t> lastOldFromNewCentroids;
+ arma::mat lastIterationCentroids; // For sanity checks.
//! Update the bounds in the tree before the next iteration.
void UpdateTree(TreeType& node,
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index 1f48545..95546d3 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -139,7 +139,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
assignments.fill(size_t(-1));
typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
RuleType rules(centroids, dataset, assignments, distances, metric,
- prunedPoints);
+ prunedPoints, oldFromNewCentroids);
// Now construct the traverser ourselves.
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
@@ -155,17 +155,8 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
{
if (assignments(0, i) != size_t(-1))
{
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
- {
- newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
- dataset.col(i);
- ++counts(oldFromNewCentroids[assignments(0, i)]);
- }
- else
- {
- newCentroids.col(assignments(0, i)) += dataset.col(i);
- ++counts(assignments(0, i));
- }
+ newCentroids.col(assignments(0, i)) += dataset.col(i);
+ ++counts(assignments(0, i));
}
}
@@ -200,7 +191,7 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
clusterDistances[centroids.n_cols] = maxMovement;
distanceCalculations += centroids.n_cols;
- lastOldFromNewCentroids = oldFromNewCentroids;
+// lastIterationCentroids = oldCentroids;
delete centroidTree;
@@ -248,9 +239,7 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
// Don't forget to map back from the new cluster index.
size_t c;
if (!prunedPoints[node.Point(i)])
- c = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
- lastOldFromNewCentroids[assignments(0, node.Point(i))] :
- assignments(0, node.Point(i));
+ c = assignments(0, node.Point(i));
else
c = lastOwners[node.Point(i)];
@@ -310,30 +299,42 @@ void DTNNKMeans<MetricType, MatType, TreeType>::UpdateTree(
/*
for (size_t i = 0; i < node.NumPoints(); ++i)
{
- const double ownerDist = metric.Evaluate(dataset.col(node.Point(i)),
- centroids.col(owner));
+ arma::vec dists(centroids.n_cols);
+ size_t trueOwner = centroids.n_cols;
+ double trueDist = DBL_MAX;
for (size_t j = 0; j < centroids.n_cols; ++j)
{
const double dist = metric.Evaluate(dataset.col(node.Point(i)),
- centroids.col(j));
- if (dist < ownerDist)
+ lastIterationCentroids.col(j));
+ dists(j) = dist;
+ if (dist < trueDist)
{
- Log::Warn << node << "...\n" << *node.Parent();
+ trueDist = dist;
+ trueOwner = j;
+ }
+ }
+
+ if (trueOwner != owner)
+ {
+ Log::Warn << node << "...\n" << *node.Parent();
+ Log::Warn << dists.t();
+ Log::Warn << "Assignment: " << assignments(0, node.Point(i)) << ".\n";
+ Log::Warn << "Dists: " << distances(0, node.Point(i)) << ", " <<
+distances(1, node.Point(i)) << ".\n";
// TreeType* n = node.Parent()->Parent();
// while (n != NULL)
// {
// Log::Warn << "...\n" << *n;
// n = n->Parent();
// }
- Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
- << owner << " but has true owner " << j << "! [" <<
-oldFromNewCentroids[assignments(0, node.Point(i))] << " -- " <<
+ Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+ << owner << " but has true owner " << trueOwner << "! [" <<
+assignments(0, node.Point(i)) << " -- " <<
metric.Evaluate(dataset.col(node.Point(i)),
-centroids.col(oldFromNewCentroids[assignments(0, node.Point(i))])) << "] " <<
+centroids.col(assignments(0, node.Point(i)))) << "] " <<
distances(0, node.Point(i)) << " " <<
-oldFromNewCentroids[assignments(0, node.Point(i))] << " " <<
-oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
- }
+assignments(0, node.Point(i)) << " " <<
+assignments(0, node.Point(i - 1)) << ".\n";
}
}
*/
@@ -400,34 +401,6 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
// The node was pruned last iteration. See if the node can remain pruned.
singleOwner = false;
-/*
- for (size_t i = 0; i < node.NumPoints(); ++i)
- {
- size_t trueOwner = 0;
- double ownerDist = DBL_MAX;
- arma::vec distances(centroids.n_cols);
- for (size_t j = 0; j < centroids.n_cols; ++j)
- {
- const double dist = metric.Evaluate(dataset.col(node.Point(i)),
- centroids.col(j));
- distances(j) = dist;
- if (dist < ownerDist)
- {
- trueOwner = j;
- ownerDist = dist;
- }
- }
-
- if (trueOwner != node.Stat().Owner())
- {
- Log::Warn << node << "...\n" << *node.Parent();
- Log::Warn << distances.t();
- Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
- << node.Stat().Owner() << " but has true owner " << trueOwner <<
-"!\n";
- }
- }*/
-
// If it was pruned because all points were pruned, we need to check
// individually.
if (node.Stat().Owner() == centroids.n_cols)
@@ -465,6 +438,37 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
}
}
}
+/*
+ if (node.Stat().Pruned() && node.Stat().Owner() != centroids.n_cols)
+ {
+ for (size_t i = 0; i < node.NumPoints(); ++i)
+ {
+ size_t trueOwner = 0;
+ double ownerDist = DBL_MAX;
+ arma::vec distances(centroids.n_cols);
+ for (size_t j = 0; j < centroids.n_cols; ++j)
+ {
+ const double dist = metric.Evaluate(dataset.col(node.Point(i)),
+ lastIterationCentroids.col(j));
+ distances(j) = dist;
+ if (dist < ownerDist)
+ {
+ trueOwner = j;
+ ownerDist = dist;
+ }
+ }
+
+ if (trueOwner != node.Stat().Owner())
+ {
+ Log::Warn << node << "...\n" << *node.Parent();
+ Log::Warn << distances.t();
+ Log::Fatal << "Point " << node.Point(i) << " was assigned to owner "
+ << node.Stat().Owner() << " but has true owner " << trueOwner <<
+"!\n";
+ }
+ }
+ }
+*/
}
else
{
@@ -490,9 +494,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
size_t owner;
if (!prunedLastIteration && !prunedPoints[index])
{
- owner = (tree::TreeTraits<TreeType>::RearrangesDataset) ?
- lastOldFromNewCentroids[assignments(0, index)] :
- assignments(0, index);
+ owner = assignments(0, index);
// Establish bounds, since these points were searched this iteration.
upperBounds[index] = distances(0, index);
lowerSecondBounds[index] = distances(1, index);
@@ -516,7 +518,7 @@ oldFromNewCentroids[assignments(0, node.Point(i - 1))] << ".\n";
arma::vec distances(centroids.n_cols);
for (size_t j = 0; j < centroids.n_cols; ++j)
{
- const double dist = metric.Evaluate(centroids.col(j),
+ const double dist = metric.Evaluate(lastIterationCentroids.col(j),
dataset.col(index));
distances(j) = dist;
if (dist < trueDist)
@@ -535,8 +537,8 @@ prunedPoints[index] << ", lastOwner " << lastOwners[index] << ": invalid "
Log::Warn << distances.t();
Log::Fatal << "Assigned owner " << owner << " but true owner is "
<< trueOwner << "!\n";
- }*/
-
+ }
+*/
prunedPoints[index] = true;
upperBounds[index] += clusterDistances[owner];
lastOwners[index] = owner;
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
index da2a63d..2ded3ea 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -15,15 +15,17 @@ namespace mlpack {
namespace kmeans {
template<typename MetricType, typename TreeType>
-class DTNNKMeansRules
+class DTNNKMeansRules : public neighbor::NeighborSearchRules<
+ neighbor::NearestNeighborSort, MetricType, TreeType>
{
public:
DTNNKMeansRules(const arma::mat& centroids,
- const arma::mat& dataset,
- arma::Mat<size_t>& neighbors,
- arma::mat& distances,
- MetricType& metric,
- const std::vector<bool>& prunedPoints);
+ const arma::mat& dataset,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric,
+ const std::vector<bool>& prunedPoints,
+ const std::vector<size_t>& oldFromNewCentroids);
double BaseCase(const size_t queryIndex, const size_t referenceIndex);
@@ -36,23 +38,11 @@ class DTNNKMeansRules
TreeType& referenceNode,
const double oldScore);
- typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
-
- size_t Scores() const { return rules.Scores(); }
- size_t& Scores() { return rules.Scores(); }
- size_t BaseCases() const { return rules.BaseCases(); }
- size_t& BaseCases() { return rules.BaseCases(); }
-
- const TraversalInfoType& TraversalInfo() const
- { return rules.TraversalInfo(); }
- TraversalInfoType& TraversalInfo() { return rules.TraversalInfo(); }
-
private:
- typename neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
- MetricType, TreeType> rules;
-
const std::vector<bool>& prunedPoints;
+
+ const std::vector<size_t>& oldFromNewCentroids;
};
} // namespace kmeans
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
index bce7d47..11c2c3e 100644
--- a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -19,9 +19,12 @@ DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
arma::Mat<size_t>& neighbors,
arma::mat& distances,
MetricType& metric,
- const std::vector<bool>& prunedPoints) :
- rules(centroids, dataset, neighbors, distances, metric),
- prunedPoints(prunedPoints)
+ const std::vector<bool>& prunedPoints,
+ const std::vector<size_t>& oldFromNewCentroids) :
+ neighbor::NeighborSearchRules<neighbor::NearestNeighborSort, MetricType,
+ TreeType>(centroids, dataset, neighbors, distances, metric),
+ prunedPoints(prunedPoints),
+ oldFromNewCentroids(oldFromNewCentroids)
{
// Nothing to do.
}
@@ -35,7 +38,62 @@ inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
if (prunedPoints[queryIndex])
return 0.0; // Returning 0 shouldn't be a problem.
- return rules.BaseCase(queryIndex, referenceIndex);
+ // This is basically an inlined NeighborSearchRules::BaseCase(), but it
+ // differs in that it applies the mappings to the results automatically.
+ // We can also skip a check or two.
+
+ // By the way, all of the this-> is necessary because the parent class is a
+ // dependent name, so all of the members of that parent aren't resolvable
+ // before type substitution. The 'this->' turns that member into a dependent
+ // name too (since the type of 'this' is dependent), and thus the compiler
+ // resolves the name later and we get no error. Hooray C++!
+ //
+ // See also:
+ // http://stackoverflow.com/questions/10639053/name-lookups-in-c-templates
+
+ // If we have already performed this base case, do not perform it again.
+ if ((this->lastQueryIndex == queryIndex) &&
+ (this->lastReferenceIndex == referenceIndex))
+ return this->lastBaseCase;
+
+ double distance = this->metric.Evaluate(this->querySet.col(queryIndex),
+ this->referenceSet.col(referenceIndex));
+ ++this->baseCases;
+
+ const size_t cluster = oldFromNewCentroids[referenceIndex];
+
+ // Is this better than either existing candidate?
+ if (distance < this->distances(0, queryIndex))
+ {
+ // Do we need to replace the assignment, or is it an old assignment from a
+ // previous iteration?
+ if (this->neighbors(0, queryIndex) != cluster &&
+ this->neighbors(0, queryIndex) < this->referenceSet.n_cols)
+ {
+ // We must push the old closest assignment down the stack.
+ this->neighbors(1, queryIndex) = this->neighbors(0, queryIndex);
+ this->distances(1, queryIndex) = this->distances(0, queryIndex);
+ this->neighbors(0, queryIndex) = cluster;
+ }
+ else if (this->neighbors(0, queryIndex) >= this->referenceSet.n_cols)
+ {
+ this->neighbors(0, queryIndex) = cluster;
+ }
+
+ this->distances(0, queryIndex) = distance;
+ }
+ else if (distance < this->distances(1, queryIndex))
+ {
+ // Here it doesn't actually matter if the assignment is the same.
+ this->neighbors(1, queryIndex) = cluster;
+ this->distances(1, queryIndex) = distance;
+ }
+
+ this->lastQueryIndex = queryIndex;
+ this->lastReferenceIndex = referenceIndex;
+ this->lastBaseCase = distance;
+
+ return distance;
}
template<typename MetricType, typename TreeType>
@@ -47,7 +105,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
if (prunedPoints[queryIndex])
return DBL_MAX;
- return rules.Score(queryIndex, referenceNode);
+ return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+ MetricType, TreeType>::Score(queryIndex, referenceNode);
}
template<typename MetricType, typename TreeType>
@@ -59,7 +118,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Score(
return DBL_MAX;
// Check if the query node is Hamerly pruned, and if not, then don't continue.
- return rules.Score(queryNode, referenceNode);
+ return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+ MetricType, TreeType>::Score(queryNode, referenceNode);
}
template<typename MetricType, typename TreeType>
@@ -68,7 +128,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
TreeType& referenceNode,
const double oldScore)
{
- return rules.Rescore(queryIndex, referenceNode, oldScore);
+ return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+ MetricType, TreeType>::Rescore(queryIndex, referenceNode, oldScore);
}
template<typename MetricType, typename TreeType>
@@ -79,7 +140,8 @@ inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
{
// No need to check for a Hamerly prune. Because we've already done that in
// Score().
- return rules.Rescore(queryNode, referenceNode, oldScore);
+ return neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+ MetricType, TreeType>::Rescore(queryNode, referenceNode, oldScore);
}
} // namespace kmeans
diff --git a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
index effa4ac..e09df0c 100644
--- a/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
+++ b/src/mlpack/methods/neighbor_search/neighbor_search_rules.hpp
@@ -100,7 +100,7 @@ class NeighborSearchRules
//! Modify the traversal info.
TraversalInfoType& TraversalInfo() { return traversalInfo; }
- private:
+ protected:
//! The reference set.
const typename TreeType::Mat& referenceSet;
More information about the mlpack-git
mailing list