[mlpack-git] master: Fixed various errors in the implementation of the bound and the calculation of addressed. (8c5a97d)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Aug 29 12:02:04 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit 8c5a97dcb1641ae6c98edc70426fb19f5cd7cb79
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Mon Aug 8 20:04:46 2016 +0300
Fixed various errors in the implementation of the bound and the calculation of addressed.
>---------------------------------------------------------------
8c5a97dcb1641ae6c98edc70426fb19f5cd7cb79
src/mlpack/core/tree/address.hpp | 16 ++++-
src/mlpack/core/tree/binary_space_tree/traits.hpp | 6 ++
src/mlpack/core/tree/binary_space_tree/typedef.hpp | 29 ++++++++-
.../core/tree/binary_space_tree/ub_tree_split.hpp | 34 +++++++++--
.../tree/binary_space_tree/ub_tree_split_impl.hpp | 29 ++++++++-
src/mlpack/core/tree/cellbound.hpp | 68 ++++++++++++++++++++--
src/mlpack/core/tree/cellbound_impl.hpp | 65 +++++++++++++++++----
src/mlpack/tests/ub_tree_test.cpp | 10 +++-
8 files changed, 231 insertions(+), 26 deletions(-)
diff --git a/src/mlpack/core/tree/address.hpp b/src/mlpack/core/tree/address.hpp
index 276cedb..6f4bdd5 100644
--- a/src/mlpack/core/tree/address.hpp
+++ b/src/mlpack/core/tree/address.hpp
@@ -115,14 +115,21 @@ void AddressToPoint(VecType& point, const AddressType& address)
(order - 1 - i));
}
-
for (size_t i = 0; i < rearrangedAddress.n_elem; i++)
{
bool sgn = rearrangedAddress(i) & ((AddressElemType) 1 << (order - 1));
+ if (!sgn)
+ {
+ rearrangedAddress(i) = ((AddressElemType) 1 << (order - 1)) - 1 -
+ rearrangedAddress(i);
+ }
+
// Extract the mantissa.
AddressElemType tmp = (AddressElemType) 1 << numMantBits;
AddressElemType mantissa = rearrangedAddress(i) & (tmp - 1);
+ if (mantissa == 0)
+ mantissa = 1;
VecElemType normalizedVal = (VecElemType) mantissa / tmp;
@@ -136,6 +143,13 @@ void AddressToPoint(VecType& point, const AddressType& address)
e += std::numeric_limits<VecElemType>::min_exponent;
point(i) = std::ldexp(normalizedVal, e);
+ if (std::isinf(point(i)))
+ {
+ if (point(i) > 0)
+ point(i) = std::numeric_limits<VecElemType>::max();
+ else
+ point(i) = std::numeric_limits<VecElemType>::lowest();
+ }
}
}
diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp
index cd6512c..9042420 100644
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@ -80,6 +80,12 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType,
static const bool BinaryTree = true;
};
+/**
+ * This is a specialization of the TreeType class to the UBTree tree type.
+ * The only difference with general BinarySpaceTree is that UBTree can have
+ * overlapping children.
+ * See mlpack/core/tree/tree_traits.hpp for more information.
+ */
template<typename MetricType,
typename StatisticType,
typename MatType,
diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
index 77f986d..203d739 100644
--- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
@@ -135,7 +135,34 @@ using MeanSplitBallTree = BinarySpaceTree<MetricType,
bound::BallBound,
MeanSplit>;
-
+/**
+ * The Universal B-tree. When recursively splitting nodes, the class
+ * calculates addresses of all points and splits each node according to the
+ * median address. Children nodes may overlap since the implementation
+ * of a tighter bound requires a lot of arithmetic operations. In order to get
+ * a tighter bound increase the CellBound::maxNumBounds constant.
+ *
+ * @code
+ * @inproceedings{bayer1997,
+ * author = {Bayer, Rudolf},
+ * title = {The Universal B-Tree for Multidimensional Indexing: General
+ * Concepts},
+ * booktitle = {Proceedings of the International Conference on Worldwide
+ * Computing and Its Applications},
+ * series = {WWCA '97},
+ * year = {1997},
+ * isbn = {3-540-63343-X},
+ * pages = {198--209},
+ * numpages = {12},
+ * publisher = {Springer-Verlag},
+ * address = {London, UK, UK},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
template<typename MetricType, typename StatisticType, typename MatType>
using UBTree = BinarySpaceTree<MetricType,
StatisticType,
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
index 93b2ea1..85f5013 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
@@ -1,5 +1,9 @@
/**
* @file ub_tree_split.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Definition of UBTreeSplit, a class that splits the space according
+ * to the median address of points contained in the node.
*/
#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
@@ -14,10 +18,23 @@ template<typename BoundType, typename MatType = arma::mat>
class UBTreeSplit
{
public:
+ //! The type of a one-dimensional address.
typedef typename std::conditional<sizeof(typename MatType::elem_type) * CHAR_BIT <= 32,
uint32_t,
uint64_t>::type AddressElemType;
+ /**
+ * Split the node according to the median address of points contained in the
+ * node.
+ *
+ * @param bound The bound used for this node.
+ * @param data The dataset used by the binary space tree.
+ * @param begin Index of the starting point in the dataset that belongs to
+ * this node.
+ * @param count Number of points in this node.
+ * @param splitCol The index at which the dataset is divided into two parts
+ * after the rearrangement.
+ */
bool SplitNode(BoundType& bound,
MatType& data,
const size_t begin,
@@ -32,14 +49,22 @@ class UBTreeSplit
std::vector<size_t>& oldFromNew);
private:
-// arma::Mat<AddressElemType> addresses;
+ //! This vector contains addresses of all points in the dataset.
std::vector<std::pair<arma::Col<AddressElemType>, size_t>> addresses;
- template<typename VecType>
- arma::Col<AddressElemType> CalculateAddress(const VecType& point);
-
+ /**
+ * Calculate addresses for all points in the dataset.
+ *
+ * @param data The dataset used by the binary space tree.
+ */
void InitializeAddresses(const MatType& data);
+ /**
+ * Calculate addresses for all points in the dataset.
+ *
+ * @param data The dataset used by the binary space tree.
+ * @param count Number of points in this node.
+ */
void PerformSplit(MatType& data,
const size_t count);
@@ -47,6 +72,7 @@ class UBTreeSplit
const size_t count,
std::vector<size_t>& oldFromNew);
+ //! A comparator for sorting addresses.
static bool ComparePair(
const std::pair<arma::Col<AddressElemType>, size_t>& p1,
const std::pair<arma::Col<AddressElemType>, size_t>& p2)
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index e0804b4..9d5963e 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -1,5 +1,9 @@
/**
* @file ub_tree_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of UBTreeSplit, a class that splits a node according
+ * to the median address of points contained in the node.
*/
#ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
@@ -20,15 +24,25 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
if (begin == 0 && count == data.n_cols)
{
+ // Calculate all addresses.
InitializeAddresses(data);
+ // Probably this is not a good idea. Maybe it is better to get
+ // a number of distinct samples and find the median.
std::sort(addresses.begin(), addresses.end(), ComparePair);
+ // Rearrange dataset.
PerformSplit(data, count);
}
+ // The bound shouldn't contain too many subrectangles.
+ // In order to minimize the number of hyperrectangles we set last bits
+ // of the last address in the node to 1 and last bits of the first address
+ // in the next node to zero in such a way that the ordering is not
+ // disturbed.
if (begin + count < data.n_cols)
{
+ // Omit leading equal bits.
size_t row = 0;
arma::Col<AddressElemType>& lo = addresses[begin + count - 1].first;
const arma::Col<AddressElemType>& hi = addresses[begin + count].first;
@@ -46,6 +60,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
bit++;
+ // Replace insignificant bits.
if (bit == order)
{
bit = 0;
@@ -62,8 +77,15 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
for (; bit < order; bit++)
lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
}
+
+ // The bound shouldn't contain too many subrectangles.
+ // In order to minimize the number of hyperrectangles we set last bits
+ // of the first address in the next node to 0 and last bits of the last
+ // address in the previous node to 1 in such a way that the ordering is not
+ // disturbed.
if (begin > 0)
{
+ // Omit leading equal bits.
size_t row = 0;
const arma::Col<AddressElemType>& lo = addresses[begin - 1].first;
arma::Col<AddressElemType>& hi = addresses[begin].first;
@@ -81,6 +103,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
bit++;
+ // Replace insignificant bits.
if (bit == order)
{
bit = 0;
@@ -98,12 +121,15 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
}
+ // Set the minimum and the maximum addresses.
for (size_t k = 0; k < bound.Dim(); k++)
{
bound.LoAddress()[k] = addresses[begin].first[k];
bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
}
bound.UpdateAddressBounds();
+
+ // Since the dataset is sorted we can easily get the split column.
splitCol = begin + count / 2;
return true;
@@ -214,12 +240,11 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
template<typename BoundType, typename MatType>
void UBTreeSplit<BoundType, MatType>::InitializeAddresses(const MatType& data)
{
-// addresses.set_size(data.n_rows, data.n_cols);
addresses.resize(data.n_cols);
+ // Calculate all addresses.
for (size_t i = 0; i < data.n_cols; i++)
{
-// address.col(i) = CalculateAddress(data.col(i));
addresses[i].first.zeros(data.n_rows);
bound::addr::PointToAddress(addresses[i].first, data.col(i));
addresses[i].second = i;
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index 4097152..1fa51a4 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -1,6 +1,30 @@
/**
* @file cellbound.hpp
+ * @author Mikhail Lozhnikov
*
+ * Definition of the CellBound class. The class describes a bound that consists
+ * of a number of hyperrectangles. These hyperrectangles do not overlap each
+ * other. The bound is limited by an outer hyperrectangle and two addresses,
+ * the lower address and the high address. Thus, the bound contains all points
+ * included between the lower and the high addresses.
+ *
+ * The notion of addresses is described in the following paper.
+ * @code
+ * @inproceedings{bayer1997,
+ * author = {Bayer, Rudolf},
+ * title = {The Universal B-Tree for Multidimensional Indexing: General
+ * Concepts},
+ * booktitle = {Proceedings of the International Conference on Worldwide
+ * Computing and Its Applications},
+ * series = {WWCA '97},
+ * year = {1997},
+ * isbn = {3-540-63343-X},
+ * pages = {198--209},
+ * numpages = {12},
+ * publisher = {Springer-Verlag},
+ * address = {London, UK, UK},
+ * }
+ * @endcode
*/
#ifndef MLPACK_CORE_TREE_CELLBOUND_HPP
#define MLPACK_CORE_TREE_CELLBOUND_HPP
@@ -19,6 +43,8 @@ template<typename MetricType = metric::LMetric<2, true>,
class CellBound
{
public:
+ //! Depending on the precision of the tree element type, we may need to use
+ //! uint32_t or uint64_t.
typedef typename std::conditional<sizeof(ElemType) * CHAR_BIT <= 32,
uint32_t,
uint64_t>::type AddressElemType;
@@ -61,18 +87,22 @@ class CellBound
const math::RangeType<ElemType>& operator[](const size_t i) const
{ return bounds[i]; }
+ //! Get lower address.
arma::Col<AddressElemType>& LoAddress() { return loAddress; }
-
+ //! Modify lower address.
const arma::Col<AddressElemType>& LoAddress() const {return loAddress; }
+ //! Get high address.
arma::Col<AddressElemType>& HiAddress() { return hiAddress; }
-
+ //! Modify high address.
const arma::Col<AddressElemType>& HiAddress() const {return hiAddress; }
+ //! Get lower bound of each subrectangle.
const arma::Mat<ElemType>& LoBound() const { return loBound; }
-
+ //! Get high bound of each subrectangle.
const arma::Mat<ElemType>& HiBound() const { return hiBound; }
+ //! Get the number of subrectangles.
size_t NumBounds() const { return numBounds; }
//! Get the minimum width of the bound.
@@ -159,6 +189,10 @@ class CellBound
template<typename VecType>
bool Contains(const VecType& point) const;
+ /**
+ * Calculate the bounds of all subrectangles. You should set the lower and the
+ * high addresses.
+ */
void UpdateAddressBounds();
/**
@@ -173,24 +207,48 @@ class CellBound
void Serialize(Archive& ar, const unsigned int version);
private:
+ //! The precision of the tree element type.
static constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
+ //! Maximum number of subrectangles.
const size_t maxNumBounds = 10;
//! The dimensionality of the bound.
size_t dim;
//! The bounds for each dimension.
math::RangeType<ElemType>* bounds;
+ //! Lower bounds of subrectangles.
arma::Mat<ElemType> loBound;
+ //! High bounds of subrectangles.
arma::Mat<ElemType> hiBound;
+ //! The numbre of subrectangles.
size_t numBounds;
-
+ //! The lowest address that the bound may contain.
arma::Col<AddressElemType> loAddress;
+ //! The highest address that the bound may contain.
arma::Col<AddressElemType> hiAddress;
-
+ //! The minimal width of the outer rectangle.
ElemType minWidth;
+ /**
+ * Add a subrectangle to the bound.
+ *
+ * @param loCorner The lower corner of the subrectangle that is being added.
+ * @param hiCorner The high corner of the subrectangle that is being added.
+ */
void AddBound(const arma::Col<ElemType>& loCorner,
const arma::Col<ElemType>& hiCorner);
+ /**
+ * Initialize all subrectangles that touches the lower address.
+ *
+ * @param numEqualBits The number of equal leading bits of the lower address
+ * and the high address.
+ */
void InitHighBound(size_t numEqualBits);
+ /**
+ * Initialize all subrectangles that touches the high address.
+ *
+ * @param numEqualBits The number of equal leading bits of the lower address
+ * and the high address.
+ */
void InitLowerBound(size_t numEqualBits);
};
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 93df708..13f02e0 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -1,10 +1,9 @@
/**
* @file cellbound_impl.hpp
+ * @author Mikhail Lozhnikov
*
- * Implementation of hyper-rectangle bound policy class.
- * Template parameter Power is the metric to use; use 2 for Euclidean (L2).
- *
- * @experimental
+ * Implementation of the CellBound class. The class describes a bound that
+ * consists of a number of hyperrectangles.
*/
#ifndef MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
#define MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
@@ -179,14 +178,15 @@ void CellBound<MetricType, ElemType>::AddBound(
assert(loCorner.n_elem == dim);
assert(hiCorner.n_elem == dim);
+ // If the subrectangle is not contained entirely in the outer rectangle,
+ // we shrink it.
for (size_t k = 0; k < dim; k++)
{
- loBound(k, numBounds) = loCorner[k] +
- math::ClampNonNegative(bounds[k].Lo() - loCorner[k]);
+ loBound(k, numBounds) = std::max(loCorner[k], bounds[k].Lo());
- hiBound(k, numBounds) = bounds[k].Hi() -
- math::ClampNonNegative(bounds[k].Hi() - hiCorner[k]);
+ hiBound(k, numBounds) = std::min(bounds[k].Hi(), hiCorner[k]);
+ // This should never happen.
if (loBound(k, numBounds) > hiBound(k, numBounds))
return;
}
@@ -205,26 +205,37 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
assert(tmpHiAddress.n_elem > 0);
+ // We have to calculate the number of subrectangles since the maximum number
+ // of hyperrectangles is restricted.
size_t numCorners = 0;
for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
{
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // This hyperrectangle is not contained entirely in the bound.
+ // So, the number of hyperrectangles should be increased.
if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
numCorners++;
+ // We ran out of the limit of hyperrectangles. In that case we enlare
+ // the last hyperrectangle.
if (numCorners >= maxNumBounds / 2)
tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
}
size_t pos = order * tmpHiAddress.n_elem - 1;
+ // Find the last hyperrectangle and add it to the bound.
for ( ; pos > numEqualBits; pos--)
{
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // All last bits after pos of tmpHiAddress are equal to 1 and
+ // All last bits of tmpLoAddress (after pos) are equal to 0.
+ // Thus, tmpHiAddress corresponds to the high corner of the enlarged
+ // rectangle and tmpLoAddress corresponds to the lower corner.
if (!(tmpHiAddress[row] & ((AddressElemType) 1 << bit)))
{
addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -233,9 +244,11 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
AddBound(loCorner, hiCorner);
break;
}
+ // Nullify the bit that corresponds to this step.
tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
}
+ // Add the enlarged rectangle if we have not done that.
if (pos == numEqualBits)
{
addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -249,17 +262,22 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // The lower bound should correspond to this step.
tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
{
+ // This hyperrectangle is contained entirely in the bound and do not
+ // overlap with other hyperrectangles since loAddress is less than
+ // tmpLoAddress and tmpHiAddress is less that the lower addresses
+ // of hyperrectangles that we have added previously.
tmpHiAddress[row] ^= (AddressElemType) 1 << bit;
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
AddBound(loCorner, hiCorner);
}
-
+ // The high bound should correspond to this step.
tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
}
}
@@ -272,26 +290,37 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
+ // We have to calculate the number of subrectangles since the maximum number
+ // of hyperrectangles is restricted.
size_t numCorners = 0;
for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
{
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // This hyperrectangle is not contained entirely in the bound.
+ // So, the number of hyperrectangles should be increased.
if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
numCorners++;
+ // We ran out of the limit of hyperrectangles. In that case we enlare
+ // the last hyperrectangle.
if (numCorners >= maxNumBounds / 2)
tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
}
size_t pos = order * tmpHiAddress.n_elem - 1;
+ // Find the last hyperrectangle and add it to the bound.
for ( ; pos > numEqualBits; pos--)
{
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // All last bits after pos of tmpHiAddress are equal to 1 and
+ // All last bits of tmpLoAddress (after pos) are equal to 0.
+ // Thus, tmpHiAddress corresponds to the high corner of the enlarged
+ // rectangle and tmpLoAddress corresponds to the lower corner.
if (tmpLoAddress[row] & ((AddressElemType) 1 << bit))
{
addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -300,9 +329,12 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
AddBound(loCorner, hiCorner);
break;
}
+ // Enlarge the hyperrectangle at this step since it is contained
+ // entirely in the bound.
tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
}
+ // Add the enlarged rectangle if we have not done that.
if (pos == numEqualBits)
{
addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -316,17 +348,24 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
size_t row = pos / order;
size_t bit = order - 1 - pos % order;
+ // The high bound should correspond to this step.
tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
{
+ // This hyperrectangle is contained entirely in the bound and do not
+ // overlap with other hyperrectangles since hiAddress is greater than
+ // tmpHiAddress and tmpLoAddress is greater that the high addresses
+ // of hyperrectangles that we have added previously.
tmpLoAddress[row] ^= (AddressElemType) 1 << bit;
+
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
AddBound(loCorner, hiCorner);
}
+ // The lower bound should correspond to this step.
tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
}
}
@@ -336,11 +375,14 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
{
numBounds = 0;
+ // Calculate the number of equal leading bits of the lower address and
+ // the high address.
size_t row = 0;
for ( ; row < hiAddress.n_elem; row++)
if (loAddress[row] != hiAddress[row])
break;
+ // If the high address is equal to the lower address.
if (row == hiAddress.n_elem)
{
for (size_t i = 0; i < dim; i++)
@@ -348,7 +390,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
loBound(i, 0) = bounds[i].Lo();
hiBound(i, 0) = bounds[i].Hi();
}
-
numBounds = 1;
return;
@@ -362,6 +403,7 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
if ((row == hiAddress.n_elem - 1) && (bit == order - 1))
{
+ // If the addresses differ in the last bit.
for (size_t i = 0; i < dim; i++)
{
loBound(i, 0) = bounds[i].Lo();
@@ -374,7 +416,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
}
size_t numEqualBits = row * order + bit;
-
InitHighBound(numEqualBits);
InitLowerBound(numEqualBits);
@@ -382,6 +423,7 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
if (numBounds == 0)
{
+ // I think this should never happen.
for (size_t i = 0; i < dim; i++)
{
loBound(i, 0) = bounds[i].Lo();
@@ -390,7 +432,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
numBounds = 1;
}
- assert(numBounds > 0);
}
/**
diff --git a/src/mlpack/tests/ub_tree_test.cpp b/src/mlpack/tests/ub_tree_test.cpp
index 312d440..2d3c28e 100644
--- a/src/mlpack/tests/ub_tree_test.cpp
+++ b/src/mlpack/tests/ub_tree_test.cpp
@@ -23,10 +23,11 @@ BOOST_AUTO_TEST_CASE(AddressTest)
arma::Mat<ElemType> dataset(8, 1000);
dataset.randu();
-
+ dataset -= 0.5;
arma::Col<AddressElemType> address(dataset.n_rows);
arma::Col<ElemType> point(dataset.n_rows);
+ // Ensure that this is one-to-one transform.
for (size_t i = 0; i < dataset.n_cols; i++)
{
addr::PointToAddress(address, dataset.col(i));
@@ -35,6 +36,7 @@ BOOST_AUTO_TEST_CASE(AddressTest)
for (size_t k = 0; k < dataset.n_rows; k++)
BOOST_REQUIRE_CLOSE(dataset(k, i), point[k], 1e-13);
}
+
}
template<typename TreeType>
@@ -56,6 +58,7 @@ void CheckSplit(const TreeType& tree)
arma::Col<AddressElemType> address(tree.Bound().Dim());
+ // Find the highest address of the left node.
for (size_t i = 0; i < tree.Left()->NumDescendants(); i++)
{
addr::PointToAddress(address,
@@ -65,6 +68,7 @@ void CheckSplit(const TreeType& tree)
hi = address;
}
+ // Find the lowest address of the right node.
for (size_t i = 0; i < tree.Right()->NumDescendants(); i++)
{
addr::PointToAddress(address,
@@ -74,6 +78,7 @@ void CheckSplit(const TreeType& tree)
lo = address;
}
+ // Addresses in the left node should be less than addresses in the right node.
BOOST_REQUIRE_LE(addr::CompareAddresses(hi, lo), 0);
CheckSplit(*tree.Left());
@@ -99,11 +104,13 @@ void CheckBound(const TreeType& tree)
{
arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i));
+ // Check that the point is contained in the bound.
BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point));
const arma::Mat<ElemType>& loBound = tree.Bound().LoBound();
const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound();
+ // Ensure that there is a hyperrectangle that contains the point.
bool success = false;
for (size_t j = 0; j < tree.Bound().NumBounds(); j++)
{
@@ -142,6 +149,7 @@ BOOST_AUTO_TEST_CASE(UBTreeBoundTest)
CheckBound(tree);
}
+// Ensure that MinDistance() and MaxDistance() works correctly.
template<typename TreeType, typename MetricType>
void CheckDistance(TreeType& tree, TreeType* node = NULL)
{
More information about the mlpack-git
mailing list