[mlpack-svn] r12943 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jun 4 19:10:37 EDT 2012


Author: rcurtin
Date: 2012-06-04 19:10:37 -0400 (Mon, 04 Jun 2012)
New Revision: 12943

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Clean up cover tree implementation a little bit.  Allow instantiated metrics
(good for hyptan or gaussian kernels or any other kernels with parameters).


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-06-04 21:00:04 UTC (rev 12942)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-06-04 23:10:37 UTC (rev 12943)
@@ -103,7 +103,8 @@
    *      building (default 2.0).
    */
   CoverTree(const arma::mat& dataset,
-            const double expansionConstant = 2.0);
+            const double expansionConstant = 2.0,
+            MetricType* metric = NULL);
 
   /**
    * Construct a child cover tree node.  This constructor is not meant to be
@@ -144,7 +145,8 @@
             arma::vec& distances,
             size_t nearSetSize,
             size_t& farSetSize,
-            size_t& usedSetSize);
+            size_t& usedSetSize,
+            MetricType& metric = NULL);
 
   /**
    * Delete this cover tree node and its children.
@@ -266,9 +268,6 @@
   //! The instantiated statistic.
   StatisticType stat;
 
-  //! The instantiated metric.  Either the user passes it in, or we build it.
-  MetricType* metric;
-
   //! Distance to the parent.
   double parentDistance;
 
@@ -289,7 +288,8 @@
   void ComputeDistances(const size_t pointIndex,
                         const arma::Col<size_t>& indices,
                         arma::vec& distances,
-                        const size_t pointSetSize);
+                        const size_t pointSetSize,
+                        MetricType& metric);
   /**
    * Split the given indices and distances into a near and a far set, returning
    * the number of points in the near set.  The distances must already be

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-06-04 21:00:04 UTC (rev 12942)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-06-04 23:10:37 UTC (rev 12943)
@@ -17,13 +17,22 @@
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
 CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
     const arma::mat& dataset,
-    const double expansionConstant) :
+    const double expansionConstant,
+    MetricType* metric) :
     dataset(dataset),
     point(RootPointPolicy::ChooseRoot(dataset)),
     expansionConstant(expansionConstant),
     parentDistance(0),
     furthestDescendantDistance(0)
 {
+  // If we need to create a metric, do that.  We'll just do it on the heap.
+  bool localMetric = false;
+  if (metric == NULL)
+  {
+    localMetric = true; // So we know we need to free it.
+    metric = new MetricType();
+  }
+
   // Kick off the building.  Create the indices array and the distances array.
   arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
       dataset.n_cols - 1, dataset.n_cols - 1);
@@ -35,7 +44,7 @@
   arma::vec distances(dataset.n_cols - 1);
 
   // Build the initial distances.
-  ComputeDistances(point, indices, distances, dataset.n_cols - 1);
+  ComputeDistances(point, indices, distances, dataset.n_cols - 1, *metric);
 
   // Now determine the scale factor of the root node.
   const double maxDistance = max(distances);
@@ -50,7 +59,8 @@
   size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
   size_t usedSetSize = 0;
   children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
-      0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize));
+      0, indices, distances, childNearSetSize, childFarSetSize, usedSetSize,
+      *metric));
 
   furthestDescendantDistance = children[0]->FurthestDescendantDistance();
 
@@ -80,7 +90,7 @@
   {
     // We want to select the furthest point in the near set as the next child.
     size_t newPointIndex = nearSetSize - 1;
-    
+
     // Swap to front if necessary.
     if (newPointIndex != 0)
     {
@@ -106,7 +116,7 @@
       size_t childFarSetSize = 0;
       children.push_back(new CoverTree(dataset, expansionConstant,
           indices[0], scale - 1, distances[0], indices, distances,
-          childNearSetSize, childFarSetSize, usedSetSize));
+          childNearSetSize, childFarSetSize, usedSetSize, *metric));
 
       // And we're done.
       break;
@@ -122,7 +132,7 @@
 
     // Build distances for the child.
     ComputeDistances(indices[0], childIndices, childDistances,
-        nearSetSize - 1);
+        nearSetSize - 1, *metric);
     childDistances(nearSetSize - 1) = 0;
 
     // Split into near and far sets for this point.
@@ -134,7 +144,7 @@
     childFarSetSize = ((nearSetSize - 1) - childNearSetSize);
     children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
         scale - 1, distances[0], childIndices, childDistances, childNearSetSize,
-        childFarSetSize, childUsedSetSize));
+        childFarSetSize, childUsedSetSize, *metric));
 
     // If we created an implicit node, take its self-child instead (this could
     // happen multiple times).
@@ -163,13 +173,16 @@
     MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
         childIndices, childFarSetSize, childUsedSetSize);
   }
-  
+
   // Calculate furthest descendant.
   for (size_t i = 0; i < usedSetSize; ++i)
     if (distances[i] > furthestDescendantDistance)
       furthestDescendantDistance = distances[i];
 
   Log::Assert(furthestDescendantDistance <= pow(expansionConstant, scale + 1));
+
+  if (localMetric)
+    delete metric;
 }
 
 template<typename MetricType, typename RootPointPolicy, typename StatisticType>
@@ -183,7 +196,8 @@
     arma::vec& distances,
     size_t nearSetSize,
     size_t& farSetSize,
-    size_t& usedSetSize) :
+    size_t& usedSetSize,
+    MetricType& metric) :
     dataset(dataset),
     point(pointIndex),
     scale(scale),
@@ -193,8 +207,11 @@
 {
   // If the size of the near set is 0, this is a leaf.
   if (nearSetSize == 0)
+  {
+    this->scale = INT_MIN;
     return;
-    
+  }
+
   // Determine the next scale level.  This should be the first level where there
   // are any points in the far set.  So, if we know the maximum distance in the
   // distances array, this will be the largest i such that
@@ -203,21 +220,22 @@
   // implicit node.  If the maximum distance is 0, every point in the near set
   // will be created as a leaf, and a child to this node.  We also do not need
   // to change the furthestChildDistance or furthestDescendantDistance.
-  const double maxDistance = max(distances.rows(0, nearSetSize + farSetSize - 1));
+  const double maxDistance = max(distances.rows(0,
+      nearSetSize + farSetSize - 1));
   if (maxDistance == 0)
   {
     // Make the self child at the lowest possible level.
     // This should not modify farSetSize or usedSetSize.
     size_t tempSize = 0;
     children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
-        INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
+        INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
 
     // Every point in the near set should be a leaf.
     for (size_t i = 0; i < nearSetSize; ++i)
     {
       // farSetSize and usedSetSize will not be modified.
       children.push_back(new CoverTree(dataset, expansionConstant, indices[i],
-          INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize));
+          INT_MIN, 0, indices, distances, 0, tempSize, usedSetSize, metric));
       usedSetSize++;
     }
 
@@ -247,8 +265,8 @@
   size_t childUsedSetSize = 0;
   children.push_back(new CoverTree(dataset, expansionConstant, pointIndex,
       nextScale, 0, indices, distances, childNearSetSize, childFarSetSize,
-      childUsedSetSize));
-  
+      childUsedSetSize, metric));
+
   // The self-child can't modify the furthestChildDistance away from 0, but it
   // can modify the furthestDescendantDistance.
   furthestDescendantDistance = children[0]->FurthestDescendantDistance();
@@ -279,7 +297,7 @@
   // is what we are trying to make.
   SortPointSet(indices, distances, childFarSetSize, childUsedSetSize,
       farSetSize);
-  
+
   // Update size of near set and used set.
   nearSetSize -= childUsedSetSize;
   usedSetSize += childUsedSetSize;
@@ -315,7 +333,7 @@
       size_t childNearSetSize = 0;
       children.push_back(new CoverTree(dataset, expansionConstant,
           indices[0], nextScale, distances[0], indices, distances,
-          childNearSetSize, farSetSize, usedSetSize));
+          childNearSetSize, farSetSize, usedSetSize, metric));
 
       // Because the far set size is 0, we don't have to do any swapping to
       // move the point into the used set.
@@ -335,7 +353,7 @@
 
     // Build distances for the child.
     ComputeDistances(indices[0], childIndices, childDistances, nearSetSize
-        + farSetSize - 1);
+        + farSetSize - 1, metric);
 
     // Split into near and far sets for this point.
     childNearSetSize = SplitNearFar(childIndices, childDistances, bound,
@@ -343,7 +361,7 @@
     childFarSetSize = PruneFarSet(childIndices, childDistances,
         expansionConstant * bound, childNearSetSize,
         (nearSetSize + farSetSize - 1));
-    
+
     // Now that we know the near and far set sizes, we can put the used point
     // (the self point) in the correct place; now, when we call
     // MoveToUsedSet(), it will move the self-point correctly.  The distance
@@ -355,7 +373,7 @@
     childUsedSetSize = 1; // Mark self point as used.
     children.push_back(new CoverTree(dataset, expansionConstant, indices[0],
         nextScale, distances[0], childIndices, childDistances, childNearSetSize,
-        childFarSetSize, childUsedSetSize));
+        childFarSetSize, childUsedSetSize, metric));
 
     // If we created an implicit node, take its self-child instead (this could
     // happen multiple times).
@@ -383,9 +401,10 @@
     MoveToUsedSet(indices, distances, nearSetSize, farSetSize, usedSetSize,
         childIndices, childFarSetSize, childUsedSetSize);
   }
-    
+
   // Calculate furthest descendant.
-  for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize + usedSetSize); ++i)
+  for (size_t i = (nearSetSize + farSetSize); i < (nearSetSize + farSetSize +
+      usedSetSize); ++i)
     if (distances[i] > furthestDescendantDistance)
       furthestDescendantDistance = distances[i];
 
@@ -495,13 +514,14 @@
     const size_t pointIndex,
     const arma::Col<size_t>& indices,
     arma::vec& distances,
-    const size_t pointSetSize)
+    const size_t pointSetSize,
+    MetricType& metric)
 {
   // For each point, rebuild the distances.  The indices do not need to be
   // modified.
   for (size_t i = 0; i < pointSetSize; ++i)
   {
-    distances[i] = MetricType::Evaluate(dataset.unsafe_col(pointIndex),
+    distances[i] = metric.Evaluate(dataset.unsafe_col(pointIndex),
         dataset.unsafe_col(indices[i]));
   }
 }
@@ -561,7 +581,7 @@
 
   delete[] indicesBuffer;
   delete[] distancesBuffer;
-  
+
   // This returns the complete size of the far set.
   return (childFarSetSize + farSetSize);
 }
@@ -674,7 +694,7 @@
       if (childIndices[childFarSetSize + j] == indices[i + nearSetSize])
       {
         // We have found a point to swap.
-        
+
         // Perform the swap.
         size_t tempIndex = indices[nearSetSize + farSetSize - 1];
         double tempDist = distances[nearSetSize + farSetSize - 1];
@@ -741,7 +761,7 @@
 
   // The far set size is the left pointer, with the near set size subtracted
   // from it.
-  return (left - nearSetSize); 
+  return (left - nearSetSize);
 }
 
 }; // namespace tree




More information about the mlpack-svn mailing list