[mlpack-svn] r14417 - mlpack/trunk/src/mlpack/core/tree/cover_tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Feb 28 14:15:26 EST 2013


Author: rcurtin
Date: 2013-02-28 14:15:25 -0500 (Thu, 28 Feb 2013)
New Revision: 14417

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
Log:
Add Metric() function, and refactor a little bit.  The metric may be stored
locally if none is given.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	2013-02-28 19:14:56 UTC (rev 14416)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree.hpp	2013-02-28 19:15:25 UTC (rev 14417)
@@ -160,6 +160,7 @@
    * @param parent Parent node (NULL indicates no parent).
    * @param parentDistance Distance to parent node point.
    * @param furthestDescendantDistance Distance to furthest descendant point.
+   * @param metric Instantiated metric (optional).
    */
   CoverTree(const arma::mat& dataset,
             const double base,
@@ -167,7 +168,8 @@
             const int scale,
             CoverTree* parent,
             const double parentDistance,
-            const double furthestDescendantDistance);
+            const double furthestDescendantDistance,
+            MetricType* metric = NULL);
 
   /**
    * Create a cover tree from another tree.  Be careful!  This may use a lot of
@@ -311,6 +313,12 @@
   //! Distance to the furthest descendant.
   double furthestDescendantDistance;
 
+  //! Whether or not we need to destroy the metric in the destructor.
+  bool localMetric;
+
+  //! The metric used for this tree.
+  MetricType* metric;
+
   /**
    * Fill the vector of distances with the distances between the point specified
    * by pointIndex and each point in the indices array.  The distances of the
@@ -325,8 +333,7 @@
   void ComputeDistances(const size_t pointIndex,
                         const arma::Col<size_t>& indices,
                         arma::vec& distances,
-                        const size_t pointSetSize,
-                        MetricType& metric);
+                        const size_t pointSetSize);
   /**
    * Split the given indices and distances into a near and a far set, returning
    * the number of points in the near set.  The distances must already be

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	2013-02-28 19:14:56 UTC (rev 14416)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp	2013-02-28 19:15:25 UTC (rev 14417)
@@ -27,15 +27,13 @@
     base(base),
     parent(NULL),
     parentDistance(0),
-    furthestDescendantDistance(0)
+    furthestDescendantDistance(0),
+    localMetric(metric == NULL),
+    metric(metric)
 {
   // If we need to create a metric, do that.  We'll just do it on the heap.
-  bool localMetric = false;
-  if (metric == NULL)
-  {
-    localMetric = true; // So we know we need to free it.
-    metric = new MetricType();
-  }
+  if (localMetric)
+    this->metric = new MetricType();
 
   // Kick off the building.  Create the indices array and the distances array.
   arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
@@ -48,7 +46,7 @@
   arma::vec distances(dataset.n_cols - 1);
 
   // Build the initial distances.
-  ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
+  ComputeDistances(point, indices, distances, dataset.n_cols - 1);
 
   // Now determine the scale factor of the root node.
   const double maxDistance = max(distances);
@@ -139,7 +137,7 @@
 
     // Build distances for the child.
     ComputeDistances(indices[0], childIndices, childDistances,
-        nearSetSize - 1, *metric);
+        nearSetSize - 1);
     childDistances(nearSetSize - 1) = 0;
 
     // Split into near and far sets for this point.
@@ -191,9 +189,6 @@
 
   Log::Assert(furthestDescendantDistance <= pow(base, scale + 1));
 
-  if (localMetric)
-    delete metric;
-
   // Initialize statistic.
   stat = StatisticType(*this);
 }
@@ -218,7 +213,9 @@
     base(base),
     parent(parent),
     parentDistance(parentDistance),
-    furthestDescendantDistance(0)
+    furthestDescendantDistance(0),
+    localMetric(false),
+    metric(&metric)
 {
   // If the size of the near set is 0, this is a leaf.
   if (nearSetSize == 0)
@@ -376,7 +373,7 @@
 
     // Build distances for the child.
     ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
-        + farSetSize - 1, metric);
+        + farSetSize - 1);
 
     // Split into near and far sets for this point.
     childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
@@ -449,15 +446,22 @@
     const int scale,
     CoverTree* parent,
     const double parentDistance,
-    const double furthestDescendantDistance) :
+    const double furthestDescendantDistance,
+    MetricType* metric) :
     dataset(dataset),
     point(pointIndex),
     scale(scale),
     base(base),
     parent(parent),
     parentDistance(parentDistance),
-    furthestDescendantDistance(furthestDescendantDistance)
+    furthestDescendantDistance(furthestDescendantDistance),
+    localMetric(metric == NULL),
+    metric(metric)
 {
+  // If necessary, create a local metric.
+  if (localMetric)
+    this->metric = new MetricType();
+
   // Initialize the statistic.
   stat = StatisticType(*this);
 }
@@ -472,7 +476,9 @@
     stat(other.stat),
     parent(other.parent),
     parentDistance(other.parentDistance),
-    furthestDescendantDistance(other.furthestDescendantDistance)
+    furthestDescendantDistance(other.furthestDescendantDistance),
+    localMetric(false),
+    metric(other.metric)
 {
   // Copy each child by hand.
   for (size_t i = 0; i < other.NumChildren(); ++i)
@@ -488,6 +494,10 @@
   // Delete each child.
   for (size_t i = 0; i < children.size(); ++i)
     delete children[i];
+
+  // Delete the local metric, if necessary.
+  if (localMetric)
+    delete metric;
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -495,7 +505,7 @@
     const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
 {
   // Every cover tree node will contain points up to EC^(scale + 1) away.
-  return std::max(MetricType::Evaluate(dataset.unsafe_col(point),
+  return std::max(metric->Evaluate(dataset.unsafe_col(point),
       other->Dataset().unsafe_col(other->Point())) -
       furthestDescendantDistance - other->FurthestDescendantDistance(), 0.0);
 }
@@ -514,7 +524,7 @@
 double CoverTree<MetricType, RootPointPolicy, StatisticType>::MinDistance(
     const arma::vec& other) const
 {
-  return std::max(MetricType::Evaluate(dataset.unsafe_col(point), other) -
+  return std::max(metric->Evaluate(dataset.unsafe_col(point), other) -
       furthestDescendantDistance, 0.0);
 }
 
@@ -530,7 +540,7 @@
 double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
     const CoverTree<MetricType, RootPointPolicy, StatisticType>* other) const
 {
-  return MetricType::Evaluate(dataset.unsafe_col(point),
+  return metric->Evaluate(dataset.unsafe_col(point),
       other->Dataset().unsafe_col(other->Point())) +
       furthestDescendantDistance + other->FurthestDescendantDistance();
 }
@@ -549,7 +559,7 @@
 double CoverTree<MetricType, RootPointPolicy, StatisticType>::MaxDistance(
     const arma::vec& other) const
 {
-  return MetricType::Evaluate(dataset.unsafe_col(point), other) +
+  return metric->Evaluate(dataset.unsafe_col(point), other) +
       furthestDescendantDistance;
 }
 
@@ -619,14 +629,13 @@
     const size_t pointIndex,
     const arma::Col<size_t>& indices,
     arma::vec& distances,
-    const size_t pointSetSize,
-    MetricType& metric)
+    const size_t pointSetSize)
 {
   // For each point, rebuild the distances.  The indices do not need to be
   // modified.
   for (size_t i = 0; i < pointSetSize; ++i)
   {
-    distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
+    distances[i] = metric->Evaluate(dataset.unsafe_col(pointIndex),
         dataset.unsafe_col(indices[i]));
   }
 }




More information about the mlpack-svn mailing list