[mlpack-git] master: Refactor to use new set of rules. How many times will I restart writing this algorithm until I actually get it working well? (4db93bf)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:48 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44
>---------------------------------------------------------------
commit 4db93bfa4c4e67bad7a3f1778018ece2b6fb4fc0
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Jan 29 22:34:00 2015 -0500
Refactor to use new set of rules. How many times will I restart writing this algorithm until I actually get it working well?
>---------------------------------------------------------------
4db93bfa4c4e67bad7a3f1778018ece2b6fb4fc0
src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 33 +++++++----
src/mlpack/methods/kmeans/dtnn_rules.hpp | 60 ++++++++++++++++++++
src/mlpack/methods/kmeans/dtnn_rules_impl.hpp | 78 ++++++++++++++++++++++++++
3 files changed, 159 insertions(+), 12 deletions(-)
diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index fde2b44..6112ca7 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -13,6 +13,8 @@
// In case it hasn't been included yet.
#include "dtnn_kmeans.hpp"
+#include "dtnn_rules.hpp"
+
namespace mlpack {
namespace kmeans {
@@ -85,28 +87,35 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
TreeType* centroidTree = BuildTree<TreeType>(
const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
- typedef neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
- TreeType> AllkNNType;
- AllkNNType allknn(centroidTree, tree, centroids, dataset, false, metric);
-
+ // We won't use the AllkNN class here because we have our own set of rules.
// This is a lot of overhead. We don't need the distances.
- arma::mat distances;
- arma::Mat<size_t> assignments;
- allknn.Search(1, assignments, distances);
- distanceCalculations += allknn.BaseCases() + allknn.Scores();
+ arma::mat distances(5, dataset.n_cols);
+ arma::Mat<size_t> assignments(5, dataset.n_cols);
+ distances.fill(DBL_MAX);
+ assignments.fill(size_t(-1));
+ typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
+ RuleType rules(centroids, dataset, assignments, distances, metric);
+
+ // Now construct the traverser ourselves.
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*tree, *centroidTree);
+
+ distanceCalculations += rules.BaseCases() + rules.Scores();
// From the assignments, calculate the new centroids and counts.
for (size_t i = 0; i < dataset.n_cols; ++i)
{
if (tree::TreeTraits<TreeType>::RearrangesDataset)
{
- newCentroids.col(oldFromNewCentroids[assignments[i]]) += dataset.col(i);
- ++counts(oldFromNewCentroids[assignments[i]]);
+ newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
+ dataset.col(i);
+ ++counts(oldFromNewCentroids[assignments(0, i)]);
}
else
{
- newCentroids.col(assignments[i]) += dataset.col(i);
- ++counts(assignments[i]);
+ newCentroids.col(assignments(0, i)) += dataset.col(i);
+ ++counts(assignments(0, i));
}
}
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
new file mode 100644
index 0000000..44647ce
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -0,0 +1,60 @@
+/**
+ * @file dtnn_rules.hpp
+ * @author Ryan Curtin
+ *
+ * A set of rules for the dual-tree k-means algorithm which uses dual-tree
+ * nearest neighbor search. For the most part we'll call out to
+ * NeighborSearchRules when we can.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
+#define __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
+
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+class DTNNKMeansRules
+{
+ public:
+ DTNNKMeansRules(const arma::mat& centroids,
+ const arma::mat& dataset,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric);
+
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+ double Score(const size_t queryIndex, TreeType& referenceNode);
+ double Score(TreeType& queryNode, TreeType& referenceNode);
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore);
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore);
+
+ typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+ size_t Scores() const { return rules.Scores(); }
+ size_t& Scores() { return rules.Scores(); }
+ size_t BaseCases() const { return rules.BaseCases(); }
+ size_t& BaseCases() { return rules.BaseCases(); }
+
+ const TraversalInfoType& TraversalInfo() const
+ { return rules.TraversalInfo(); }
+ TraversalInfoType& TraversalInfo() { return rules.TraversalInfo(); }
+
+ private:
+
+ typename neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+ MetricType, TreeType> rules;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#include "dtnn_rules_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
new file mode 100644
index 0000000..c4492c1
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -0,0 +1,78 @@
+/**
+ * @file dtnn_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of DualTreeKMeansRules.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DTNN_RULES_IMPL_HPP
+
+#include "dtnn_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
+ const arma::mat& centroids,
+ const arma::mat& dataset,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances,
+ MetricType& metric) :
+ rules(centroids, dataset, neighbors, distances, metric)
+{
+ // Nothing to do.
+}
+
+template<typename MetricType, typename TreeType>
+inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
+{
+ // We'll check if the query point has been Hamerly pruned. If so, don't
+ // continue.
+
+ return rules.BaseCase(queryIndex, referenceIndex);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode)
+{
+ return rules.Score(queryIndex, referenceNode);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode)
+{
+ // Check if the query node is Hamerly pruned, and if not, then don't continue.
+ return rules.Score(queryNode, referenceNode);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ return rules.Rescore(queryIndex, referenceNode, oldScore);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore)
+{
+ // No need to check for a Hamerly prune. Because we've already done that in
+ // Score().
+ return rules.Rescore(queryNode, referenceNode, oldScore);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif
More information about the mlpack-git
mailing list