[mlpack-git] master: Removed an overload of PerformSplit() for the UB tree. Added a function that implements the default binary split behaviour. Other minor fixes. (9bcd066)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Aug 26 15:24:08 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit 9bcd066e7a32021fd36778398703bf87552e5c31
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Fri Aug 26 22:24:08 2016 +0300
Removed an overload of PerformSplit() for the UB tree. Added a function that implements the default binary split behaviour. Other minor fixes.
>---------------------------------------------------------------
9bcd066e7a32021fd36778398703bf87552e5c31
src/mlpack/core/tree/CMakeLists.txt | 1 +
src/mlpack/core/tree/address.hpp | 30 ++--
.../tree/binary_space_tree/binary_space_tree.hpp | 80 -----------
.../binary_space_tree/binary_space_tree_impl.hpp | 156 +--------------------
.../core/tree/binary_space_tree/mean_split.hpp | 51 ++++++-
.../core/tree/binary_space_tree/midpoint_split.hpp | 51 ++++++-
.../tree/binary_space_tree/rp_tree_max_split.hpp | 50 ++++++-
.../tree/binary_space_tree/rp_tree_mean_split.hpp | 50 ++++++-
.../core/tree/binary_space_tree/ub_tree_split.hpp | 10 +-
.../tree/binary_space_tree/ub_tree_split_impl.hpp | 10 +-
.../tree/binary_space_tree/vantage_point_split.hpp | 44 ++++++
src/mlpack/core/tree/cellbound.hpp | 7 +-
src/mlpack/core/tree/cellbound_impl.hpp | 2 +-
13 files changed, 254 insertions(+), 288 deletions(-)
diff --git a/src/mlpack/core/tree/CMakeLists.txt b/src/mlpack/core/tree/CMakeLists.txt
index 3b05381..0e3c998 100644
--- a/src/mlpack/core/tree/CMakeLists.txt
+++ b/src/mlpack/core/tree/CMakeLists.txt
@@ -15,6 +15,7 @@ set(SOURCES
binary_space_tree/mean_split_impl.hpp
binary_space_tree/midpoint_split.hpp
binary_space_tree/midpoint_split_impl.hpp
+ binary_space_tree/perform_split.hpp
binary_space_tree/rp_tree_max_split.hpp
binary_space_tree/rp_tree_max_split_impl.hpp
binary_space_tree/rp_tree_mean_split.hpp
diff --git a/src/mlpack/core/tree/address.hpp b/src/mlpack/core/tree/address.hpp
index 4d4520c..72202e3 100644
--- a/src/mlpack/core/tree/address.hpp
+++ b/src/mlpack/core/tree/address.hpp
@@ -37,6 +37,16 @@ namespace addr {
* variables should be equal-sized and the type of the address should correspond
* to the type of the vector.
*
+ * The function maps each floating point coordinate to an equal-sized unsigned
+ * integer datatype in such a way that the transform preserves the ordering
+ * (i.e. lower coordinates correspond to lower integers). Thus, the mapping
+ * saves the exponent and the mantissa of each floating point value
+ * consequently, furthermore the exponent is stored before the mantissa. In the
+ * case of negative numbers the resulting integer value should be inverted.
+ * In the multi-dimensional case, after we transform the representation, we
+ * have to interleave the bits of the new representation across all the elements
+ * in the address vector.
+ *
* @param address The resulting address.
* @param point The point that is being translated to the address.
*/
@@ -122,6 +132,8 @@ void PointToAddress(AddressType& address, const VecType& point)
* variables should be equal-sized and the type of the address should correspond
* to the type of the vector.
*
+ * The function makes the backward transform to the function above.
+ *
* @param address An address to translate.
* @param point The point that corresponds to the address.
*/
@@ -201,9 +213,9 @@ void AddressToPoint(VecType& point, const AddressType& address)
template<typename AddressType1, typename AddressType2>
int CompareAddresses(const AddressType1& addr1, const AddressType2& addr2)
{
- static_assert(sizeof(typename AddressType1::elem_type) ==
- sizeof(typename AddressType2::elem_type), "We aren't able to compare "
- "adresses of distinct sizes");
+ static_assert(std::is_same<typename AddressType1::elem_type,
+ typename AddressType2::elem_type>::value == true, "We aren't able to "
+ "compare adresses of distinct types");
assert(addr1.n_elem == addr2.n_elem);
@@ -225,13 +237,13 @@ template<typename AddressType1, typename AddressType2, typename AddressType3>
bool Contains(const AddressType1& address, const AddressType2& loBound,
const AddressType3& hiBound)
{
- static_assert(sizeof(typename AddressType1::elem_type) ==
- sizeof(typename AddressType2::elem_type), "We aren't able to compare "
- "adresses of distinct sizes");
+ static_assert(std::is_same<typename AddressType1::elem_type,
+ typename AddressType2::elem_type>::value == true, "We aren't able to "
+ "compare adresses of distinct types");
- static_assert(sizeof(typename AddressType1::elem_type) ==
- sizeof(typename AddressType3::elem_type), "We aren't able to compare "
- "adresses of distinct sizes");
+ static_assert(std::is_same<typename AddressType1::elem_type,
+ typename AddressType3::elem_type>::value == true, "We aren't able to "
+ "compare adresses of distinct types");
assert(address.n_elem == loBound.n_elem);
assert(address.n_elem == hiBound.n_elem);
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
index bcbbbac..b024ee5 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree.hpp
@@ -478,86 +478,6 @@ class BinarySpaceTree
SplitType<BoundType<MetricType>, MatType>& splitter);
/**
- * Perform the split process according to the information about the
- * split.
- *
- * @param bound The bound used for this node.
- * @param data The dataset used by the binary space tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param splitInfo The information about the split.
- */
- template<typename SplitInfo>
- size_t PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const SplitInfo& splitInfo);
-
- /**
- * Perform the split process according to the information about the split and
- * return a list of changed indices.
- *
- * @param bound The bound used for this node.
- * @param data The dataset used by the binary space tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param splitInfo The information about the split.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- */
- template<typename SplitInfo>
- size_t PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const SplitInfo& splitInfo,
- std::vector<size_t>& oldFromNew);
-
- /**
- * An overload for the universal B tree. For the first time the function
- * rearranges the whole dataset. Next time the function only returns the split
- * column.
- *
- * @param bound The bound used for this node.
- * @param data The dataset used by the binary space tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param splitInfo The information about the split.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- */
- size_t PerformSplit(
- MatType& data,
- const size_t begin,
- const size_t count,
- const typename UBTreeSplit<BoundType<MetricType>,
- MatType>::SplitInfo& splitInfo);
-
- /**
- * An overload for the universal B tree. For the first time the function
- * rearranges the whole dataset. Next time the function only returns the split
- * column.
- *
- * @param bound The bound used for this node.
- * @param data The dataset used by the binary space tree.
- * @param begin Index of the starting point in the dataset that belongs to
- * this node.
- * @param count Number of points in this node.
- * @param splitInfo The information about the split.
- * @param oldFromNew Vector which will be filled with the old positions for
- * each new point.
- */
- size_t PerformSplit(
- MatType& data,
- const size_t begin,
- const size_t count,
- const typename UBTreeSplit<BoundType<MetricType>,
- MatType>::SplitInfo& splitInfo,
- std::vector<size_t>& oldFromNew);
-
- /**
* Update the bound of the current node. This method does not take into
* account bound-specific properties.
*
diff --git a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
index f2d96e7..7301a97 100644
--- a/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/binary_space_tree_impl.hpp
@@ -667,7 +667,7 @@ void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
// Perform the actual splitting. This will order the dataset such that
// points that belong to the left subtree are on the left of splitCol, and
// points from the right subtree are on the right side of splitCol.
- splitCol = PerformSplit(*dataset, begin, count, splitInfo);
+ splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
assert(splitCol > begin);
assert(splitCol < begin + count);
@@ -733,7 +733,8 @@ SplitNode(std::vector<size_t>& oldFromNew,
// Perform the actual splitting. This will order the dataset such that
// points that belong to the left subtree are on the left of splitCol, and
// points from the right subtree are on the right side of splitCol.
- splitCol = PerformSplit(*dataset, begin, count, splitInfo, oldFromNew);
+ splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
+ oldFromNew);
assert(splitCol > begin);
assert(splitCol < begin + count);
@@ -765,157 +766,6 @@ template<typename MetricType,
template<typename BoundMetricType, typename...> class BoundType,
template<typename SplitBoundType, typename SplitMatType>
class SplitType>
-template<typename SplitInfo>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
- SplitType>::PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const SplitInfo& splitInfo)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // splitVal should be on the left side of the matrix, and the points greater
- // than splitVal should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
- left++;
- while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
- (left <= right) && (right > 0))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
- (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
- (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
-template<typename SplitInfo>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
- SplitType>::PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const SplitInfo& splitInfo,
- std::vector<size_t>& oldFromNew)
-{
- // This method modifies the input dataset. We loop both from the left and
- // right sides of the points contained in this node. The points less than
- // splitVal should be on the left side of the matrix, and the points greater
- // than splitVal should be on the right side of the matrix.
- size_t left = begin;
- size_t right = begin + count - 1;
-
- // First half-iteration of the loop is out here because the termination
- // condition is in the middle.
- while (Split::AssignToLeftNode(data.col(left), splitInfo) && (left <= right))
- left++;
- while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
- (left <= right) && (right > 0))
- right--;
-
- while (left <= right)
- {
- // Swap columns.
- data.swap_cols(left, right);
-
- // Update the indices for what we changed.
- size_t t = oldFromNew[left];
- oldFromNew[left] = oldFromNew[right];
- oldFromNew[right] = t;
-
- // See how many points on the left are correct. When they are correct,
- // increase the left counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it later.
- while (Split::AssignToLeftNode(data.col(left), splitInfo) &&
- (left <= right))
- left++;
-
- // Now see how many points on the right are correct. When they are correct,
- // decrease the right counter accordingly. When we encounter one that isn't
- // correct, stop. We will switch it with the wrong point we found in the
- // previous loop.
- while ((!Split::AssignToLeftNode(data.col(right), splitInfo)) &&
- (left <= right))
- right--;
- }
-
- Log::Assert(left == right + 1);
-
- return left;
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
- SplitType>::PerformSplit(
- MatType& data,
- const size_t begin,
- const size_t count,
- const typename UBTreeSplit<BoundType<MetricType>,
- MatType>::SplitInfo& splitInfo)
-{
- return SplitType<BoundType<MetricType>, MatType>::PerformSplit(data, begin,
- count, splitInfo);
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
-size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
- SplitType>::PerformSplit(
- MatType& data,
- const size_t begin,
- const size_t count,
- const typename UBTreeSplit<BoundType<MetricType>,
- MatType>::SplitInfo& splitInfo,
- std::vector<size_t>& oldFromNew)
-{
- return SplitType<BoundType<MetricType>, MatType>::PerformSplit(data, begin,
- count, splitInfo, oldFromNew);
-}
-
-template<typename MetricType,
- typename StatisticType,
- typename MatType,
- template<typename BoundMetricType, typename...> class BoundType,
- template<typename SplitBoundType, typename SplitMatType>
- class SplitType>
template<typename BoundType2>
void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
UpdateBound(BoundType2& boundToUpdate)
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index d3979d8..4c07627 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -10,6 +10,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
#include <mlpack/core.hpp>
+#include "perform_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -23,13 +24,6 @@ template<typename BoundType, typename MatType = arma::mat>
class MeanSplit
{
public:
- /**
- * Indicates that this class does not perform the actual splitting i.e.
- * it does not reorder the dataset. If this variable is false, the class
- * finds the partition and reorders the dataset.
- */
- static constexpr bool NeedRearrangeDataset = true;
-
//! An information about the partition.
struct SplitInfo
{
@@ -59,6 +53,49 @@ class MeanSplit
SplitInfo& splitInfo);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo)
+ {
+ return split::PerformSplit<MatType, MeanSplit>(data, begin, count,
+ splitInfo);
+ }
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return the list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+ {
+ return split::PerformSplit<MatType, MeanSplit>(data, begin, count,
+ splitInfo, oldFromNew);
+ }
+
+ /**
* Indicates that a point should be assigned to the left subtree.
*
* @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index 522d298..66c3043 100644
--- a/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -11,6 +11,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_HPP
#include <mlpack/core.hpp>
+#include "perform_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -24,13 +25,6 @@ template<typename BoundType, typename MatType = arma::mat>
class MidpointSplit
{
public:
- /**
- * Indicates that this class does not perform the actual splitting i.e.
- * it does not reorder the dataset. If this variable is false, the class
- * finds the partition and reorders the dataset.
- */
- static constexpr bool NeedRearrangeDataset = true;
-
//! A struct that contains an information about the split.
struct SplitInfo
{
@@ -59,6 +53,49 @@ class MidpointSplit
SplitInfo& splitInfo);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo)
+ {
+ return split::PerformSplit<MatType, MidpointSplit>(data, begin, count,
+ splitInfo);
+ }
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return the list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+ {
+ return split::PerformSplit<MatType, MidpointSplit>(data, begin, count,
+ splitInfo, oldFromNew);
+ }
+
+ /**
* Indicates that a point should be assigned to the left subtree.
*
* @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
index f49e2a2..d556cb5 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
@@ -9,6 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MAX_SPLIT_HPP
#include <mlpack/core.hpp>
+#include "perform_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -28,12 +29,6 @@ class RPTreeMaxSplit
public:
//! The element type held by the matrix type.
typedef typename MatType::elem_type ElemType;
- /**
- * Indicates that this class does not perform the actual splitting i.e.
- * it does not reorder the dataset. If this variable is false, the class
- * finds the partition and reorders the dataset.
- */
- static constexpr bool NeedRearrangeDataset = true;
//! An information about the partition.
struct SplitInfo
{
@@ -61,6 +56,49 @@ class RPTreeMaxSplit
SplitInfo& splitInfo);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo)
+ {
+ return split::PerformSplit<MatType, RPTreeMaxSplit>(data, begin, count,
+ splitInfo);
+ }
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return the list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+ {
+ return split::PerformSplit<MatType, RPTreeMaxSplit>(data, begin, count,
+ splitInfo, oldFromNew);
+ }
+
+ /**
* Indicates that a point should be assigned to the left subtree.
*
* @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
index f54645c..64bb10b 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
@@ -10,6 +10,7 @@
#include <mlpack/core.hpp>
#include "rp_tree_max_split.hpp"
+#include "perform_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -28,12 +29,6 @@ class RPTreeMeanSplit
public:
//! The element type held by the matrix type.
typedef typename MatType::elem_type ElemType;
- /**
- * Indicates that this class does not perform the actual splitting i.e.
- * it does not reorder the dataset. If this variable is false, the class
- * finds the partition and reorders the dataset.
- */
- static constexpr bool NeedRearrangeDataset = true;
//! An information about the partition.
struct SplitInfo
{
@@ -67,6 +62,49 @@ class RPTreeMeanSplit
SplitInfo& splitInfo);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo)
+ {
+ return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
+ splitInfo);
+ }
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return the list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+ {
+ return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
+ splitInfo, oldFromNew);
+ }
+
+ /**
* Indicates that a point should be assigned to the left subtree.
*
* @param point The point that is being assigned.
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
index f7291ca..c0f71b7 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
@@ -23,13 +23,6 @@ class UBTreeSplit
uint32_t,
uint64_t>::type AddressElemType;
- /**
- * This class performs the actual splitting i.e. it reorders the dataset.
- * This variable should be equal to true if the class does not perform the
- * actual splitting.
- */
- static constexpr bool NeedRearrangeDataset = false;
-
//! An information about the partition.
struct SplitInfo
{
@@ -69,7 +62,8 @@ class UBTreeSplit
const SplitInfo& splitInfo);
/**
- * Rearrange the dataset according to the addresses.
+ * Rearrange the dataset according to the addresses and return the list
+ * of changed indices.
*
* @param data The dataset used by the binary space tree.
* @param begin Index of the starting point in the dataset that belongs to
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index c3bc7e1..51beba8 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -204,18 +204,14 @@ size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
if (splitInfo.addresses)
{
std::vector<size_t> newFromOld(data.n_cols);
- std::vector<size_t> newToOld(data.n_cols);
for (size_t i = 0; i < splitInfo.addresses->size(); i++)
- {
newFromOld[i] = i;
- newToOld[i] = i;
- }
for (size_t i = 0; i < splitInfo.addresses->size(); i++)
{
size_t index = (*splitInfo.addresses)[i].second;
- size_t oldI = newToOld[i];
+ size_t oldI = oldFromNew[i];
size_t newIndex = newFromOld[index];
data.swap_cols(i, newFromOld[index]);
@@ -224,10 +220,6 @@ size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
newFromOld[index] = i;
newFromOld[oldI] = tmp;
- tmp = newToOld[i];
- newToOld[i] = newToOld[newIndex];
- newToOld[newIndex] = tmp;
-
tmp = oldFromNew[i];
oldFromNew[i] = oldFromNew[newIndex];
oldFromNew[newIndex] = tmp;
diff --git a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
index 015d2d8..ef85ec7 100644
--- a/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/vantage_point_split.hpp
@@ -9,6 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_VANTAGE_POINT_SPLIT_HPP
#include <mlpack/core.hpp>
+#include "perform_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -65,6 +66,49 @@ class VantagePointSplit
SplitInfo& splitInfo);
/**
+ * Perform the split process according to the information about the
+ * split.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo)
+ {
+ return split::PerformSplit<MatType, VantagePointSplit>(data, begin, count,
+ splitInfo);
+ }
+
+ /**
+ * Perform the split process according to the information about the split and
+ * return the list of changed indices.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitInfo The information about the split.
+ * @param oldFromNew Vector which will be filled with the old positions for
+ * each new point.
+ */
+ static size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const SplitInfo& splitInfo,
+ std::vector<size_t>& oldFromNew)
+ {
+ return split::PerformSplit<MatType, VantagePointSplit>(data, begin, count,
+ splitInfo, oldFromNew);
+ }
+
+ /**
* Indicates that a point should be assigned to the left subtree.
* This method returns true if a point should be assigned to the left subtree,
* i.e., if the distance from the point to the vantage point is less then the
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index a653c6e..8841f8b 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -270,7 +270,8 @@ class CellBound
const arma::Col<ElemType>& hiCorner,
const MatType& data);
/**
- * Initialize all subrectangles that touches the lower address.
+ * Initialize all subrectangles that touches the lower address. This function
+ * should be called before InitLowerBound().
*
* @param numEqualBits The number of equal leading bits of the lower address
* and the high address.
@@ -278,8 +279,10 @@ class CellBound
*/
template<typename MatType>
void InitHighBound(size_t numEqualBits, const MatType& data);
+
/**
- * Initialize all subrectangles that touches the high address.
+ * Initialize all subrectangles that touches the high address. This function
+ * should be called after InitHighBound().
*
* @param numEqualBits The number of equal leading bits of the lower address
* and the high address.
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index ebb9a34..ba18379 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -328,7 +328,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits,
// We ran out of the limit of hyperrectangles. In that case we enlare
// the last hyperrectangle.
- if (numCorners >= maxNumBounds / 2)
+ if (numCorners >= maxNumBounds - numBounds)
tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
}
More information about the mlpack-git
mailing list