[mlpack-svn] r15035 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 8 17:22:20 EDT 2013


Author: rcurtin
Date: 2013-05-08 17:22:19 -0400 (Wed, 08 May 2013)
New Revision: 15035

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
Log:
Add Descendant() and NumDescendants() functions.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	2013-05-08 19:26:14 UTC (rev 15034)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	2013-05-08 21:22:19 UTC (rev 15035)
@@ -238,6 +238,12 @@
   //! Modify the children manually (maybe not a great idea).
   std::vector<CoverTree*>& Children() { return children; }
 
+  //! Get the number of descendant points.
+  size_t NumDescendants() const;
+
+  //! Get the index of a particular descendant point.
+  size_t Descendant(const size_t index) const;
+
   //! Get the scale of this node.
   int Scale() const { return scale; }
   //! Modify the scale of this node.  Be careful...
@@ -341,6 +347,9 @@
   //! The instantiated statistic.
   StatisticType stat;
 
+  //! The number of descendant points.
+  size_t numDescendants;
+
   //! The parent node (NULL if this is the root of the tree).
   CoverTree* parent;
 

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	2013-05-08 19:26:14 UTC (rev 15034)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	2013-05-08 21:22:19 UTC (rev 15035)
@@ -26,6 +26,7 @@
     point(RootPointPolicy::ChooseRoot(dataset)),
     scale(INT_MAX),
     base(base),
+    numDescendants(0),
     parent(NULL),
     parentDistance(0),
     furthestDescendantDistance(0),
@@ -76,6 +77,7 @@
     point(RootPointPolicy::ChooseRoot(dataset)),
     scale(INT_MAX),
     base(base),
+    numDescendants(0),
     parent(NULL),
     parentDistance(0),
     furthestDescendantDistance(0),
@@ -131,6 +133,7 @@
     point(pointIndex),
     scale(scale),
     base(base),
+    numDescendants(0),
     parent(parent),
     parentDistance(parentDistance),
     furthestDescendantDistance(0),
@@ -141,6 +144,7 @@
   if (nearSetSize == 0)
   {
     this->scale = INT_MIN;
+    numDescendants = 1;
     stat = StatisticType(*this);
     return;
   }
@@ -152,7 +156,234 @@
   stat = StatisticType(*this);
 }
 
+// Manually create a cover tree node.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+    const arma::mat& dataset,
+    const double base,
+    const size_t pointIndex,
+    const int scale,
+    CoverTree* parent,
+    const double parentDistance,
+    const double furthestDescendantDistance,
+    MetricType* metric) :
+    dataset(dataset),
+    point(pointIndex),
+    scale(scale),
+    base(base),
+    numDescendants(0),
+    parent(parent),
+    parentDistance(parentDistance),
+    furthestDescendantDistance(furthestDescendantDistance),
+    localMetric(metric == NULL),
+    metric(metric)
+{
+  // If necessary, create a local metric.
+  if (localMetric)
+    this->metric = new MetricType();
+
+  // Initialize the statistic.
+  stat = StatisticType(*this);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
+    const CoverTree& other) :
+    dataset(other.dataset),
+    point(other.point),
+    scale(other.scale),
+    base(other.base),
+    stat(other.stat),
+    numDescendants(other.numDescendants),
+    parent(other.parent),
+    parentDistance(other.parentDistance),
+    furthestDescendantDistance(other.furthestDescendantDistance),
+    localMetric(false),
+    metric(other.metric)
+{
+  // Copy each child by hand.
+  for (size_t i = 0; i < other.NumChildren(); ++i)
+  {
+    children.push_back(new CoverTree(other.Child(i)));
+    children[i]->Parent() = this;
+  }
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
+{
+  // Delete each child.
+  for (size_t i = 0; i < children.size(); ++i)
+    delete children[i];
+
+  // Delete the local metric, if necessary.
+  if (localMetric)
+    delete metric;
+}
+
+//! Return the number of descendant points.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+inline size_t
+CoverTree<MetricType, RootPointPolicy, StatisticType>::NumDescendants() const
+{
+  return numDescendants;
+}
+
+//! Return the index of a particular descendant point.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+inline size_t
+CoverTree<MetricType, RootPointPolicy, StatisticType>::Descendant(
+    const size_t index) const
+{
+  // The first descendant is the point contained within this node.
+  if (index == 0)
+    return point;
+
+  // Is it in the self-child?
+  if (index < children[0]->NumDescendants())
+    return children[0]->Descendant(index);
+
+  // Now check the other children.
+  size_t sum = children[0]->NumDescendants();
+  for (size_t i = 1; i < children.size(); ++i)
+  {
+    if (index - sum < children[i]->NumDescendants())
+      return children[i]->Descendant(index - sum);
+    sum += children[i]->NumDescendants();
+  }
+
+  // This should never happen.
+  return (size_t() - 1);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+  // Every cover tree node will contain points up to EC^(scale + 1) away.
+  return std::max(metric->Evaluate(dataset.unsafe_col(point),
+      other->Dataset().unsafe_col(other->Point())) -
+      furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+    const double distance) const
+{
+  // We already have the distance as evaluated by the metric.
+  return std::max(distance - furthestDescendantDistance -
+      other->FurthestDescendantDistance(), 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+    const arma::vec& other) const
+{
+  return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
+      furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
+    const arma::vec& /* other */,
+    const double distance) const
+{
+  return std::max(distance - furthestDescendantDistance, 0.0);
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
+{
+  return metric->Evaluate(dataset.unsafe_col(point),
+      other->Dataset().unsafe_col(other->Point())) +
+      furthestDescendantDistance + other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
+    const double distance) const
+{
+  // We already have the distance as evaluated by the metric.
+  return distance + furthestDescendantDistance +
+      other->FurthestDescendantDistance();
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+    const arma::vec& other) const
+{
+  return metric->Evaluate(dataset.unsafe_col(point), other) +
+      furthestDescendantDistance;
+}
+
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
+    const arma::vec& /* other */,
+    const double distance) const
+{
+  return distance + furthestDescendantDistance;
+}
+
+//! Return the minimum and maximum distance to another node.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+    RangeDistance(const CoverTree* other) const
+{
+  const double distance = metric->Evaluate(dataset.unsafe_col(point),
+      other->Dataset().unsafe_col(other->Point()));
+
+  math::Range result;
+  result.Lo() = distance - furthestDescendantDistance -
+      other->FurthestDescendantDistance();
+  result.Hi() = distance + furthestDescendantDistance +
+      other->FurthestDescendantDistance();
+
+  return result;
+}
+
+//! Return the minimum and maximum distance to another node given that the
+//! point-to-point distance has already been calculated.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+    RangeDistance(const CoverTree* other,
+                  const double distance) const
+{
+  math::Range result;
+  result.Lo() = distance - furthestDescendantDistance -
+      other->FurthestDescendantDistance();
+  result.Hi() = distance + furthestDescendantDistance +
+      other->FurthestDescendantDistance();
+
+  return result;
+}
+
+//! Return the minimum and maximum distance to another point.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+    RangeDistance(const arma::vec& other) const
+{
+  const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
+
+  return math::Range(distance - furthestDescendantDistance,
+                     distance + furthestDescendantDistance);
+}
+
+//! Return the minimum and maximum distance to another point given that the
+//! point-to-point distance has already been calculated.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
+    RangeDistance(const arma::vec& other,
+                  const double distance) const
+{
+  return math::Range(distance - furthestDescendantDistance,
+                     distance + furthestDescendantDistance);
+}
+
+//! For a newly initialized node, create children using the near and far set.
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 inline void
 CoverTree<MetricType, RootPointPolicy, StatisticType>::CreateChildren(
     arma::Col<size_t>& indices,
@@ -189,15 +420,16 @@
       usedSetSize++;
     }
 
+    // The number of descendants is just the number of children, because each of
+    // them are leaves and contain one point.
+    numDescendants = children.size();
+
     // Re-sort the dataset.  We have
     // [ used | far | other used ]
     // and we want
     // [ far | all used ].
     SortPointSet(indices, distances, 0, usedSetSize, farSetSize);
 
-    // Initialize the statistic.
-    stat = StatisticType(*this);
-
     return;
   }
 
@@ -216,6 +448,8 @@
   children.push_back(new CoverTree(dataset, base, point, nextScale, this, 0,
       indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize,
       *metric));
+  // Don't double-count the self-child (so, subtract one).
+  numDescendants += children[0]->NumDescendants();
 
   // The self-child can't modify the furthestChildDistance away from 0, but it
   // can modify the furthestDescendantDistance.
@@ -287,6 +521,7 @@
       children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
           this, distances[0], indices, distances, childNearSetSize, farSetSize,
           usedSetSize, *metric));
+      numDescendants += children[children.size() - 1]->NumDescendants();
 
       // Because the far set size is 0, we don't have to do any swapping to
       // move the point into the used set.
@@ -327,6 +562,7 @@
     children.push_back(new CoverTree(dataset, base, indices[0], nextScale,
         this, distances[0], childIndices, childDistances, childNearSetSize,
         childFarSetSize, childUsedSetSize, *metric));
+    numDescendants += children[children.size() - 1]->NumDescendants();
 
     // If we created an implicit node, take its self-child instead (this could
     // happen multiple times).
@@ -365,196 +601,7 @@
       furthestDescendantDistance = distances[i];
 }
 
-// Manually create a cover tree node.
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const arma::mat& dataset,
-    const double base,
-    const size_t pointIndex,
-    const int scale,
-    CoverTree* parent,
-    const double parentDistance,
-    const double furthestDescendantDistance,
-    MetricType* metric) :
-    dataset(dataset),
-    point(pointIndex),
-    scale(scale),
-    base(base),
-    parent(parent),
-    parentDistance(parentDistance),
-    furthestDescendantDistance(furthestDescendantDistance),
-    localMetric(metric == NULL),
-    metric(metric)
-{
-  // If necessary, create a local metric.
-  if (localMetric)
-    this->metric = new MetricType();
-
-  // Initialize the statistic.
-  stat = StatisticType(*this);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
-    const CoverTree& other) :
-    dataset(other.dataset),
-    point(other.point),
-    scale(other.scale),
-    base(other.base),
-    stat(other.stat),
-    parent(other.parent),
-    parentDistance(other.parentDistance),
-    furthestDescendantDistance(other.furthestDescendantDistance),
-    localMetric(false),
-    metric(other.metric)
-{
-  // Copy each child by hand.
-  for (size_t i = 0; i < other.NumChildren(); ++i)
-  {
-    children.push_back(new CoverTree(other.Child(i)));
-    children[i]->Parent() = this;
-  }
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
-{
-  // Delete each child.
-  for (size_t i = 0; i < children.size(); ++i)
-    delete children[i];
-
-  // Delete the local metric, if necessary.
-  if (localMetric)
-    delete metric;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
-  // Every cover tree node will contain points up to EC^(scale + 1) away.
-  return std::max(metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point())) -
-      furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
-    const double distance) const
-{
-  // We already have the distance as evaluated by the metric.
-  return std::max(distance - furthestDescendantDistance -
-      other->FurthestDescendantDistance(), 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const arma::vec& other) const
-{
-  return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
-      furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
-    const arma::vec& /* other */,
-    const double distance) const
-{
-  return std::max(distance - furthestDescendantDistance, 0.0);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
-{
-  return metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point())) +
-      furthestDescendantDistance + other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const CoverTree<MetricType, RootPointPolicy, StatisticType>* other,
-    const double distance) const
-{
-  // We already have the distance as evaluated by the metric.
-  return distance + furthestDescendantDistance +
-      other->FurthestDescendantDistance();
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const arma::vec& other) const
-{
-  return metric->Evaluate(dataset.unsafe_col(point), other) +
-      furthestDescendantDistance;
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
-    const arma::vec& /* other */,
-    const double distance) const
-{
-  return distance + furthestDescendantDistance;
-}
-
-//! Return the minimum and maximum distance to another node.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
-    RangeDistance(const CoverTree* other) const
-{
-  const double distance = metric->Evaluate(dataset.unsafe_col(point),
-      other->Dataset().unsafe_col(other->Point()));
-
-  math::Range result;
-  result.Lo() = distance - furthestDescendantDistance -
-      other->FurthestDescendantDistance();
-  result.Hi() = distance + furthestDescendantDistance +
-      other->FurthestDescendantDistance();
-
-  return result;
-}
-
-//! Return the minimum and maximum distance to another node given that the
-//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
-    RangeDistance(const CoverTree* other,
-                  const double distance) const
-{
-  math::Range result;
-  result.Lo() = distance - furthestDescendantDistance -
-      other->FurthestDescendantDistance();
-  result.Hi() = distance + furthestDescendantDistance +
-      other->FurthestDescendantDistance();
-
-  return result;
-}
-
-//! Return the minimum and maximum distance to another point.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
-    RangeDistance(const arma::vec& other) const
-{
-  const double distance = metric->Evaluate(dataset.unsafe_col(point), other);
-
-  return math::Range(distance - furthestDescendantDistance,
-                     distance + furthestDescendantDistance);
-}
-
-//! Return the minimum and maximum distance to another point given that the
-//! point-to-point distance has already been calculated.
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-math::Range CoverTree<MetricType, RootPointPolicy, StatisticType>::
-    RangeDistance(const arma::vec& other,
-                  const double distance) const
-{
-  return math::Range(distance - furthestDescendantDistance,
-                     distance + furthestDescendantDistance);
-}
-
-template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
     arma::Col<size_t>& indices,
     arma::vec& distances,
@@ -865,7 +912,8 @@
  * Returns a string representation of this object.
  */
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
-std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString() const
+std::string CoverTree<MetricType, RootPointPolicy, StatisticType>::ToString()
+    const
 {
   std::ostringstream convert;
   convert << "CoverTree [" << this << "]" << std::endl;




More information about the mlpack-svn mailing list