[mlpack-svn] r12585 - mlpack/trunk/src/mlpack/core/tree
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 30 18:04:00 EDT 2012
Author: rcurtin
Date: 2012-04-30 18:03:59 -0400 (Mon, 30 Apr 2012)
New Revision: 12585
Added:
mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp
Modified:
mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Allow template parameter class to specify which point should be chosen as the
root node. Document template parameters a little bit (but hopefully this will
be redone at some point).
Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt 2012-04-30 22:03:59 UTC (rev 12585)
@@ -10,6 +10,7 @@
bounds.hpp
cover_tree.hpp
cover_tree_impl.hpp
+ first_point_is_root.hpp
hrectbound.hpp
hrectbound_impl.hpp
periodichrectbound.hpp
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp 2012-04-30 22:03:59 UTC (rev 12585)
@@ -9,6 +9,7 @@
#include <mlpack/core.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include "first_point_is_root.hpp"
namespace mlpack {
namespace tree {
@@ -68,8 +69,20 @@
* year = {2009}
* }
* @endcode
+ *
+ * The CoverTree class offers three template parameters; a custom metric type
+ * can be used with MetricType (this class defaults to the L2-squared metric).
+ * The root node's point can be chosen with the RootPointPolicy; by default, the
+ * FirstPointIsRoot policy is used, meaning the first point in the dataset is
+ * used. The StatisticType policy allows you to define statistics which can be
+ * gathered during the creation of the tree.
+ *
+ * @tparam MetricType Metric type to use during tree construction.
+ * @tparam RootPointPolicy Determines which point to use as the root node.
+ * @tparam StatisticType Statistic to be used during tree creation.
*/
template<typename MetricType = metric::LMetric<2>,
+ typename RootPointPolicy = FirstPointIsRoot,
typename StatisticType = EmptyStatistic>
class CoverTree
{
Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp 2012-04-30 22:03:59 UTC (rev 12585)
@@ -14,21 +14,26 @@
namespace tree {
// Create the cover tree.
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::CoverTree(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
const arma::mat& dataset,
const double expansionConstant) :
dataset(dataset),
- point(0),
+ point(RootPointPolicy::ChooseRoot(dataset)),
expansionConstant(expansionConstant)
{
// Kick off the building. Create the indices array and the distances array.
arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
dataset.n_cols - 1, dataset.n_cols - 1);
+ // This is now [1 2 3 4 ... n]. We must be sure that our point does not
+ // occur.
+ if (point != 0)
+ indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
arma::vec distances(dataset.n_cols - 1);
// Build the initial distances.
- ComputeDistances(0 /* default */, indices, distances, dataset.n_cols - 1);
+ ComputeDistances(point, indices, distances, dataset.n_cols - 1);
// Now determine the scale factor of the root node.
const double maxDistance = max(distances);
@@ -42,7 +47,7 @@
dataset.n_cols - 1);
size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
size_t childUsedSetSize = 0;
- children.push_back(new CoverTree(dataset, expansionConstant, 0, scale - 1,
+ children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize));
size_t nearSetSize = (dataset.n_cols - 1) - childUsedSetSize;
@@ -117,8 +122,8 @@
}
}
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::CoverTree(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
const arma::mat& dataset,
const double expansionConstant,
const size_t pointIndex,
@@ -304,16 +309,16 @@
ComputeDistances(pointIndex, indices, distances, farSetSize);
}
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::~CoverTree()
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
{
// Delete each child.
for (size_t i = 0; i < children.size(); ++i)
delete children[i];
}
-template<typename MetricType, typename StatisticType>
-size_t CoverTree<MetricType, StatisticType>::SplitNearFar(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
arma::Col<size_t>& indices,
arma::vec& distances,
const double bound,
@@ -365,8 +370,8 @@
}
// Returns the maximum distance between points.
-template<typename MetricType, typename StatisticType>
-void CoverTree<MetricType, StatisticType>::ComputeDistances(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
const size_t pointIndex,
const arma::Col<size_t>& indices,
arma::vec& distances,
@@ -381,8 +386,8 @@
}
}
-template<typename MetricType, typename StatisticType>
-size_t CoverTree<MetricType, StatisticType>::SortPointSet(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
arma::Col<size_t>& indices,
arma::vec& distances,
const size_t childFarSetSize,
Added: mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp 2012-04-30 22:03:59 UTC (rev 12585)
@@ -0,0 +1,37 @@
+/**
+ * @file first_point_is_root.hpp
+ * @author Ryan Curtin
+ *
+ * A very simple policy for the cover tree; the first point in the dataset is
+ * chosen as the root of the cover tree.
+ */
+#ifndef __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+#define __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * This class is meant to be used as a choice for the policy class
+ * RootPointPolicy of the CoverTree class. This policy determines which point
+ * is used for the root node of the cover tree. This particular implementation
+ * simply chooses the first point in the dataset as the root. A more complex
+ * implementation might choose, for instance, the point with least maximum
+ * distance to other points (the closest to the "middle").
+ */
+class FirstPointIsRoot
+{
+ public:
+ /**
+ * Return the point to be used as the root point of the cover tree. This just
+ * returns 0.
+ */
+ static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
More information about the mlpack-svn
mailing list