[mlpack-git] master: Fixed various errors in the implementation of the bound and the calculation of addressed. (8c5a97d)

gitdub at mlpack.org gitdub at mlpack.org
Mon Aug 29 12:02:04 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit 8c5a97dcb1641ae6c98edc70426fb19f5cd7cb79
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Mon Aug 8 20:04:46 2016 +0300

    Fixed various errors in the implementation of the bound and the calculation of addressed.


>---------------------------------------------------------------

8c5a97dcb1641ae6c98edc70426fb19f5cd7cb79
 src/mlpack/core/tree/address.hpp                   | 16 ++++-
 src/mlpack/core/tree/binary_space_tree/traits.hpp  |  6 ++
 src/mlpack/core/tree/binary_space_tree/typedef.hpp | 29 ++++++++-
 .../core/tree/binary_space_tree/ub_tree_split.hpp  | 34 +++++++++--
 .../tree/binary_space_tree/ub_tree_split_impl.hpp  | 29 ++++++++-
 src/mlpack/core/tree/cellbound.hpp                 | 68 ++++++++++++++++++++--
 src/mlpack/core/tree/cellbound_impl.hpp            | 65 +++++++++++++++++----
 src/mlpack/tests/ub_tree_test.cpp                  | 10 +++-
 8 files changed, 231 insertions(+), 26 deletions(-)

diff --git a/src/mlpack/core/tree/address.hpp b/src/mlpack/core/tree/address.hpp
index 276cedb..6f4bdd5 100644
--- a/src/mlpack/core/tree/address.hpp
+++ b/src/mlpack/core/tree/address.hpp
@@ -115,14 +115,21 @@ void AddressToPoint(VecType& point, const AddressType& address)
           (order - 1 - i));
     }
 
-
   for (size_t i = 0; i < rearrangedAddress.n_elem; i++)
   {
     bool sgn = rearrangedAddress(i) & ((AddressElemType) 1 << (order - 1));
 
+    if (!sgn)
+    {
+      rearrangedAddress(i) = ((AddressElemType) 1 << (order - 1)) - 1 -
+          rearrangedAddress(i);
+    }
+
     // Extract the mantissa.
     AddressElemType tmp = (AddressElemType) 1 << numMantBits;
     AddressElemType mantissa = rearrangedAddress(i) & (tmp - 1);
+    if (mantissa == 0)
+      mantissa = 1;
 
     VecElemType normalizedVal = (VecElemType) mantissa / tmp;
 
@@ -136,6 +143,13 @@ void AddressToPoint(VecType& point, const AddressType& address)
     e += std::numeric_limits<VecElemType>::min_exponent;
 
     point(i) = std::ldexp(normalizedVal, e);
+    if (std::isinf(point(i)))
+    {
+      if (point(i) > 0)
+        point(i) = std::numeric_limits<VecElemType>::max();
+      else
+        point(i) = std::numeric_limits<VecElemType>::lowest();
+    }
   }
 }
 
diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp
index cd6512c..9042420 100644
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@ -80,6 +80,12 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType,
   static const bool BinaryTree = true;
 };
 
+/**
+ * This is a specialization of the TreeType class to the UBTree tree type.
+ * The only difference with general BinarySpaceTree is that UBTree can have
+ * overlapping children.
+ * See mlpack/core/tree/tree_traits.hpp for more information.
+ */
 template<typename MetricType,
          typename StatisticType,
          typename MatType,
diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
index 77f986d..203d739 100644
--- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
@@ -135,7 +135,34 @@ using MeanSplitBallTree = BinarySpaceTree<MetricType,
                                           bound::BallBound,
                                           MeanSplit>;
 
-
+/**
+ * The Universal B-tree. When recursively splitting nodes, the class
+ * calculates addresses of all points and splits each node according to the
+ * median address. Children nodes may overlap since the implementation
+ * of a tighter bound requires a lot of arithmetic operations. In order to get
+ * a tighter bound increase the CellBound::maxNumBounds constant.
+ *
+ * @code
+ * @inproceedings{bayer1997,
+ *   author = {Bayer, Rudolf},
+ *   title = {The Universal B-Tree for Multidimensional Indexing: General
+ *       Concepts},
+ *   booktitle = {Proceedings of the International Conference on Worldwide
+ *       Computing and Its Applications},
+ *   series = {WWCA '97},
+ *   year = {1997},
+ *   isbn = {3-540-63343-X},
+ *   pages = {198--209},
+ *   numpages = {12},
+ *   publisher = {Springer-Verlag},
+ *   address = {London, UK, UK},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
 template<typename MetricType, typename StatisticType, typename MatType>
 using UBTree = BinarySpaceTree<MetricType,
                                StatisticType,
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
index 93b2ea1..85f5013 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split.hpp
@@ -1,5 +1,9 @@
 /**
  * @file ub_tree_split.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Definition of UBTreeSplit, a class that splits the space according
+ * to the median address of points contained in the node.
  */
 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_HPP
@@ -14,10 +18,23 @@ template<typename BoundType, typename MatType = arma::mat>
 class UBTreeSplit
 {
  public:
+  //! The type of a one-dimensional address.
   typedef typename std::conditional<sizeof(typename MatType::elem_type) * CHAR_BIT <= 32,
                                     uint32_t,
                                     uint64_t>::type AddressElemType;
 
+  /**
+   * Split the node according to the median address of points contained in the
+   * node.
+   *
+   * @param bound The bound used for this node.
+   * @param data The dataset used by the binary space tree.
+   * @param begin Index of the starting point in the dataset that belongs to
+   *    this node.
+   * @param count Number of points in this node.
+   * @param splitCol The index at which the dataset is divided into two parts
+   *    after the rearrangement.
+   */
   bool SplitNode(BoundType& bound,
                  MatType& data,
                  const size_t begin,
@@ -32,14 +49,22 @@ class UBTreeSplit
                  std::vector<size_t>& oldFromNew);
 
  private:
-//  arma::Mat<AddressElemType> addresses;
+  //! This vector contains addresses of all points in the dataset.
   std::vector<std::pair<arma::Col<AddressElemType>, size_t>> addresses;
 
-  template<typename VecType>
-  arma::Col<AddressElemType> CalculateAddress(const VecType& point);
-
+  /**
+   * Calculate addresses for all points in the dataset.
+   *
+   * @param data The dataset used by the binary space tree.
+   */
   void InitializeAddresses(const MatType& data);
 
+  /**
+   * Calculate addresses for all points in the dataset.
+   *
+   * @param data The dataset used by the binary space tree.
+   * @param count Number of points in this node.
+   */
   void PerformSplit(MatType& data,
                        const size_t count);
 
@@ -47,6 +72,7 @@ class UBTreeSplit
                        const size_t count,
                        std::vector<size_t>& oldFromNew);
 
+  //! A comparator for sorting addresses.
   static bool ComparePair(
       const std::pair<arma::Col<AddressElemType>, size_t>& p1,
       const std::pair<arma::Col<AddressElemType>, size_t>& p2)
diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index e0804b4..9d5963e 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -1,5 +1,9 @@
 /**
  * @file ub_tree_split_impl.hpp
+ * @author Mikhail Lozhnikov
+ *
+ * Implementation of UBTreeSplit, a class that splits a node according
+ * to the median address of points contained in the node.
  */
 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
@@ -20,15 +24,25 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
   constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
   if (begin == 0 && count == data.n_cols)
   {
+    // Calculate all addresses.
     InitializeAddresses(data);
 
+    // Probably this is not a good idea. Maybe it is better to get
+    // a number of distinct samples and find the median.
     std::sort(addresses.begin(), addresses.end(), ComparePair);
 
+    // Rearrange dataset.
     PerformSplit(data, count);
   }
 
+  // The bound shouldn't contain too many subrectangles.
+  // In order to minimize the number of hyperrectangles we set last bits
+  // of the last address in the node to 1 and last bits of the first  address
+  // in the next node to zero in such a way that the ordering is not
+  // disturbed.
   if (begin + count < data.n_cols)
   {
+    // Omit leading equal bits.
     size_t row = 0;
     arma::Col<AddressElemType>& lo = addresses[begin + count - 1].first;
     const arma::Col<AddressElemType>& hi = addresses[begin + count].first;
@@ -46,6 +60,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
 
     bit++;
 
+    // Replace insignificant bits.
     if (bit == order)
     {
       bit = 0;
@@ -62,8 +77,15 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
       for (; bit < order; bit++)
         lo[row] |= ((AddressElemType) 1 << (order - 1 - bit)); 
   }
+
+  // The bound shouldn't contain too many subrectangles.
+  // In order to minimize the number of hyperrectangles we set last bits
+  // of the first address in the next node to 0 and last bits of the last
+  // address in the previous node to 1 in such a way that the ordering is not
+  // disturbed.
   if (begin > 0)
   {
+    // Omit leading equal bits.
     size_t row = 0;
     const arma::Col<AddressElemType>& lo = addresses[begin - 1].first;
     arma::Col<AddressElemType>& hi = addresses[begin].first;
@@ -81,6 +103,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
 
     bit++;
 
+    // Replace insignificant bits.
     if (bit == order)
     {
       bit = 0;
@@ -98,12 +121,15 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
         hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit)); 
   }
 
+  // Set the minimum and the maximum addresses.
   for (size_t k = 0; k < bound.Dim(); k++)
   {
     bound.LoAddress()[k] = addresses[begin].first[k];
     bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
   }
   bound.UpdateAddressBounds();
+
+  // Since the dataset is sorted we can easily get the split column.
   splitCol = begin + count / 2;
   
   return true;
@@ -214,12 +240,11 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
 template<typename BoundType, typename MatType>
 void UBTreeSplit<BoundType, MatType>::InitializeAddresses(const MatType& data)
 {
-//  addresses.set_size(data.n_rows, data.n_cols);
   addresses.resize(data.n_cols);
 
+  // Calculate all addresses.
   for (size_t i = 0; i < data.n_cols; i++)
   {
-//    address.col(i) = CalculateAddress(data.col(i));
     addresses[i].first.zeros(data.n_rows);
     bound::addr::PointToAddress(addresses[i].first, data.col(i));
     addresses[i].second = i;
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index 4097152..1fa51a4 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -1,6 +1,30 @@
 /**
  * @file cellbound.hpp
+ * @author Mikhail Lozhnikov
  *
+ * Definition of the CellBound class. The class describes a bound that consists
+ * of a number of hyperrectangles. These hyperrectangles do not overlap each
+ * other. The bound is limited by an outer hyperrectangle and two addresses,
+ * the lower address and the high address. Thus, the bound contains all points
+ * included between the lower and the high addresses.
+ *
+ * The notion of addresses is described in the following paper.
+ * @code
+ * @inproceedings{bayer1997,
+ *   author = {Bayer, Rudolf},
+ *   title = {The Universal B-Tree for Multidimensional Indexing: General
+ *       Concepts},
+ *   booktitle = {Proceedings of the International Conference on Worldwide
+ *       Computing and Its Applications},
+ *   series = {WWCA '97},
+ *   year = {1997},
+ *   isbn = {3-540-63343-X},
+ *   pages = {198--209},
+ *   numpages = {12},
+ *   publisher = {Springer-Verlag},
+ *   address = {London, UK, UK},
+ * }
+ * @endcode
  */
 #ifndef MLPACK_CORE_TREE_CELLBOUND_HPP
 #define MLPACK_CORE_TREE_CELLBOUND_HPP
@@ -19,6 +43,8 @@ template<typename MetricType = metric::LMetric<2, true>,
 class CellBound
 {
  public:
+  //! Depending on the precision of the tree element type, we may need to use
+  //! uint32_t or uint64_t.
   typedef typename std::conditional<sizeof(ElemType) * CHAR_BIT <= 32,
                                     uint32_t,
                                     uint64_t>::type AddressElemType;
@@ -61,18 +87,22 @@ class CellBound
   const math::RangeType<ElemType>& operator[](const size_t i) const
   { return bounds[i]; }
 
+  //! Get lower address.
   arma::Col<AddressElemType>& LoAddress() { return loAddress; }
-
+  //! Modify lower address.
   const arma::Col<AddressElemType>& LoAddress() const {return loAddress; }
   
+  //! Get high address.
   arma::Col<AddressElemType>& HiAddress() { return hiAddress; }
-
+  //! Modify high address.
   const arma::Col<AddressElemType>& HiAddress() const {return hiAddress; }
 
+  //! Get lower bound of each subrectangle.
   const arma::Mat<ElemType>& LoBound() const { return loBound; }
-
+  //! Get high bound of each subrectangle.
   const arma::Mat<ElemType>& HiBound() const { return hiBound; }
 
+  //! Get the number of subrectangles.
   size_t NumBounds() const { return numBounds; }
 
   //! Get the minimum width of the bound.
@@ -159,6 +189,10 @@ class CellBound
   template<typename VecType>
   bool Contains(const VecType& point) const;
 
+  /**
+   * Calculate the bounds of all subrectangles. You should set the lower and the
+   * high addresses.
+   */
   void UpdateAddressBounds();
 
   /**
@@ -173,24 +207,48 @@ class CellBound
   void Serialize(Archive& ar, const unsigned int version);
 
  private:
+  //! The precision of the tree element type.
   static constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
+  //! Maximum number of subrectangles.
   const size_t maxNumBounds = 10;
   //! The dimensionality of the bound.
   size_t dim;
   //! The bounds for each dimension.
   math::RangeType<ElemType>* bounds;
+  //! Lower bounds of subrectangles.
   arma::Mat<ElemType> loBound;
+  //! High bounds of subrectangles.
   arma::Mat<ElemType> hiBound;
+  //! The numbre of subrectangles.
   size_t numBounds;
-
+  //! The lowest address that the bound may contain.
   arma::Col<AddressElemType> loAddress;
+  //! The highest address that the bound may contain.
   arma::Col<AddressElemType> hiAddress;
-
+  //! The minimal width of the outer rectangle.
   ElemType minWidth;
 
+  /**
+   * Add a subrectangle to the bound.
+   *
+   * @param loCorner The lower corner of the subrectangle that is being added.
+   * @param hiCorner The high corner of the subrectangle that is being added.
+   */
   void AddBound(const arma::Col<ElemType>& loCorner,
                 const arma::Col<ElemType>& hiCorner);
+  /**
+   * Initialize all subrectangles that touches the lower address.
+   *
+   * @param numEqualBits The number of equal leading bits of the lower address
+   * and the high address.
+   */
   void InitHighBound(size_t numEqualBits);
+  /**
+   * Initialize all subrectangles that touches the high address.
+   *
+   * @param numEqualBits The number of equal leading bits of the lower address
+   * and the high address.
+   */
   void InitLowerBound(size_t numEqualBits);
 };
 
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 93df708..13f02e0 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -1,10 +1,9 @@
 /**
  * @file cellbound_impl.hpp
+ * @author Mikhail Lozhnikov
  *
- * Implementation of hyper-rectangle bound policy class.
- * Template parameter Power is the metric to use; use 2 for Euclidean (L2).
- *
- * @experimental
+ * Implementation of the CellBound class. The class describes a bound that
+ * consists of a number of hyperrectangles.
  */
 #ifndef MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
 #define MLPACK_CORE_TREE_CELLBOUND_IMPL_HPP
@@ -179,14 +178,15 @@ void CellBound<MetricType, ElemType>::AddBound(
   assert(loCorner.n_elem == dim);
   assert(hiCorner.n_elem == dim);
 
+  // If the subrectangle is not contained entirely in the outer rectangle,
+  // we shrink it.
   for (size_t k = 0; k < dim; k++)
   {
-    loBound(k, numBounds) =  loCorner[k] +
-        math::ClampNonNegative(bounds[k].Lo() - loCorner[k]);
+    loBound(k, numBounds) = std::max(loCorner[k], bounds[k].Lo());
 
-    hiBound(k, numBounds) = bounds[k].Hi() -
-        math::ClampNonNegative(bounds[k].Hi() - hiCorner[k]);
+    hiBound(k, numBounds) = std::min(bounds[k].Hi(), hiCorner[k]);
 
+    // This should never happen.
     if (loBound(k, numBounds) > hiBound(k, numBounds))
       return;
   }
@@ -205,26 +205,37 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
 
   assert(tmpHiAddress.n_elem > 0);
 
+  // We have to calculate the number of subrectangles since the maximum number
+  // of hyperrectangles is restricted.
   size_t numCorners = 0;
   for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
   {
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // This hyperrectangle is not contained entirely in the bound.
+    // So, the number of hyperrectangles should be increased.
     if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
       numCorners++;
 
+    // We ran out of the limit of hyperrectangles. In that case we enlare
+    // the last hyperrectangle.
     if (numCorners >= maxNumBounds / 2)
       tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
   }
 
   size_t pos = order * tmpHiAddress.n_elem - 1;
 
+  // Find the last hyperrectangle and add it to the bound.
   for ( ; pos > numEqualBits; pos--)
   {
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // All last bits after pos of tmpHiAddress are equal to 1 and
+    // All last bits of tmpLoAddress (after pos) are equal to 0.
+    // Thus, tmpHiAddress corresponds to the high corner of the enlarged
+    // rectangle and tmpLoAddress corresponds to the lower corner.
     if (!(tmpHiAddress[row] & ((AddressElemType) 1 << bit)))
     {
       addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -233,9 +244,11 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
       AddBound(loCorner, hiCorner);
       break;
     }
+    // Nullify the bit that corresponds to this step.
     tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
   }
 
+  // Add the enlarged rectangle if we have not done that.
   if (pos == numEqualBits)
   {
     addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -249,17 +262,22 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // The lower bound should correspond to this step.
     tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
 
     if (tmpHiAddress[row] & ((AddressElemType) 1 << bit))
     {
+      // This hyperrectangle is contained entirely in the bound and do not
+      // overlap with other hyperrectangles since loAddress is less than
+      // tmpLoAddress and tmpHiAddress is less that the lower addresses
+      // of hyperrectangles that we have added previously.
       tmpHiAddress[row] ^= (AddressElemType) 1 << bit;
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
       AddBound(loCorner, hiCorner);
     }
-
+    // The high bound should correspond to this step.
     tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
   }
 }
@@ -272,26 +290,37 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
   arma::Col<ElemType> loCorner(tmpHiAddress.n_elem);
   arma::Col<ElemType> hiCorner(tmpHiAddress.n_elem);
 
+  // We have to calculate the number of subrectangles since the maximum number
+  // of hyperrectangles is restricted.
   size_t numCorners = 0;
   for (size_t pos = numEqualBits + 1; pos < order * tmpHiAddress.n_elem; pos++)
   {
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // This hyperrectangle is not contained entirely in the bound.
+    // So, the number of hyperrectangles should be increased.
     if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
       numCorners++;
 
+    // We ran out of the limit of hyperrectangles. In that case we enlare
+    // the last hyperrectangle.
     if (numCorners >= maxNumBounds / 2)
       tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
   }
 
   size_t pos = order * tmpHiAddress.n_elem - 1;
 
+  // Find the last hyperrectangle and add it to the bound.
   for ( ; pos > numEqualBits; pos--)
   {
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // All last bits after pos of tmpHiAddress are equal to 1 and
+    // All last bits of tmpLoAddress (after pos) are equal to 0.
+    // Thus, tmpHiAddress corresponds to the high corner of the enlarged
+    // rectangle and tmpLoAddress corresponds to the lower corner.
     if (tmpLoAddress[row] & ((AddressElemType) 1 << bit))
     {
       addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -300,9 +329,12 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
       AddBound(loCorner, hiCorner);
       break;
     }
+    // Enlarge the hyperrectangle at this step since it is contained
+    // entirely in the bound.
     tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
   }
 
+  // Add the enlarged rectangle if we have not done that.
   if (pos == numEqualBits)
   {
     addr::AddressToPoint(loCorner, tmpLoAddress);
@@ -316,17 +348,24 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
     size_t row = pos / order;
     size_t bit = order - 1 - pos % order;
 
+    // The high bound should correspond to this step.
     tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
 
     if (!(tmpLoAddress[row] & ((AddressElemType) 1 << bit)))
     {
+      // This hyperrectangle is contained entirely in the bound and do not
+      // overlap with other hyperrectangles since hiAddress is greater than
+      // tmpHiAddress and tmpLoAddress is greater that the high addresses
+      // of hyperrectangles that we have added previously.
       tmpLoAddress[row] ^= (AddressElemType) 1 << bit;
+
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
       AddBound(loCorner, hiCorner);
     }
 
+    // The lower bound should correspond to this step.
     tmpLoAddress[row] &= ~((AddressElemType) 1 << bit);
   }
 }
@@ -336,11 +375,14 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
 {
   numBounds = 0;
 
+  // Calculate the number of equal leading bits of the lower address and
+  // the high address.
   size_t row = 0;
   for ( ; row < hiAddress.n_elem; row++)
     if (loAddress[row] != hiAddress[row])
       break;
 
+  // If the high address is equal to the lower address.
   if (row == hiAddress.n_elem)
   {
     for (size_t i = 0; i < dim; i++)
@@ -348,7 +390,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
       loBound(i, 0) = bounds[i].Lo();
       hiBound(i, 0) = bounds[i].Hi();
     }
-
     numBounds = 1;
 
     return;
@@ -362,6 +403,7 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
 
   if ((row == hiAddress.n_elem - 1) && (bit == order - 1))
   {
+    // If the addresses differ in the last bit.
     for (size_t i = 0; i < dim; i++)
     {
       loBound(i, 0) = bounds[i].Lo();
@@ -374,7 +416,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
   }
 
   size_t numEqualBits = row * order + bit;
-
   InitHighBound(numEqualBits);
   InitLowerBound(numEqualBits);
 
@@ -382,6 +423,7 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
 
   if (numBounds == 0)
   {
+    // I think this should never happen.
     for (size_t i = 0; i < dim; i++)
     {
       loBound(i, 0) = bounds[i].Lo();
@@ -390,7 +432,6 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
 
     numBounds = 1;
   }
-  assert(numBounds > 0);
 }
 
 /**
diff --git a/src/mlpack/tests/ub_tree_test.cpp b/src/mlpack/tests/ub_tree_test.cpp
index 312d440..2d3c28e 100644
--- a/src/mlpack/tests/ub_tree_test.cpp
+++ b/src/mlpack/tests/ub_tree_test.cpp
@@ -23,10 +23,11 @@ BOOST_AUTO_TEST_CASE(AddressTest)
   arma::Mat<ElemType> dataset(8, 1000);
 
   dataset.randu();
-
+  dataset -= 0.5;
   arma::Col<AddressElemType> address(dataset.n_rows);
   arma::Col<ElemType> point(dataset.n_rows);
 
+  // Ensure that this is one-to-one transform.
   for (size_t i = 0; i < dataset.n_cols; i++)
   {
     addr::PointToAddress(address, dataset.col(i));
@@ -35,6 +36,7 @@ BOOST_AUTO_TEST_CASE(AddressTest)
     for (size_t k = 0; k < dataset.n_rows; k++)
       BOOST_REQUIRE_CLOSE(dataset(k, i), point[k], 1e-13);
   }
+
 }
 
 template<typename TreeType>
@@ -56,6 +58,7 @@ void CheckSplit(const TreeType& tree)
 
   arma::Col<AddressElemType> address(tree.Bound().Dim());
 
+  // Find the highest address of the left node.
   for (size_t i = 0; i < tree.Left()->NumDescendants(); i++)
   {
     addr::PointToAddress(address,
@@ -65,6 +68,7 @@ void CheckSplit(const TreeType& tree)
       hi = address;
   }
 
+  // Find the lowest address of the right node.
   for (size_t i = 0; i < tree.Right()->NumDescendants(); i++)
   {
     addr::PointToAddress(address,
@@ -74,6 +78,7 @@ void CheckSplit(const TreeType& tree)
       lo = address;
   }
 
+  // Addresses in the left node should be less than addresses in the right node.
   BOOST_REQUIRE_LE(addr::CompareAddresses(hi, lo), 0);
 
   CheckSplit(*tree.Left());
@@ -99,11 +104,13 @@ void CheckBound(const TreeType& tree)
   {
     arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i));
 
+    // Check that the point is contained in the bound.
     BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point));
 
     const arma::Mat<ElemType>& loBound = tree.Bound().LoBound();
     const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound();
 
+    // Ensure that there is a hyperrectangle that contains the point.
     bool success = false;
     for (size_t j = 0; j < tree.Bound().NumBounds(); j++)
     {
@@ -142,6 +149,7 @@ BOOST_AUTO_TEST_CASE(UBTreeBoundTest)
   CheckBound(tree);
 }
 
+// Ensure that MinDistance() and MaxDistance() works correctly.
 template<typename TreeType, typename MetricType>
 void CheckDistance(TreeType& tree, TreeType* node = NULL)
 {




More information about the mlpack-git mailing list