[mlpack-svn] r12581 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 30 17:39:22 EDT 2012
Author: rcurtin
Date: 2012-04-30 17:39:21 -0400 (Mon, 30 Apr 2012)
New Revision: 12581
Modified:
mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Allow arbitrary metrics to be used in the construction of the cover tree.
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-04-30 21:23:00 UTC (rev 12580)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-04-30 21:39:21 UTC (rev 12581)
@@ -8,6 +8,7 @@
#define __MLPACK_CORE_TREE_COVER_TREE_HPP
#include <mlpack/core.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
namespace mlpack {
namespace tree {
@@ -68,7 +69,8 @@
* }
* @endcode
*/
-template<typename StatisticType = EmptyStatistic>
+template<typename MetricType = metric::LMetric<2>,
+ typename StatisticType = EmptyStatistic>
class CoverTree
{
public:
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-04-30 21:23:00 UTC (rev 12580)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-04-30 21:39:21 UTC (rev 12581)
@@ -14,9 +14,10 @@
namespace tree {
// Create the cover tree.
-template<typename StatisticType>
-CoverTree<StatisticType>::CoverTree(const arma::mat& dataset,
- const double expansionConstant) :
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double expansionConstant) :
dataset(dataset),
point(0),
expansionConstant(expansionConstant)
@@ -116,16 +117,17 @@
}
}
-template<typename StatisticType>
-CoverTree<StatisticType>::CoverTree(const arma::mat& dataset,
- const double expansionConstant,
- const size_t pointIndex,
- const int scale,
- arma::Col<size_t>& indices,
- arma::vec& distances,
- size_t nearSetSize,
- size_t& farSetSize,
- size_t& usedSetSize) :
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::CoverTree(
+ const arma::mat& dataset,
+ const double expansionConstant,
+ const size_t pointIndex,
+ const int scale,
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ size_t nearSetSize,
+ size_t& farSetSize,
+ size_t& usedSetSize) :
dataset(dataset),
point(pointIndex),
scale(scale),
@@ -302,19 +304,20 @@
ComputeDistances(pointIndex, indices, distances, farSetSize);
}
-template<typename StatisticType>
-CoverTree<StatisticType>::~CoverTree()
+template<typename MetricType, typename StatisticType>
+CoverTree<MetricType, StatisticType>::~CoverTree()
{
// Delete each child.
for (size_t i = 0; i < children.size(); ++i)
delete children[i];
}
-template<typename StatisticType>
-size_t CoverTree<StatisticType>::SplitNearFar(arma::Col<size_t>& indices,
- arma::vec& distances,
- const double bound,
- const size_t pointSetSize)
+template<typename MetricType, typename StatisticType>
+size_t CoverTree<MetricType, StatisticType>::SplitNearFar(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const double bound,
+ const size_t pointSetSize)
{
// Sanity check; there is no guarantee that this condition will not be true.
// ...or is there?
@@ -362,8 +365,8 @@
}
// Returns the maximum distance between points.
-template<typename StatisticType>
-void CoverTree<StatisticType>::ComputeDistances(
+template<typename MetricType, typename StatisticType>
+void CoverTree<MetricType, StatisticType>::ComputeDistances(
const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
@@ -373,17 +376,18 @@
// modified.
for (size_t i = 0; i < pointSetSize; ++i)
{
- distances[i] = metric::LMetric<2>::Evaluate(dataset.col(pointIndex),
+ distances[i] = MetricType::Evaluate(dataset.col(pointIndex),
dataset.col(indices[i]));
}
}
-template<typename StatisticType>
-size_t CoverTree<StatisticType>::SortPointSet(arma::Col<size_t>& indices,
- arma::vec& distances,
- const size_t childFarSetSize,
- const size_t childUsedSetSize,
- const size_t farSetSize)
+template<typename MetricType, typename StatisticType>
+size_t CoverTree<MetricType, StatisticType>::SortPointSet(
+ arma::Col<size_t>& indices,
+ arma::vec& distances,
+ const size_t childFarSetSize,
+ const size_t childUsedSetSize,
+ const size_t farSetSize)
{
// We'll use low-level memcpy calls ourselves, just to ensure it's done
// quickly and the way we want it to be. Unfortunately this takes up more
More information about the mlpack-svn
mailing list