[mlpack-git] master: Merge remote-tracking branch 'upstream/master' into r-proj-tree (36f3966)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Aug 15 14:58:11 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a7794bde8082c691553152393e1e230098f5e920...87776e52cf9ead63fa458118a0cfd2fe46b23466
>---------------------------------------------------------------
commit 36f39668f8d9e3e6a169e4f726bc997eda323e88
Merge: 079cb80 95da0dd
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Mon Aug 15 21:58:11 2016 +0300
Merge remote-tracking branch 'upstream/master' into r-proj-tree
>---------------------------------------------------------------
36f39668f8d9e3e6a169e4f726bc997eda323e88
.appveyor.yml | 2 +-
HISTORY.md | 6 +
src/mlpack/core/tree/CMakeLists.txt | 13 +-
src/mlpack/core/tree/binary_space_tree.hpp | 1 +
.../tree/binary_space_tree/binary_space_tree.hpp | 18 +
.../binary_space_tree/binary_space_tree_impl.hpp | 48 +-
src/mlpack/core/tree/binary_space_tree/traits.hpp | 52 +-
src/mlpack/core/tree/binary_space_tree/typedef.hpp | 58 ++
.../tree/binary_space_tree/vantage_point_split.hpp | 116 +++
.../binary_space_tree/vantage_point_split_impl.hpp | 85 ++
src/mlpack/core/tree/cover_tree/traits.hpp | 6 -
src/mlpack/core/tree/hollow_ball_bound.hpp | 14 +-
src/mlpack/core/tree/hollow_ball_bound_impl.hpp | 169 ++--
src/mlpack/core/tree/hrectbound_impl.hpp | 168 +++-
src/mlpack/core/tree/rectangle_tree/traits.hpp | 12 -
src/mlpack/core/tree/tree_traits.hpp | 6 -
src/mlpack/core/tree/vantage_point_tree.hpp | 21 -
.../vantage_point_tree/dual_tree_traverser.hpp | 98 ---
.../dual_tree_traverser_impl.hpp | 237 ------
.../vantage_point_tree/single_tree_traverser.hpp | 63 --
.../single_tree_traverser_impl.hpp | 113 ---
src/mlpack/core/tree/vantage_point_tree/traits.hpp | 66 --
.../core/tree/vantage_point_tree/typedef.hpp | 76 --
.../vantage_point_tree/vantage_point_split.hpp | 157 ----
.../vantage_point_split_impl.hpp | 233 -----
.../tree/vantage_point_tree/vantage_point_tree.hpp | 508 -----------
.../vantage_point_tree/vantage_point_tree_impl.hpp | 937 ---------------------
src/mlpack/methods/lsh/lsh_search.hpp | 3 +-
.../neighbor_search/neighbor_search_rules_impl.hpp | 66 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 1 -
.../sort_policies/furthest_neighbor_sort.hpp | 25 +
.../sort_policies/nearest_neighbor_sort.hpp | 20 +
.../preprocess/preprocess_describe_main.cpp | 42 +-
.../range_search/range_search_rules_impl.hpp | 42 +-
src/mlpack/methods/range_search/rs_model.hpp | 1 -
src/mlpack/tests/vantage_point_tree_test.cpp | 82 +-
36 files changed, 729 insertions(+), 2836 deletions(-)
diff --cc src/mlpack/core/tree/CMakeLists.txt
index 45b494f,01e6394..127b2c2
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@@ -14,12 -14,10 +14,14 @@@ set(SOURCE
binary_space_tree/mean_split_impl.hpp
binary_space_tree/midpoint_split.hpp
binary_space_tree/midpoint_split_impl.hpp
+ binary_space_tree/rp_tree_max_split.hpp
+ binary_space_tree/rp_tree_max_split_impl.hpp
+ binary_space_tree/rp_tree_mean_split.hpp
+ binary_space_tree/rp_tree_mean_split_impl.hpp
binary_space_tree/single_tree_traverser.hpp
binary_space_tree/single_tree_traverser_impl.hpp
+ binary_space_tree/vantage_point_split.hpp
+ binary_space_tree/vantage_point_split_impl.hpp
binary_space_tree/traits.hpp
binary_space_tree/typedef.hpp
bounds.hpp
diff --cc src/mlpack/core/tree/binary_space_tree.hpp
index d0574cd,ba54981..8e3505d
--- a/src/mlpack/core/tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree.hpp
@@@ -11,8 -11,7 +11,9 @@@
#include "bounds.hpp"
#include "binary_space_tree/midpoint_split.hpp"
#include "binary_space_tree/mean_split.hpp"
+ #include "binary_space_tree/vantage_point_split.hpp"
+#include "binary_space_tree/rp_tree_max_split.hpp"
+#include "binary_space_tree/rp_tree_mean_split.hpp"
#include "binary_space_tree/binary_space_tree.hpp"
#include "binary_space_tree/single_tree_traverser.hpp"
#include "binary_space_tree/single_tree_traverser_impl.hpp"
diff --cc src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index 8ab2a68,f083195..31f8c04
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@@ -478,39 -476,22 +478,57 @@@ class BinarySpaceTre
SplitType<BoundType<MetricType>, MatType>& splitter);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const typename Split::SplitInfo& splitInfo);
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return a list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const typename Split::SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew);
++
++ /**
+ * Update the bound of the current node. This method does not take into
+ * account bound-specific properties.
+ *
+ * @param boundToUpdate The bound to update.
+ */
+ template<typename BoundType2>
+ void UpdateBound(BoundType2& boundToUpdate);
+
+ /**
+ * Update the bound of the current node. This method is designed for
+ * HollowBallBound only.
+ *
+ * @param boundToUpdate The bound to update.
+ */
+ void UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate);
+
protected:
/**
* A default constructor. This is meant to only be used with
diff --cc src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index 40f0860,9565a95..95403c2
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@@ -762,112 -747,40 +759,153 @@@ template<typename MetricType
template<typename BoundMetricType, typename...> class BoundType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const typename Split::SplitInfo& splitInfo)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // splitVal should be on the left side of the matrix, and the points greater
+ // than splitVal should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+ left++;
+ while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
+ (left <= right) && (right > 0))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
+ (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
+ (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType,
+ template<typename SplitBoundType, typename SplitMatType>
+ class SplitType>
+size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ SplitType>::PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const typename Split::SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+{
+ // This method modifies the input dataset. We loop both from the left and
+ // right sides of the points contained in this node. The points less than
+ // splitVal should be on the left side of the matrix, and the points greater
+ // than splitVal should be on the right side of the matrix.
+ size_t left = begin;
+ size_t right = begin + count - 1;
+
+ // First half-iteration of the loop is out here because the termination
+ // condition is in the middle.
+ while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
+ left++;
+ while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
+ (left <= right) && (right > 0))
+ right--;
+
+ while (left <= right)
+ {
+ // Swap columns.
+ data.swap_cols(left, right);
+
+ // Update the indices for what we changed.
+ size_t t = oldFromNew[left];
+ oldFromNew[left] = oldFromNew[right];
+ oldFromNew[right] = t;
+
+ // See how many points on the left are correct. When they are correct,
+ // increase the left counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it later.
+ while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
+ (left <= right))
+ left++;
+
+ // Now see how many points on the right are correct. When they are correct,
+ // decrease the right counter accordingly. When we encounter one that isn't
+ // correct, stop. We will switch it with the wrong point we found in the
+ // previous loop.
+ while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
+ (left <= right))
+ right--;
+ }
+
+ Log::Assert(left == right + 1);
+
+ return left;
+}
+
++template<typename MetricType,
++ typename StatisticType,
++ typename MatType,
++ template<typename BoundMetricType, typename...> class BoundType,
++ template<typename SplitBoundType, typename SplitMatType>
++ class SplitType>
++
+ template<typename BoundType2>
+ void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+ UpdateBound(BoundType2& boundToUpdate)
+ {
+ if (count > 0)
+ boundToUpdate |= dataset->cols(begin, begin + count - 1);
+ }
+
+ template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType,
+ template<typename SplitBoundType, typename SplitMatType>
+ class SplitType>
+ void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
+ UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate)
+ {
+ if (!parent)
+ {
+ if (count > 0)
+ boundToUpdate |= dataset->cols(begin, begin + count - 1);
+ return;
+ }
+
+ if (parent->left != NULL && parent->left != this)
+ {
+ boundToUpdate.HollowCenter() = parent->left->bound.Center();
+ boundToUpdate.InnerRadius() = std::numeric_limits<ElemType>::max();
+ }
+
+ if (count > 0)
+ boundToUpdate |= dataset->cols(begin, begin + count - 1);
+ }
+
// Default constructor (private), for boost::serialization.
template<typename MetricType,
typename StatisticType,
diff --cc src/mlpack/core/tree/binary_space_tree/traits.hpp
index 198cb10,9f469e5..ce02ee6
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@@ -42,107 -42,6 +42,99 @@@ class TreeTraits<BinarySpaceTree<Metric
static const bool FirstPointIsCentroid = false;
/**
- * There is no guarantee that the first point of the first sibling is the
- * centroid of other siblings.
- */
- static const bool FirstSiblingFirstPointIsCentroid = false;
-
- /**
+ * The tree has not got duplicated points.
+ */
+ static const bool HasDuplicatedPoints = false;
+
+ /**
+ * Points are not contained at multiple levels of the binary space tree.
+ */
+ static const bool HasSelfChildren = false;
+
+ /**
+ * Points are rearranged during building of the tree.
+ */
+ static const bool RearrangesDataset = true;
+
+ /**
+ * This is always a binary tree.
+ */
+ static const bool BinaryTree = true;
+};
+
++/**
++ * This is a specialization of the TreeType class to the max-split random
++ * projection tree. The only difference with general BinarySpaceTree is that the
++ * tree can have overlapping children.
++ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType>
+class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ RPTreeMaxSplit>>
+{
+ public:
+ /**
+ * Children of a random projection tree node may overlap.
+ */
+ static const bool HasOverlappingChildren = true;
+
+ /**
+ * The tree has not got duplicated points.
+ */
+ static const bool HasDuplicatedPoints = false;
+
+ /**
+ * There is no guarantee that the first point in a node is its centroid.
+ */
+ static const bool FirstPointIsCentroid = false;
+
+ /**
- * There is no guarantee that the first point of the first sibling is the
- * centroid of other siblings.
- */
- static const bool FirstSiblingFirstPointIsCentroid = false;
-
- /**
+ * Points are not contained at multiple levels of the binary space tree.
+ */
+ static const bool HasSelfChildren = false;
+
+ /**
+ * Points are rearranged during building of the tree.
+ */
+ static const bool RearrangesDataset = true;
+
+ /**
+ * This is always a binary tree.
+ */
+ static const bool BinaryTree = true;
+};
+
++/**
++ * This is a specialization of the TreeType class to the mean-split random
++ * projection tree. The only difference with general BinarySpaceTree is that the
++ * tree can have overlapping children.
++ */
+template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename BoundMetricType, typename...> class BoundType>
+class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
+ RPTreeMeanSplit>>
+{
+ public:
+ /**
+ * Children of a random projection tree node may overlap.
+ */
+ static const bool HasOverlappingChildren = true;
+
+ /**
+ * The tree has not got duplicated points.
+ */
+ static const bool HasDuplicatedPoints = false;
+
+ /**
+ * There is no guarantee that the first point in a node is its centroid.
+ */
+ static const bool FirstPointIsCentroid = false;
+
+ /**
- * There is no guarantee that the first point of the first sibling is the
- * centroid of other siblings.
- */
- static const bool FirstSiblingFirstPointIsCentroid = false;
-
- /**
* Points are not contained at multiple levels of the binary space tree.
*/
static const bool HasSelfChildren = false;
@@@ -176,7 -75,23 +168,29 @@@ class TreeTraits<BinarySpaceTree<Metric
static const bool HasOverlappingChildren = true;
static const bool HasDuplicatedPoints = false;
static const bool FirstPointIsCentroid = false;
- static const bool FirstSiblingFirstPointIsCentroid = false;
+ static const bool HasSelfChildren = false;
+ static const bool RearrangesDataset = true;
+ static const bool BinaryTree = true;
+ };
+
++/**
++ * This is a specialization of the TreeType class to an arbitrary tree with
++ * HollowBallBound (currently only the vantage point tree is supported).
++ * The only difference with general BinarySpaceTree is that the tree can have
++ * overlapping children.
++ */
+ template<typename MetricType,
+ typename StatisticType,
+ typename MatType,
+ template<typename SplitBoundType, typename SplitMatType>
+ class SplitType>
+ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType,
+ bound::HollowBallBound, SplitType>>
+ {
+ public:
+ static const bool HasOverlappingChildren = true;
+ static const bool HasDuplicatedPoints = false;
+ static const bool FirstPointIsCentroid = false;
static const bool HasSelfChildren = false;
static const bool RearrangesDataset = true;
static const bool BinaryTree = true;
diff --cc src/mlpack/core/tree/binary_space_tree/typedef.hpp
index cad3af6,2ede96e..753ae15
--- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
@@@ -136,72 -136,63 +136,130 @@@ using MeanSplitBallTree = BinarySpaceTr
MeanSplit>;
/**
+ * The vantage point tree (which is also called the metric tree. Vantage point
+ * trees and metric trees were invented independently by Yianilos an Uhlmann) is
+ * a kind of the binary space tree. When recursively splitting nodes, the VPTree
+ * class selects the vantage point and splits the node according to the distance
+ * to this point. Thus, points that are closer to the vantage point form the
+ * inner subtree. Other points form the outer subtree. The vantage point is
+ * contained in the first (inner) node.
+ *
+ * This implementation differs from the original algorithms. Namely, vantage
+ * points are not contained in intermediate nodes. The tree has points only in
+ * the leaves of the tree.
+ *
+ * For more information, see the following papers.
+ *
+ * @code
+ * @inproceedings{yianilos1993vptrees,
+ * author = {Yianilos, Peter N.},
+ * title = {Data Structures and Algorithms for Nearest Neighbor Search in
+ * General Metric Spaces},
+ * booktitle = {Proceedings of the Fourth Annual ACM-SIAM Symposium on
+ * Discrete Algorithms},
+ * series = {SODA '93},
+ * year = {1993},
+ * isbn = {0-89871-313-7},
+ * pages = {311--321},
+ * numpages = {11},
+ * publisher = {Society for Industrial and Applied Mathematics},
+ * address = {Philadelphia, PA, USA}
+ * }
+ *
+ * @article{uhlmann1991metrictrees,
+ * author = {Jeffrey K. Uhlmann},
+ * title = {Satisfying general proximity / similarity queries with metric
+ * trees},
+ * journal = {Information Processing Letters},
+ * volume = {40},
+ * number = {4},
+ * pages = {175 - 179},
+ * year = {1991},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, VantagePointTree, VPTree
+ */
+ template<typename BoundType,
+ typename MatType = arma::mat>
+ using VPTreeSplit = VantagePointSplit<BoundType, MatType, 100>;
+
+ template<typename MetricType, typename StatisticType, typename MatType>
+ using VPTree = BinarySpaceTree<MetricType,
+ StatisticType,
+ MatType,
+ bound::HollowBallBound,
+ VPTreeSplit>;
+
++/**
+ * A max-split random projection tree. When recursively splitting nodes, the
+ * MaxSplitRPTree class selects a random hyperplane and splits a node by the
+ * hyperplane. The tree holds points in leaf nodes. In contrast to the k-d tree,
+ * children of a MaxSplitRPTree node may overlap.
+ *
+ * @code
+ * @inproceedings{dasgupta2008,
+ * author = {Dasgupta, Sanjoy and Freund, Yoav},
+ * title = {Random Projection Trees and Low Dimensional Manifolds},
+ * booktitle = {Proceedings of the Fortieth Annual ACM Symposium on Theory of
+ * Computing},
+ * series = {STOC '08},
+ * year = {2008},
+ * pages = {537--546},
+ * numpages = {10},
+ * publisher = {ACM},
+ * address = {New York, NY, USA},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
+
+template<typename MetricType, typename StatisticType, typename MatType>
+using MaxSplitRPTree = BinarySpaceTree<MetricType,
+ StatisticType,
+ MatType,
+ bound::HRectBound,
+ RPTreeMaxSplit>;
+
+/**
+ * A mean-split random projection tree. When recursively splitting nodes, the
+ * RPTree class may perform one of two different kinds of split.
+ * Depending on the diameter and the average distance between points, the node
+ * may be split by a random hyperplane or according to the distance from the
+ * mean point. The tree holds points in leaf nodes. In contrast to the k-d tree,
+ * children of a MaxSplitRPTree node may overlap.
+ *
+ * @code
+ * @inproceedings{dasgupta2008,
+ * author = {Dasgupta, Sanjoy and Freund, Yoav},
+ * title = {Random Projection Trees and Low Dimensional Manifolds},
+ * booktitle = {Proceedings of the Fortieth Annual ACM Symposium on Theory of
+ * Computing},
+ * series = {STOC '08},
+ * year = {2008},
+ * pages = {537--546},
+ * numpages = {10},
+ * publisher = {ACM},
+ * address = {New York, NY, USA},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
+template<typename MetricType, typename StatisticType, typename MatType>
+using RPTree = BinarySpaceTree<MetricType,
+ StatisticType,
+ MatType,
+ bound::HRectBound,
+ RPTreeMeanSplit>;
+
} // namespace tree
} // namespace mlpack
diff --cc src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
index 0000000,0542433..015d2d8
mode 000000,100644..100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
@@@ -1,0 -1,157 +1,116 @@@
+ /**
+ * @file vantage_point_split.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Definition of class VantagePointSplit, a class that splits a vantage point
+ * tree into two parts using the distance to a certain vantage point.
+ */
+ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
+ #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
+
+ #include <mlpack/core.hpp>
+
+ namespace mlpack {
+ namespace tree /** Trees and tree-building procedures. */ {
+
+ template<typename BoundType,
+ typename MatType = arma::mat,
+ size_t MaxNumSamples = 100>
+ class VantagePointSplit
+ {
+ public:
+ //! The matrix element type.
+ typedef typename MatType::elem_type ElemType;
+ //! The bounding shape type.
+ typedef typename BoundType::MetricType MetricType;
++ //! A struct that contains an information about the split.
++ struct SplitInfo
++ {
++ //! The vantage point.
++ arma::Col<ElemType> vantagePoint;
++ //! The median distance according to which the node will be split.
++ ElemType mu;
++ //! An instance of the MetricType class.
++ const MetricType* metric;
++
++ SplitInfo() :
++ mu(0),
++ metric(NULL)
++ { }
++
++ template<typename VecType>
++ SplitInfo(const MetricType& metric, const VecType& vantagePoint,
++ ElemType mu) :
++ vantagePoint(vantagePoint),
++ mu(mu),
++ metric(&metric)
++ { }
++ };
+
+ /**
+ * Split the node according to the distance to a vantage point.
+ *
- * @param bound The bound used by the tree.
- * @param data The dataset used by the tree.
++ * @param bound The bound used for this node.
++ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
- * @param splitCol The index at which the dataset is divided into two parts
- * after the rearrangement.
++ * @param splitInfo An information about the split. This information contains
++ * the vantage point and the median distance to the vantage point.
+ */
+ static bool SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
- size_t& splitCol);
++ SplitInfo& splitInfo);
+
+ /**
- * Split the node according to the distance to a vantage point.
++ * Indicates that a point should be assigned to the left subtree.
++ * This method returns true if a point should be assigned to the left subtree,
++ * i.e., if the distance from the point to the vantage point is less then the
++ * median value. Otherwise it returns false.
+ *
- * @param bound The bound used by the tree.
- * @param data The dataset used by the tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param splitCol The index at which the dataset is divided into two parts
- * after the rearrangement.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
++ * @param point The point that is being assigned.
++ * @param splitInfo An information about the split.
+ */
- static bool SplitNode(const BoundType& bound,
- MatType& data,
- const size_t begin,
- const size_t count,
- size_t& splitCol,
- std::vector<size_t>& oldFromNew);
++ template<typename VecType>
++ static bool AssignToLeftNode(const VecType& point,
++ const SplitInfo& splitInfo)
++ {
++ return (splitInfo.metric->Evaluate(splitInfo.vantagePoint, point) <
++ splitInfo.mu);
++ }
++
+ private:
+ /**
+ * Select the best vantage point, i.e., the point with the largest second
+ * moment of the distance from a number of random node points to the vantage
+ * point. Firstly this method selects no more than MaxNumSamples random
+ * points. Then it evaluates each point, i.e., calculates the corresponding
+ * second moment and selects the point with the largest moment. Each random
+ * point belongs to the node.
+ *
+ * @param metric The metric used by the tree.
+ * @param data The dataset used by the tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param vantagePoint The index of the vantage point in the dataset.
+ * @param mu The median value of distance form the vantage point to
+ * a number of random points.
+ */
+ static void SelectVantagePoint(const MetricType& metric,
+ const MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& vantagePoint,
+ ElemType& mu);
-
- /**
- * This method returns true if a point should be assigned to the left subtree,
- * i.e., if the distance from the point to the vantage point is less then the
- * median value. Otherwise it returns false.
- *
- * @param metric The metric used by the tree.
- * @param data The dataset used by the tree.
- * @param vantagePoint The vantage point.
- * @param point The point that is being assigned.
- * @param mu The median value.
- */
- template<typename VecType>
- static bool AssignToLeftSubtree(const MetricType& metric,
- const MatType& mat,
- const VecType& vantagePoint,
- const size_t point,
- const ElemType mu)
- {
- return (metric.Evaluate(vantagePoint, mat.col(point)) < mu);
- }
-
- /**
- * Perform split according to the median value and the vantage point.
- *
- * @param metric The metric used by the tree.
- * @param data The dataset used by the tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param vantagePoint The vantage point.
- * @param mu The median value.
- */
- template<typename VecType>
- static size_t PerformSplit(const MetricType& metric,
- MatType& data,
- const size_t begin,
- const size_t count,
- const VecType& vantagePoint,
- const ElemType mu);
-
- /**
- * Perform split according to the median value and the vantage point.
- *
- * @param metric The metric used by the tree.
- * @param data The dataset used by the tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param vantagePoint The vantage point.
- * @param mu The median value.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- */
- template<typename VecType>
- static size_t PerformSplit(const MetricType& metric,
- MatType& data,
- const size_t begin,
- const size_t count,
- const VecType& vantagePoint,
- const ElemType mu,
- std::vector<size_t>& oldFromNew);
+ };
+
+ } // namespace tree
+ } // namespace mlpack
+
+ // Include implementation.
+ #include "vantage_point_split_impl.hpp"
+
+ #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
diff --cc src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp
index 0000000,92e00b9..94ec63b
mode 000000,100644..100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split_impl.hpp
@@@ -1,0 -1,233 +1,85 @@@
+ /**
+ * @file vantage_point_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of class (VantagePointSplit) to split a vantage point
+ * tree according to the median value of the distance to a certain vantage point.
+ */
+ #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP
+ #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP
+
+ #include "vantage_point_split.hpp"
+ #include <mlpack/core/tree/bounds.hpp>
+
+ namespace mlpack {
+ namespace tree {
+
+ template<typename BoundType, typename MatType, size_t MaxNumSamples>
+ bool VantagePointSplit<BoundType, MatType, MaxNumSamples>::
+ SplitNode(const BoundType& bound, MatType& data, const size_t begin,
- const size_t count, size_t& splitCol)
++ const size_t count, SplitInfo& splitInfo)
+ {
+ ElemType mu = 0;
+ size_t vantagePointIndex;
+
+ // Find the best vantage point.
+ SelectVantagePoint(bound.Metric(), data, begin, count, vantagePointIndex, mu);
+
+ // If all points are equal, we can't split.
+ if (mu == 0)
+ return false;
+
- // The first point of the left child is centroid.
- data.swap_cols(begin, vantagePointIndex);
-
- arma::Col<ElemType> vantagePoint = data.col(begin);
- splitCol = PerformSplit(bound.Metric(), data, begin, count, vantagePoint, mu);
-
- assert(splitCol > begin);
- assert(splitCol < begin + count);
- return true;
-}
-
-template<typename BoundType, typename MatType, size_t MaxNumSamples>
-bool VantagePointSplit<BoundType, MatType, MaxNumSamples>::
-SplitNode(const BoundType& bound, MatType& data, const size_t begin,
- const size_t count, size_t& splitCol, std::vector<size_t>& oldFromNew)
-{
- ElemType mu = 0;
- size_t vantagePointIndex;
-
- // Find the best vantage point.
- SelectVantagePoint(bound.Metric(), data, begin, count, vantagePointIndex, mu);
-
- // If all points are equal, we can't split.
- if (mu == 0)
- return false;
-
- // The first point of the left child is centroid.
- data.swap_cols(begin, vantagePointIndex);
- const size_t t = oldFromNew[begin];
- oldFromNew[begin] = oldFromNew[vantagePointIndex];
- oldFromNew[vantagePointIndex] = t;
-
- arma::Col<ElemType> vantagePoint = data.col(begin);
-
- splitCol = PerformSplit(bound.Metric(), data, begin, count, vantagePoint, mu,
- oldFromNew);
++ splitInfo = SplitInfo(bound.Metric(), data.col(vantagePointIndex), mu);
+
+ assert(splitCol > begin);
+ assert(splitCol < begin + count);
+ return true;
+ }
+
+ template<typename BoundType, typename MatType, size_t MaxNumSamples>
-template <typename VecType>
-size_t VantagePointSplit<BoundType, MatType, MaxNumSamples>::PerformSplit(
- const MetricType& metric,
- MatType& data,
- const size_t begin,
- const size_t count,
- const VecType& vantagePoint,
- const ElemType mu)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points closer to
- // the vantage point should be on the left side of the matrix, and the farther
- // from the vantage point should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) &&
- (left <= right))
- left++;
-
- while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) &&
- (left <= right) && (right > 0))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while ((AssignToLeftSubtree(metric, data, vantagePoint, left, mu)) &&
- (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) &&
- (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename BoundType, typename MatType, size_t MaxNumSamples>
-template<typename VecType>
-size_t VantagePointSplit<BoundType, MatType, MaxNumSamples>::PerformSplit(
- const MetricType& metric,
- MatType& data,
- const size_t begin,
- const size_t count,
- const VecType& vantagePoint,
- const ElemType mu,
- std::vector<size_t>& oldFromNew)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points closer to
- // the vantage point should be on the left side of the matrix, and the farther
- // from the vantage point should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
-
- while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) &&
- (left <= right))
- left++;
-
- while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) &&
- (left <= right) && (right > 0))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // Update the indices for what we changed.
- size_t t = oldFromNew[left];
- oldFromNew[left] = oldFromNew[right];
- oldFromNew[right] = t;
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while (AssignToLeftSubtree(metric, data, vantagePoint, left, mu) &&
- (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((!AssignToLeftSubtree(metric, data, vantagePoint, right, mu)) &&
- (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename BoundType, typename MatType, size_t MaxNumSamples>
+ void VantagePointSplit<BoundType, MatType, MaxNumSamples>::
+ SelectVantagePoint(const MetricType& metric, const MatType& data,
+ const size_t begin, const size_t count, size_t& vantagePoint, ElemType& mu)
+ {
+ arma::uvec vantagePointCandidates;
+ arma::Col<ElemType> distances(MaxNumSamples);
+
+ // Get no more than max(MaxNumSamples, count) vantage point candidates
+ math::ObtainDistinctSamples(begin, begin + count, MaxNumSamples,
+ vantagePointCandidates);
+
+ ElemType bestSpread = 0;
+
+ arma::uvec samples;
+ // Evaluate each candidate
+ for (size_t i = 0; i < vantagePointCandidates.n_elem; i++)
+ {
+ // Get no more than min(MaxNumSamples, count) random samples
+ math::ObtainDistinctSamples(begin, begin + count, MaxNumSamples, samples);
+
+ // Calculate the second moment of the distance to the vantage point
+ // candidate using these random samples.
+ distances.set_size(samples.n_elem);
+
+ for (size_t j = 0; j < samples.n_elem; j++)
+ distances[j] = metric.Evaluate(data.col(vantagePointCandidates[i]),
+ data.col(samples[j]));
+
+ const ElemType spread = arma::sum(distances % distances) / samples.n_elem;
+
+ if (spread > bestSpread)
+ {
+ bestSpread = spread;
+ vantagePoint = vantagePointCandidates[i];
+ // Calculate the median value of the distance from the vantage point
+ // candidate to these samples.
+ mu = arma::median(distances);
+ }
+ }
+ assert(bestSpread > 0);
+ }
+
+ } // namespace tree
+ } // namespace mlpack
+
+ #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_IMPL_HPP
More information about the mlpack-git
mailing list