[mlpack-svn] r12581 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 30 17:39:22 EDT 2012


Author: rcurtin
Date: 2012-04-30 17:39:21 -0400 (Mon, 30 Apr 2012)
New Revision: 12581

Modified:
   mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Allow arbitrary metrics to be used in the construction of the cover tree.


Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-04-30 21:23:00 UTC (rev 12580)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-04-30 21:39:21 UTC (rev 12581)
@@ -8,6 +8,7 @@
 #define __MLPACK_CORE_TREE_COVER_TREE_HPP
 
 #include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
 
 namespace mlpack {
 namespace tree {
@@ -68,7 +69,8 @@
  * }
  * @endcode
  */
-template<typename StatisticType = EmptyStatistic>
+template<typename MetricType = metric::LMetric<2>,
+         typename StatisticType = EmptyStatistic>
 class CoverTree
 {
  public:

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-04-30 21:23:00 UTC (rev 12580)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-04-30 21:39:21 UTC (rev 12581)
@@ -14,9 +14,10 @@
 namespace tree {
 
 // Create the cover tree.
-template<typename StatisticType>
-CoverTree<StatisticType>::CoverTree(const arma::mat& dataset,
-                                    const double expansionConstant) :
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::CoverTree(
+    const arma::mat& dataset,
+    const double expansionConstant) :
     dataset(dataset),
     point(0),
     expansionConstant(expansionConstant)
@@ -116,16 +117,17 @@
   }
 }
 
-template<typename StatisticType>
-CoverTree<StatisticType>::CoverTree(const arma::mat& dataset,
-                                    const double expansionConstant,
-                                    const size_t pointIndex,
-                                    const int scale,
-                                    arma::Col<size_t>& indices,
-                                    arma::vec& distances,
-                                    size_t nearSetSize,
-                                    size_t& farSetSize,
-                                    size_t& usedSetSize) :
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::CoverTree(
+    const arma::mat& dataset,
+    const double expansionConstant,
+    const size_t pointIndex,
+    const int scale,
+    arma::Col<size_t>& indices,
+    arma::vec& distances,
+    size_t nearSetSize,
+    size_t& farSetSize,
+    size_t& usedSetSize) :
     dataset(dataset),
     point(pointIndex),
     scale(scale),
@@ -302,19 +304,20 @@
   ComputeDistances(pointIndex, indices, distances, farSetSize);
 }
 
-template<typename StatisticType>
-CoverTree<StatisticType>::~CoverTree()
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::~CoverTree()
 {
   // Delete each child.
   for (size_t i = 0; i < children.size(); ++i)
     delete children[i];
 }
 
-template<typename StatisticType>
-size_t CoverTree<StatisticType>::SplitNearFar(arma::Col<size_t>& indices,
-                                              arma::vec& distances,
-                                              const double bound,
-                                              const size_t pointSetSize)
+template<typename MetricType, typename StatisticType>
+size_t CoverTree<MetricType, StatisticType>::SplitNearFar(
+    arma::Col<size_t>& indices,
+    arma::vec& distances,
+    const double bound,
+    const size_t pointSetSize)
 {
   // Sanity check; there is no guarantee that this condition will not be true.
   // ...or is there?
@@ -362,8 +365,8 @@
 }
 
 // Returns the maximum distance between points.
-template<typename StatisticType>
-void CoverTree<StatisticType>::ComputeDistances(
+template<typename MetricType, typename StatisticType>
+void CoverTree<MetricType, StatisticType>::ComputeDistances(
     const size_t pointIndex,
     const arma::Col<size_t>& indices,
     arma::vec& distances,
@@ -373,17 +376,18 @@
   // modified.
   for (size_t i = 0; i < pointSetSize; ++i)
   {
-    distances[i] = metric::LMetric<2>::Evaluate(dataset.col(pointIndex),
+    distances[i] = MetricType::Evaluate(dataset.col(pointIndex),
         dataset.col(indices[i]));
   }
 }
 
-template<typename StatisticType>
-size_t CoverTree<StatisticType>::SortPointSet(arma::Col<size_t>& indices,
-                                              arma::vec& distances,
-                                              const size_t childFarSetSize,
-                                              const size_t childUsedSetSize,
-                                              const size_t farSetSize)
+template<typename MetricType, typename StatisticType>
+size_t CoverTree<MetricType, StatisticType>::SortPointSet(
+    arma::Col<size_t>& indices,
+    arma::vec& distances,
+    const size_t childFarSetSize,
+    const size_t childUsedSetSize,
+    const size_t farSetSize)
 {
   // We'll use low-level memcpy calls ourselves, just to ensure it's done
   // quickly and the way we want it to be.  Unfortunately this takes up more




More information about the mlpack-svn mailing list