[mlpack-git] master: Refactor CoverTree to allow sparse datasets. (90a9d93)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Mar 31 11:25:55 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0c14444f7fde2a24c230d3536b544a3b5c1e1658...90a9d939482856608467bf174f824bc441ec2c21

>---------------------------------------------------------------

commit 90a9d939482856608467bf174f824bc441ec2c21
Author: ryan <ryan at ratml.org>
Date:   Tue Mar 31 11:25:31 2015 -0400

    Refactor CoverTree to allow sparse datasets.


>---------------------------------------------------------------

90a9d939482856608467bf174f824bc441ec2c21
 src/mlpack/core/tree/cover_tree/cover_tree.hpp     |  19 +-
 .../core/tree/cover_tree/cover_tree_impl.hpp       | 357 ++++++++++++++-------
 .../core/tree/cover_tree/dual_tree_traverser.hpp   |  13 +-
 .../tree/cover_tree/dual_tree_traverser_impl.hpp   |  52 ++-
 .../core/tree/cover_tree/first_point_is_root.hpp   |   3 +-
 .../core/tree/cover_tree/single_tree_traverser.hpp |  10 +-
 .../tree/cover_tree/single_tree_traverser_impl.hpp |  31 +-
 src/mlpack/tests/tree_test.cpp                     |  36 ++-
 8 files changed, 373 insertions(+), 148 deletions(-)

diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
index 8536718..547d3f7 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp
@@ -81,14 +81,17 @@ namespace tree {
  * @tparam MetricType Metric type to use during tree construction.
  * @tparam RootPointPolicy Determines which point to use as the root node.
  * @tparam StatisticType Statistic to be used during tree creation.
+ * @tparam MatType Type of matrix to build the tree on (generally mat or
+ *      sp_mat).
  */
 template<typename MetricType = metric::LMetric<2, true>,
          typename RootPointPolicy = FirstPointIsRoot,
-         typename StatisticType = EmptyStatistic>
+         typename StatisticType = EmptyStatistic,
+         typename MatType = arma::mat>
 class CoverTree
 {
  public:
-  typedef arma::mat Mat;
+  typedef MatType Mat;
 
   /**
    * Create the cover tree with the given dataset and given base.
@@ -100,7 +103,7 @@ class CoverTree
    * @param dataset Reference to the dataset to build a tree on.
    * @param base Base to use during tree building (default 2.0).
    */
-  CoverTree(const arma::mat& dataset,
+  CoverTree(const MatType& dataset,
             const double base = 2.0,
             MetricType* metric = NULL);
 
@@ -113,7 +116,7 @@ class CoverTree
    * @param metric Instantiated metric to use during tree building.
    * @param base Base to use during tree building (default 2.0).
    */
-  CoverTree(const arma::mat& dataset,
+  CoverTree(const MatType& dataset,
             MetricType& metric,
             const double base = 2.0);
 
@@ -148,7 +151,7 @@ class CoverTree
    *     any points in the far set).
    * @param usedSetSize The number of points used will be added to this number.
    */
-  CoverTree(const arma::mat& dataset,
+  CoverTree(const MatType& dataset,
             const double base,
             const size_t pointIndex,
             const int scale,
@@ -177,7 +180,7 @@ class CoverTree
    * @param furthestDescendantDistance Distance to furthest descendant point.
    * @param metric Instantiated metric (optional).
    */
-  CoverTree(const arma::mat& dataset,
+  CoverTree(const MatType& dataset,
             const double base,
             const size_t pointIndex,
             const int scale,
@@ -212,7 +215,7 @@ class CoverTree
   using BreadthFirstDualTreeTraverser = DualTreeTraverser<RuleType>;
 
   //! Get a reference to the dataset.
-  const arma::mat& Dataset() const { return dataset; }
+  const MatType& Dataset() const { return dataset; }
 
   //! Get the index of the point which this node represents.
   size_t Point() const { return point; }
@@ -337,7 +340,7 @@ class CoverTree
 
  private:
   //! Reference to the matrix which this tree is built on.
-  const arma::mat& dataset;
+  const MatType& dataset;
 
   //! Index of the point in the matrix which this node represents.
   size_t point;
diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
index ade0eb4..24ed095 100644
--- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
@@ -17,9 +17,14 @@ namespace mlpack {
 namespace tree {
 
 // Create the cover tree.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const arma::mat& dataset,
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+    const MatType& dataset,
     const double base,
     MetricType* metric) :
     dataset(dataset),
@@ -98,9 +103,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
       << "construction." << std::endl;
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const arma::mat& dataset,
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+    const MatType& dataset,
     MetricType& metric,
     const double base) :
     dataset(dataset),
@@ -175,9 +185,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
       << "construction." << std::endl;
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const arma::mat& dataset,
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+    const MatType& dataset,
     const double base,
     const size_t pointIndex,
     const int scale,
@@ -218,9 +233,14 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
 }
 
 // Manually create a cover tree node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const arma::mat& dataset,
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
+    const MatType& dataset,
     const double base,
     const size_t pointIndex,
     const int scale,
@@ -248,8 +268,13 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
   stat = StatisticType(*this);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CoverTree(
     const CoverTree& other) :
     dataset(other.dataset),
     point(other.point),
@@ -272,8 +297,13 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
   }
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::~CoverTree()
 {
   // Delete each child.
   for (size_t i = 0; i < children.size(); ++i)
@@ -285,17 +315,28 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
 }
 
 //! Return the number of descendant points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 inline size_t
-CoverTree<MetricType, RootPointPolicy, StatisticType>::NumDescendants() const
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    NumDescendants() const
 {
   return numDescendants;
 }
 
 //! Return the index of a particular descendant point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 inline size_t
-CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::Descendant(
     const size_t index) const
 {
   // The first descendant is the point contained within this node.
@@ -319,84 +360,125 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
   return (size_t() - 1);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MinDistance(const CoverTree* other) const
 {
   // Every cover tree node will contain points up to base^(scale + 1) away.
-  return std::max(metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point())) -
+  return std::max(metric->Evaluate(dataset.col(point),
+      other->Dataset().col(other->Point())) -
       furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
-    const double distance) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MinDistance(const CoverTree* other, const double distance) const
 {
   // We already have the distance as evaluated by the metric.
   return std::max(distance - furthestDescendantDistance -
       other->FurthestDescendantDistance(), 0.0);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const arma::vec& other) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MinDistance(const arma::vec& other) const
 {
-  return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
+  return std::max(metric->Evaluate(dataset.col(point), other) -
       furthestDescendantDistance, 0.0);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const arma::vec& /* other */,
-    const double distance) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MinDistance(const arma::vec& /* other */, const double distance) const
 {
   return std::max(distance - furthestDescendantDistance, 0.0);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MaxDistance(const CoverTree* other) const
 {
-  return metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point())) +
+  return metric->Evaluate(dataset.col(point),
+      other->Dataset().col(other->Point())) +
       furthestDescendantDistance + other->FurthestDescendantDistance();
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
-    const double distance) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MaxDistance(const CoverTree* other, const double distance) const
 {
   // We already have the distance as evaluated by the metric.
   return distance + furthestDescendantDistance +
       other->FurthestDescendantDistance();
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const arma::vec& other) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MaxDistance(const arma::vec& other) const
 {
-  return metric->Evaluate(dataset.unsafe_col(point), other) +
+  return metric->Evaluate(dataset.col(point), other) +
       furthestDescendantDistance;
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const arma::vec& /* other */,
-    const double distance) const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+double CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MaxDistance(const arma::vec& /* other */, const double distance) const
 {
   return distance + furthestDescendantDistance;
 }
 
 //! Return the minimum and maximum distance to another node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
     RangeDistance(const CoverTree* other) const
 {
-  const double distance = metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point()));
+  const double distance = metric->Evaluate(dataset.col(point),
+      other->Dataset().col(other->Point()));
 
   math::Range result;
   result.Lo() = distance - furthestDescendantDistance -
@@ -409,8 +491,13 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
 
 //! Return the minimum and maximum distance to another node given that the
 //! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
     RangeDistance(const CoverTree* other,
                   const double distance) const
 {
@@ -424,11 +511,16 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
 }
 
 //! Return the minimum and maximum distance to another point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
     RangeDistance(const arma::vec& other) const
 {
-  const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
+  const double distance = metric->Evaluate(dataset.col(point), other);
 
   return math::Range(distance - furthestDescendantDistance,
                      distance + furthestDescendantDistance);
@@ -436,8 +528,13 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
 
 //! Return the minimum and maximum distance to another point given that the
 //! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
     RangeDistance(const arma::vec& /* other */,
                   const double distance) const
 {
@@ -446,9 +543,14 @@ math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
 }
 
 //! For a newly initialized node, create children using the near and far set.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 inline void
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::CreateChildren(
     arma::Col<size_t>& indices,
     arma::vec& distances,
     size_t nearSetSize,
@@ -636,12 +738,17 @@ CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
       furthestDescendantDistance = distances[i];
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
-    arma::Col<size_t>& indices,
-    arma::vec& distances,
-    const double bound,
-    const size_t pointSetSize)
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    SplitNearFar(arma::Col<size_t>& indices,
+                 arma::vec& distances,
+                 const double bound,
+                 const size_t pointSetSize)
 {
   // Sanity check; there is no guarantee that this condition will not be true.
   // ...or is there?
@@ -689,30 +796,40 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
 }
 
 // Returns the maximum distance between points.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
-    const size_t pointIndex,
-    const arma::Col<size_t>& indices,
-    arma::vec& distances,
-    const size_t pointSetSize)
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    ComputeDistances(const size_t pointIndex,
+                     const arma::Col<size_t>& indices,
+                     arma::vec& distances,
+                     const size_t pointSetSize)
 {
   // For each point, rebuild the distances.  The indices do not need to be
   // modified.
   distanceComps += pointSetSize;
   for (size_t i = 0; i < pointSetSize; ++i)
   {
-    distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
-        dataset.unsafe_col(indices[i]));
+    distances[i] = metric->Evaluate(dataset.col(pointIndex),
+        dataset.col(indices[i]));
   }
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
-    arma::Col<size_t>& indices,
-    arma::vec& distances,
-    const size_t childFarSetSize,
-    const size_t childUsedSetSize,
-    const size_t farSetSize)
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    SortPointSet(arma::Col<size_t>& indices,
+                 arma::vec& distances,
+                 const size_t childFarSetSize,
+                 const size_t childUsedSetSize,
+                 const size_t farSetSize)
 {
   // We'll use low-level memcpy calls ourselves, just to ensure it's done
   // quickly and the way we want it to be.  Unfortunately this takes up more
@@ -766,16 +883,21 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
   return (childFarSetSize + farSetSize);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
-    arma::Col<size_t>& indices,
-    arma::vec& distances,
-    size_t& nearSetSize,
-    size_t& farSetSize,
-    size_t& usedSetSize,
-    arma::Col<size_t>& childIndices,
-    const size_t childFarSetSize, // childNearSetSize is 0 in this case.
-    const size_t childUsedSetSize)
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    MoveToUsedSet(arma::Col<size_t>& indices,
+                  arma::vec& distances,
+                  size_t& nearSetSize,
+                  size_t& farSetSize,
+                  size_t& usedSetSize,
+                  arma::Col<size_t>& childIndices,
+                  const size_t childFarSetSize, // childNearSetSize is 0 here.
+                  const size_t childUsedSetSize)
 {
   const size_t originalSum = nearSetSize + farSetSize + usedSetSize;
 
@@ -907,13 +1029,18 @@ void CoverTree<MetricType, RootPointPolicy, StatisticType>::MoveToUsedSet(
   Log::Assert(originalSum == (nearSetSize + farSetSize + usedSetSize));
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
-    arma::Col<size_t>& indices,
-    arma::vec& distances,
-    const double bound,
-    const size_t nearSetSize,
-    const size_t pointSetSize)
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    PruneFarSet(arma::Col<size_t>& indices,
+                arma::vec& distances,
+                const double bound,
+                const size_t nearSetSize,
+                const size_t pointSetSize)
 {
   // What we are trying to do is remove any points greater than the bound from
   // the far set.  We don't care what happens to those indices and distances...
@@ -948,8 +1075,13 @@ size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::PruneFarSet(
  * Take a look at the last child (the most recently created one) and remove any
  * implicit nodes that have been created.
  */
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-inline void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+inline void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
     RemoveNewImplicitNodes()
 {
   // If we created an implicit node, take its self-child instead (this could
@@ -978,9 +1110,14 @@ inline void CoverTree<MetricType, RootPointPolicy, StatisticType>::
 /**
  * Returns a string representation of this object.
  */
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString()
-    const
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
+std::string CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    ToString() const
 {
   std::ostringstream convert;
   convert << "CoverTree [" << this << "]" << std::endl;
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
index e09d2d0..963099c 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser.hpp
@@ -13,9 +13,15 @@
 namespace mlpack {
 namespace tree {
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
+class CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    DualTreeTraverser
 {
  public:
   /**
@@ -53,7 +59,8 @@ class CoverTree<MetricType, RootPointPolicy, StatisticType>::DualTreeTraverser
   struct DualCoverTreeMapEntry
   {
     //! The node this entry refers to.
-    CoverTree<MetricType, RootPointPolicy, StatisticType>* referenceNode;
+    CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>*
+        referenceNode;
     //! The score of the node.
     double score;
     //! The base case.
diff --git a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
index 45849ab..aa28537 100644
--- a/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/dual_tree_traverser_impl.hpp
@@ -13,20 +13,29 @@
 namespace mlpack {
 namespace tree {
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0)
 { /* Nothing to do. */ }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
-DualTreeTraverser<RuleType>::Traverse(
-    CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
-    CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+DualTreeTraverser<RuleType>::Traverse(CoverTree& queryNode,
+                                      CoverTree& referenceNode)
 {
   // Start by creating a map and adding the reference root node to it.
   std::map<int, std::vector<DualCoverTreeMapEntry> > refMap;
@@ -46,11 +55,16 @@ DualTreeTraverser<RuleType>::Traverse(
   Traverse(queryNode, refMap);
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::Traverse(
-    CoverTree<MetricType, RootPointPolicy, StatisticType>& queryNode,
+    CoverTree& queryNode,
     std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
 {
   if (referenceMap.size() == 0)
@@ -128,9 +142,14 @@ DualTreeTraverser<RuleType>::Traverse(
   }
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::PruneMap(
     CoverTree& queryNode,
     std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap,
@@ -246,9 +265,14 @@ DualTreeTraverser<RuleType>::PruneMap(
   }
 }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 DualTreeTraverser<RuleType>::ReferenceRecursion(
     CoverTree& queryNode,
     std::map<int, std::vector<DualCoverTreeMapEntry> >& referenceMap)
diff --git a/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp b/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
index 5e8979c..18b449d 100644
--- a/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
+++ b/src/mlpack/core/tree/cover_tree/first_point_is_root.hpp
@@ -28,7 +28,8 @@ class FirstPointIsRoot
    * Return the point to be used as the root point of the cover tree.  This just
    * returns 0.
    */
-  static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
+  template<typename MatType>
+  static size_t ChooseRoot(const MatType& /* dataset */) { return 0; }
 };
 
 }; // namespace tree
diff --git a/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
index c767f8c..62aefe5 100644
--- a/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/cover_tree/single_tree_traverser.hpp
@@ -16,9 +16,15 @@
 namespace mlpack {
 namespace tree {
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-class CoverTree<MetricType, RootPointPolicy, StatisticType>::SingleTreeTraverser
+class CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
+    SingleTreeTraverser
 {
  public:
   /**
diff --git a/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
index 3b60a93..3a49cfb 100644
--- a/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/cover_tree/single_tree_traverser_impl.hpp
@@ -17,11 +17,16 @@ namespace mlpack {
 namespace tree {
 
 //! This is the structure the cover tree map will use for traversal.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 struct CoverTreeMapEntry
 {
   //! The node this entry refers to.
-  CoverTree<MetricType, RootPointPolicy, StatisticType>* node;
+  CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>* node;
   //! The score of the node.
   double score;
   //! The index of the parent node.
@@ -36,24 +41,34 @@ struct CoverTreeMapEntry
   }
 };
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::
+CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
     rule(rule),
     numPrunes(0)
 { /* Nothing to do. */ }
 
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+template<
+    typename MetricType,
+    typename RootPointPolicy,
+    typename StatisticType,
+    typename MatType
+>
 template<typename RuleType>
-void CoverTree<MetricType, RootPointPolicy, StatisticType>::
+void CoverTree<MetricType, RootPointPolicy, StatisticType, MatType>::
 SingleTreeTraverser<RuleType>::Traverse(
     const size_t queryIndex,
-    CoverTree<MetricType, RootPointPolicy, StatisticType>& referenceNode)
+    CoverTree& referenceNode)
 {
   // This is a non-recursive implementation (which should be faster than a
   // recursive implementation).
-  typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType>
+  typedef CoverTreeMapEntry<MetricType, RootPointPolicy, StatisticType, MatType>
       MapEntryType;
 
   // We will use this map as a priority queue.  Each key represents the scale,
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 6125a14..e842d93 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1573,7 +1573,7 @@ void CheckCovering(const TreeType& node)
   if (node.NumChildren() == 0)
     return;
 
-  const arma::mat& dataset = node.Dataset();
+  const typename TreeType::Mat& dataset = node.Dataset();
   const size_t nodePoint = node.Point();
 
   // To ensure that this node satisfies the covering principle, we must ensure
@@ -1620,7 +1620,7 @@ void CheckIndividualSeparation(const TreeType& constantNode,
     return;
 
   // Now we know we are at the same scale, so make the comparison.
-  const arma::mat& dataset = constantNode.Dataset();
+  const typename TreeType::Mat& dataset = constantNode.Dataset();
   const size_t constantPoint = constantNode.Point();
   const size_t nodePoint = node.Point();
 
@@ -1735,6 +1735,38 @@ BOOST_AUTO_TEST_CASE(CoverTreeConstructionTest)
 }
 
 /**
+ * Create a cover tree on sparse data and make sure it's accurate.
+ */
+BOOST_AUTO_TEST_CASE(SparseCoverTreeConstructionTest)
+{
+  arma::sp_mat dataset;
+  // 50-dimensional, 1000 point.
+  dataset.sprandu(50, 1000, 0.3);
+
+  typedef CoverTree<EuclideanDistance, FirstPointIsRoot, EmptyStatistic,
+      arma::sp_mat> TreeType;
+  TreeType tree(dataset);
+
+  // Ensure each leaf is only created once.
+  arma::vec counts;
+  counts.zeros(1000);
+  RecurseTreeCountLeaves(tree, counts);
+
+  for (size_t i = 0; i < 1000; ++i)
+    BOOST_REQUIRE_EQUAL(counts[i], 1);
+
+  // Each non-leaf should have a self-child.
+  CheckSelfChild<TreeType>(tree);
+
+  // Each node must satisfy the covering principle (its children must be less
+  // than or equal to a certain distance apart).
+  CheckCovering<TreeType, LMetric<2, true> >(tree);
+
+  // Each node's children must be separated by at least a certain value.
+  CheckSeparation<TreeType, LMetric<2, true> >(tree, tree);
+}
+
+/**
  * Test the manual constructor.
  */
 BOOST_AUTO_TEST_CASE(CoverTreeManualConstructorTest)



More information about the mlpack-git mailing list