[mlpack-git] master: Refactor to use new set of rules. How many times will I restart writing this algorithm until I actually get it working well? (4db93bf)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 12 16:03:48 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/eddd7167d69b6c88b271ef2e51d1c20e13f1acd8...70342dd8e5c17e0c164cfb8189748671e9c0dd44

>---------------------------------------------------------------

commit 4db93bfa4c4e67bad7a3f1778018ece2b6fb4fc0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu Jan 29 22:34:00 2015 -0500

    Refactor to use new set of rules. How many times will I restart writing this algorithm until I actually get it working well?


>---------------------------------------------------------------

4db93bfa4c4e67bad7a3f1778018ece2b6fb4fc0
 src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp | 33 +++++++----
 src/mlpack/methods/kmeans/dtnn_rules.hpp       | 60 ++++++++++++++++++++
 src/mlpack/methods/kmeans/dtnn_rules_impl.hpp  | 78 ++++++++++++++++++++++++++
 3 files changed, 159 insertions(+), 12 deletions(-)

diff --git a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
index fde2b44..6112ca7 100644
--- a/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
+++ b/src/mlpack/methods/kmeans/dtnn_kmeans_impl.hpp
@@ -13,6 +13,8 @@
 // In case it hasn't been included yet.
 #include "dtnn_kmeans.hpp"
 
+#include "dtnn_rules.hpp"
+
 namespace mlpack {
 namespace kmeans {
 
@@ -85,28 +87,35 @@ double DTNNKMeans<MetricType, MatType, TreeType>::Iterate(
   TreeType* centroidTree = BuildTree<TreeType>(
       const_cast<typename TreeType::Mat&>(centroids), oldFromNewCentroids);
 
-  typedef neighbor::NeighborSearch<neighbor::NearestNeighborSort, MetricType,
-      TreeType> AllkNNType;
-  AllkNNType allknn(centroidTree, tree, centroids, dataset, false, metric);
-
+  // We won't use the AllkNN class here because we have our own set of rules.
   // This is a lot of overhead.  We don't need the distances.
-  arma::mat distances;
-  arma::Mat<size_t> assignments;
-  allknn.Search(1, assignments, distances);
-  distanceCalculations += allknn.BaseCases() + allknn.Scores();
+  arma::mat distances(5, dataset.n_cols);
+  arma::Mat<size_t> assignments(5, dataset.n_cols);
+  distances.fill(DBL_MAX);
+  assignments.fill(size_t(-1));
+  typedef DTNNKMeansRules<MetricType, TreeType> RuleType;
+  RuleType rules(centroids, dataset, assignments, distances, metric);
+
+  // Now construct the traverser ourselves.
+  typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+  traverser.Traverse(*tree, *centroidTree);
+
+  distanceCalculations += rules.BaseCases() + rules.Scores();
 
   // From the assignments, calculate the new centroids and counts.
   for (size_t i = 0; i < dataset.n_cols; ++i)
   {
     if (tree::TreeTraits<TreeType>::RearrangesDataset)
     {
-      newCentroids.col(oldFromNewCentroids[assignments[i]]) += dataset.col(i);
-      ++counts(oldFromNewCentroids[assignments[i]]);
+      newCentroids.col(oldFromNewCentroids[assignments(0, i)]) +=
+          dataset.col(i);
+      ++counts(oldFromNewCentroids[assignments(0, i)]);
     }
     else
     {
-      newCentroids.col(assignments[i]) += dataset.col(i);
-      ++counts(assignments[i]);
+      newCentroids.col(assignments(0, i)) += dataset.col(i);
+      ++counts(assignments(0, i));
     }
   }
 
diff --git a/src/mlpack/methods/kmeans/dtnn_rules.hpp b/src/mlpack/methods/kmeans/dtnn_rules.hpp
new file mode 100644
index 0000000..44647ce
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dtnn_rules.hpp
@@ -0,0 +1,60 @@
+/**
+ * @file dtnn_rules.hpp
+ * @author Ryan Curtin
+ *
+ * A set of rules for the dual-tree k-means algorithm which uses dual-tree
+ * nearest neighbor search.  For the most part we'll call out to
+ * NeighborSearchRules when we can.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
+#define __MLPACK_METHODS_KMEANS_DTNN_RULES_HPP
+
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+class DTNNKMeansRules
+{
+ public:
+  DTNNKMeansRules(const arma::mat& centroids,
+                      const arma::mat& dataset,
+                      arma::Mat<size_t>& neighbors,
+                      arma::mat& distances,
+                      MetricType& metric);
+
+  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
+
+  double Score(const size_t queryIndex, TreeType& referenceNode);
+  double Score(TreeType& queryNode, TreeType& referenceNode);
+  double Rescore(const size_t queryIndex,
+                 TreeType& referenceNode,
+                 const double oldScore);
+  double Rescore(TreeType& queryNode,
+                 TreeType& referenceNode,
+                 const double oldScore);
+
+  typedef neighbor::NeighborSearchTraversalInfo<TreeType> TraversalInfoType;
+
+  size_t Scores() const { return rules.Scores(); }
+  size_t& Scores() { return rules.Scores(); }
+  size_t BaseCases() const { return rules.BaseCases(); }
+  size_t& BaseCases() { return rules.BaseCases(); }
+
+  const TraversalInfoType& TraversalInfo() const
+  { return rules.TraversalInfo(); }
+  TraversalInfoType& TraversalInfo() { return rules.TraversalInfo(); }
+
+ private:
+
+  typename neighbor::NeighborSearchRules<neighbor::NearestNeighborSort,
+      MetricType, TreeType> rules;
+};
+
+} // namespace kmeans
+} // namespace mlpack
+
+#include "dtnn_rules_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
new file mode 100644
index 0000000..c4492c1
--- /dev/null
+++ b/src/mlpack/methods/kmeans/dtnn_rules_impl.hpp
@@ -0,0 +1,78 @@
+/**
+ * @file dtnn_rules_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of DualTreeKMeansRules.
+ */
+#ifndef __MLPACK_METHODS_KMEANS_DTNN_RULES_IMPL_HPP
+#define __MLPACK_METHODS_KMEANS_DTNN_RULES_IMPL_HPP
+
+#include "dtnn_rules.hpp"
+
+namespace mlpack {
+namespace kmeans {
+
+template<typename MetricType, typename TreeType>
+DTNNKMeansRules<MetricType, TreeType>::DTNNKMeansRules(
+    const arma::mat& centroids,
+    const arma::mat& dataset,
+    arma::Mat<size_t>& neighbors,
+    arma::mat& distances,
+    MetricType& metric) :
+    rules(centroids, dataset, neighbors, distances, metric)
+{
+  // Nothing to do.
+}
+
+template<typename MetricType, typename TreeType>
+inline force_inline double DTNNKMeansRules<MetricType, TreeType>::BaseCase(
+    const size_t queryIndex,
+    const size_t referenceIndex)
+{
+  // We'll check if the query point has been Hamerly pruned.  If so, don't
+  // continue.
+
+  return rules.BaseCase(queryIndex, referenceIndex);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Score(
+    const size_t queryIndex,
+    TreeType& referenceNode)
+{
+  return rules.Score(queryIndex, referenceNode);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Score(
+    TreeType& queryNode,
+    TreeType& referenceNode)
+{
+  // Check if the query node is Hamerly pruned, and if not, then don't continue.
+  return rules.Score(queryNode, referenceNode);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
+    const size_t queryIndex,
+    TreeType& referenceNode,
+    const double oldScore)
+{
+  return rules.Rescore(queryIndex, referenceNode, oldScore);
+}
+
+template<typename MetricType, typename TreeType>
+inline double DTNNKMeansRules<MetricType, TreeType>::Rescore(
+    TreeType& queryNode,
+    TreeType& referenceNode,
+    const double oldScore)
+{
+  // No need to check for a Hamerly prune.  Because we've already done that in
+  // Score().
+  return rules.Rescore(queryNode, referenceNode, oldScore);
+}
+
+} // namespace kmeans
+} // namespace mlpack
+
+#endif



More information about the mlpack-git mailing list