[mlpack-git] master: Optimize the UB tree bound. (f17843f)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Aug 23 08:23:47 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit f17843fc8e8ef2b1c4b04d572c521575f20a1f3c
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Tue Aug 23 15:23:47 2016 +0300
Optimize the UB tree bound.
>---------------------------------------------------------------
f17843fc8e8ef2b1c4b04d572c521575f20a1f3c
.../tree/binary_space_tree/ub_tree_split_impl.hpp | 2 +-
src/mlpack/core/tree/cellbound.hpp | 18 ++++--
src/mlpack/core/tree/cellbound_impl.hpp | 64 +++++++++++++++-------
src/mlpack/tests/ub_tree_test.cpp | 24 +++++---
4 files changed, 75 insertions(+), 33 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index dedf894..c3bc7e1 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -132,7 +132,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
bound.LoAddress()[k] = addresses[begin].first[k];
bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
}
- bound.UpdateAddressBounds();
+ bound.UpdateAddressBounds(data.cols(begin, begin + count - 1));
return true;
}
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index cd6a7da..a653c6e 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -219,8 +219,11 @@ class CellBound
/**
* Calculate the bounds of all subrectangles. You should set the lower and the
* high addresses.
+ *
+ * @param data Points that are contained in the node.
*/
- void UpdateAddressBounds();
+ template<typename MatType>
+ void UpdateAddressBounds(const MatType& data);
/**
* Returns the diameter of the hyperrectangle (that is, the longest diagonal).
@@ -260,23 +263,30 @@ class CellBound
*
* @param loCorner The lower corner of the subrectangle that is being added.
* @param hiCorner The high corner of the subrectangle that is being added.
+ * @param data Points that are contained in the node.
*/
+ template<typename MatType>
void AddBound(const arma::Col<ElemType>& loCorner,
- const arma::Col<ElemType>& hiCorner);
+ const arma::Col<ElemType>& hiCorner,
+ const MatType& data);
/**
* Initialize all subrectangles that touches the lower address.
*
* @param numEqualBits The number of equal leading bits of the lower address
* and the high address.
+ * @param data Points that are contained in the node.
*/
- void InitHighBound(size_t numEqualBits);
+ template<typename MatType>
+ void InitHighBound(size_t numEqualBits, const MatType& data);
/**
* Initialize all subrectangles that touches the high address.
*
* @param numEqualBits The number of equal leading bits of the lower address
* and the high address.
+ * @param data Points that are contained in the node.
*/
- void InitLowerBound(size_t numEqualBits);
+ template<typename MatType>
+ void InitLowerBound(size_t numEqualBits, const MatType& data);
};
// A specialization of BoundTraits for this class.
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 150e2a2..ebb9a34 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -169,34 +169,55 @@ inline void CellBound<MetricType, ElemType>::Center(
}
template<typename MetricType, typename ElemType>
+template<typename MatType>
void CellBound<MetricType, ElemType>::AddBound(
const arma::Col<ElemType>& loCorner,
- const arma::Col<ElemType>& hiCorner)
+ const arma::Col<ElemType>& hiCorner,
+ const MatType& data)
{
assert(numBounds < loBound.n_cols);
assert(loBound.n_rows == dim);
assert(loCorner.n_elem == dim);
assert(hiCorner.n_elem == dim);
- // If the subrectangle is not contained entirely in the outer rectangle,
- // we shrink it.
for (size_t k = 0; k < dim; k++)
{
- loBound(k, numBounds) = std::max(loCorner[k], bounds[k].Lo());
+ loBound(k, numBounds) = std::numeric_limits<ElemType>::max();
+ hiBound(k, numBounds) = std::numeric_limits<ElemType>::lowest();
+ }
- hiBound(k, numBounds) = std::min(bounds[k].Hi(), hiCorner[k]);
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ size_t k = 0;
+ // Check if the point is contained in the hyperrectangle.
+ for (k = 0; k < dim; k++)
+ if (data(k, i) < loCorner[k] || data(k, i) > hiCorner[k])
+ break;
- // This should never happen.
- if (loBound(k, numBounds) > hiBound(k, numBounds))
- return;
+ if (k < dim)
+ continue; // The point is not contained in the hyperrectangle.
+
+ // Srink the bound.
+ for (k = 0; k < dim; k++)
+ {
+ loBound(k, numBounds) = std::min(loBound(k, numBounds), data(k, i));
+
+ hiBound(k, numBounds) = std::max(hiBound(k, numBounds), data(k, i));
+ }
}
+ for (size_t k = 0; k < dim; k++)
+ if (loBound(k, numBounds) > hiBound(k, numBounds))
+ return; // The hyperrectangle does not contain points.
+
numBounds++;
}
template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
+template<typename MatType>
+void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits,
+ const MatType& data)
{
arma::Col<AddressElemType> tmpHiAddress(hiAddress);
arma::Col<AddressElemType> tmpLoAddress(hiAddress);
@@ -241,7 +262,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
break;
}
// Nullify the bit that corresponds to this step.
@@ -254,7 +275,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
}
for ( ; pos > numEqualBits; pos--)
@@ -275,7 +296,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
}
// The high bound should correspond to this step.
tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
@@ -283,7 +304,9 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
}
template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
+template<typename MatType>
+void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits,
+ const MatType& data)
{
arma::Col<AddressElemType> tmpHiAddress(loAddress);
arma::Col<AddressElemType> tmpLoAddress(loAddress);
@@ -326,7 +349,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
break;
}
// Enlarge the hyperrectangle at this step since it is contained
@@ -340,7 +363,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
}
for ( ; pos > numEqualBits; pos--)
@@ -362,7 +385,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
addr::AddressToPoint(loCorner, tmpLoAddress);
addr::AddressToPoint(hiCorner, tmpHiAddress);
- AddBound(loCorner, hiCorner);
+ AddBound(loCorner, hiCorner, data);
}
// The lower bound should correspond to this step.
@@ -371,7 +394,8 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
}
template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::UpdateAddressBounds()
+template<typename MatType>
+void CellBound<MetricType, ElemType>::UpdateAddressBounds(const MatType& data)
{
numBounds = 0;
@@ -416,8 +440,8 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
}
size_t numEqualBits = row * order + bit;
- InitHighBound(numEqualBits);
- InitLowerBound(numEqualBits);
+ InitHighBound(numEqualBits, data);
+ InitLowerBound(numEqualBits, data);
assert(numBounds <= maxNumBounds);
@@ -892,7 +916,7 @@ inline CellBound<MetricType, ElemType>& CellBound<MetricType, ElemType>::operato
loBound(i, 0) = bounds[i].Lo();
hiBound(i, 0) = bounds[i].Hi();
}
- numBounds = 0;
+ numBounds = 1;
}
return *this;
}
diff --git a/src/mlpack/tests/ub_tree_test.cpp b/src/mlpack/tests/ub_tree_test.cpp
index 2d3c28e..84bffed 100644
--- a/src/mlpack/tests/ub_tree_test.cpp
+++ b/src/mlpack/tests/ub_tree_test.cpp
@@ -180,13 +180,17 @@ void CheckDistance(TreeType& tree, TreeType* node = NULL)
minDist = dist;
}
- BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist);
- BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point));
+ BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+ BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point) *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
math::RangeType<ElemType> r = tree.Bound().RangeDistance(point);
- BOOST_REQUIRE_LE(r.Lo(), minDist);
- BOOST_REQUIRE_LE(maxDist, r.Hi());
+ BOOST_REQUIRE_LE(r.Lo(), minDist *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+ BOOST_REQUIRE_LE(maxDist, r.Hi() *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
}
if (!tree.IsLeaf())
@@ -214,13 +218,17 @@ void CheckDistance(TreeType& tree, TreeType* node = NULL)
minDist = dist;
}
- BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist);
- BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()));
+ BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+ BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()) *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
math::RangeType<ElemType> r = tree.Bound().RangeDistance(node->Bound());
- BOOST_REQUIRE_LE(r.Lo(), minDist);
- BOOST_REQUIRE_LE(maxDist, r.Hi());
+ BOOST_REQUIRE_LE(r.Lo(), minDist *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+ BOOST_REQUIRE_LE(maxDist, r.Hi() *
+ (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
}
if (!node->IsLeaf())
{
More information about the mlpack-git
mailing list