[mlpack-svn] r12585 - mlpack/trunk/src/mlpack/core/tree

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 30 18:04:00 EDT 2012


Author: rcurtin
Date: 2012-04-30 18:03:59 -0400 (Mon, 30 Apr 2012)
New Revision: 12585

Added:
   mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp
Modified:
   mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
   mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
   mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
Log:
Allow template parameter class to specify which point should be chosen as the
root node.  Document template parameters a little bit (but hopefully this will
be redone at some point).


Modified: mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/CMakeLists.txt	2012-04-30 22:03:59 UTC (rev 12585)
@@ -10,6 +10,7 @@
   bounds.hpp
   cover_tree.hpp
   cover_tree_impl.hpp
+  first_point_is_root.hpp
   hrectbound.hpp
   hrectbound_impl.hpp
   periodichrectbound.hpp

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree.hpp	2012-04-30 22:03:59 UTC (rev 12585)
@@ -9,6 +9,7 @@
 
 #include <mlpack/core.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
+#include "first_point_is_root.hpp"
 
 namespace mlpack {
 namespace tree {
@@ -68,8 +69,20 @@
  *   year = {2009}
  * }
  * @endcode
+ *
+ * The CoverTree class offers three template parameters; a custom metric type
+ * can be used with MetricType (this class defaults to the L2-squared metric).
+ * The root node's point can be chosen with the RootPointPolicy; by default, the
+ * FirstPointIsRoot policy is used, meaning the first point in the dataset is
+ * used.  The StatisticType policy allows you to define statistics which can be
+ * gathered during the creation of the tree.
+ *
+ * @tparam MetricType Metric type to use during tree construction.
+ * @tparam RootPointPolicy Determines which point to use as the root node.
+ * @tparam StatisticType Statistic to be used during tree creation.
  */
 template<typename MetricType = metric::LMetric<2>,
+         typename RootPointPolicy = FirstPointIsRoot,
          typename StatisticType = EmptyStatistic>
 class CoverTree
 {

Modified: mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-04-30 21:41:13 UTC (rev 12584)
+++ mlpack/trunk/src/mlpack/core/tree/cover_tree_impl.hpp	2012-04-30 22:03:59 UTC (rev 12585)
@@ -14,21 +14,26 @@
 namespace tree {
 
 // Create the cover tree.
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::CoverTree(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
     const arma::mat& dataset,
     const double expansionConstant) :
     dataset(dataset),
-    point(0),
+    point(RootPointPolicy::ChooseRoot(dataset)),
     expansionConstant(expansionConstant)
 {
   // Kick off the building.  Create the indices array and the distances array.
   arma::Col<size_t> indices = arma::linspace<arma::Col<size_t> >(1,
       dataset.n_cols - 1, dataset.n_cols - 1);
+  // This is now [1 2 3 4 ... n].  We must be sure that our point does not
+  // occur.
+  if (point != 0)
+    indices[point - 1] = 0; // Put 0 back into the set; remove what was there.
+
   arma::vec distances(dataset.n_cols - 1);
 
   // Build the initial distances.
-  ComputeDistances(0 /* default */, indices, distances, dataset.n_cols - 1);
+  ComputeDistances(point, indices, distances, dataset.n_cols - 1);
 
   // Now determine the scale factor of the root node.
   const double maxDistance = max(distances);
@@ -42,7 +47,7 @@
       dataset.n_cols - 1);
   size_t childFarSetSize = (dataset.n_cols - 1) - childNearSetSize;
   size_t childUsedSetSize = 0;
-  children.push_back(new CoverTree(dataset, expansionConstant, 0, scale - 1,
+  children.push_back(new CoverTree(dataset, expansionConstant, point, scale - 1,
       indices, distances, childNearSetSize, childFarSetSize, childUsedSetSize));
 
   size_t nearSetSize = (dataset.n_cols - 1) - childUsedSetSize;
@@ -117,8 +122,8 @@
   }
 }
 
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::CoverTree(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::CoverTree(
     const arma::mat& dataset,
     const double expansionConstant,
     const size_t pointIndex,
@@ -304,16 +309,16 @@
   ComputeDistances(pointIndex, indices, distances, farSetSize);
 }
 
-template<typename MetricType, typename StatisticType>
-CoverTree<MetricType, StatisticType>::~CoverTree()
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+CoverTree<MetricType, RootPointPolicy, StatisticType>::~CoverTree()
 {
   // Delete each child.
   for (size_t i = 0; i < children.size(); ++i)
     delete children[i];
 }
 
-template<typename MetricType, typename StatisticType>
-size_t CoverTree<MetricType, StatisticType>::SplitNearFar(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SplitNearFar(
     arma::Col<size_t>& indices,
     arma::vec& distances,
     const double bound,
@@ -365,8 +370,8 @@
 }
 
 // Returns the maximum distance between points.
-template<typename MetricType, typename StatisticType>
-void CoverTree<MetricType, StatisticType>::ComputeDistances(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+void CoverTree<MetricType, RootPointPolicy, StatisticType>::ComputeDistances(
     const size_t pointIndex,
     const arma::Col<size_t>& indices,
     arma::vec& distances,
@@ -381,8 +386,8 @@
   }
 }
 
-template<typename MetricType, typename StatisticType>
-size_t CoverTree<MetricType, StatisticType>::SortPointSet(
+template<typename MetricType, typename RootPointPolicy, typename StatisticType>
+size_t CoverTree<MetricType, RootPointPolicy, StatisticType>::SortPointSet(
     arma::Col<size_t>& indices,
     arma::vec& distances,
     const size_t childFarSetSize,

Added: mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp
===================================================================
--- mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/core/tree/first_point_is_root.hpp	2012-04-30 22:03:59 UTC (rev 12585)
@@ -0,0 +1,37 @@
+/**
+ * @file first_point_is_root.hpp
+ * @author Ryan Curtin
+ *
+ * A very simple policy for the cover tree; the first point in the dataset is
+ * chosen as the root of the cover tree.
+ */
+#ifndef __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+#define __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace tree {
+
+/**
+ * This class is meant to be used as a choice for the policy class
+ * RootPointPolicy of the CoverTree class.  This policy determines which point
+ * is used for the root node of the cover tree.  This particular implementation
+ * simply chooses the first point in the dataset as the root.  A more complex
+ * implementation might choose, for instance, the point with least maximum
+ * distance to other points (the closest to the "middle").
+ */
+class FirstPointIsRoot
+{
+ public:
+  /**
+   * Return the point to be used as the root point of the cover tree.  This just
+   * returns 0.
+   */
+  static size_t ChooseRoot(const arma::mat& /* dataset */) { return 0; }
+};
+
+}; // namespace tree
+}; // namespace mlpack
+
+#endif // __MLPACK_CORE_TREE_FIRST_POINT_IS_ROOT_HPP




More information about the mlpack-svn mailing list