[mlpack-git] master: Optimize the UB tree bound. (f17843f)

gitdub at mlpack.org gitdub at mlpack.org
Tue Aug 23 08:23:47 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e

>---------------------------------------------------------------

commit f17843fc8e8ef2b1c4b04d572c521575f20a1f3c
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Tue Aug 23 15:23:47 2016 +0300

    Optimize the UB tree bound.


>---------------------------------------------------------------

f17843fc8e8ef2b1c4b04d572c521575f20a1f3c
 .../tree/binary_space_tree/ub_tree_split_impl.hpp  |  2 +-
 src/mlpack/core/tree/cellbound.hpp                 | 18 ++++--
 src/mlpack/core/tree/cellbound_impl.hpp            | 64 +++++++++++++++-------
 src/mlpack/tests/ub_tree_test.cpp                  | 24 +++++---
 4 files changed, 75 insertions(+), 33 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
index dedf894..c3bc7e1 100644
--- a/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/ub_tree_split_impl.hpp
@@ -132,7 +132,7 @@ bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
     bound.LoAddress()[k] = addresses[begin].first[k];
     bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
   }
-  bound.UpdateAddressBounds();
+  bound.UpdateAddressBounds(data.cols(begin, begin + count - 1));
 
   return true;
 }
diff --git a/src/mlpack/core/tree/cellbound.hpp b/src/mlpack/core/tree/cellbound.hpp
index cd6a7da..a653c6e 100644
--- a/src/mlpack/core/tree/cellbound.hpp
+++ b/src/mlpack/core/tree/cellbound.hpp
@@ -219,8 +219,11 @@ class CellBound
   /**
    * Calculate the bounds of all subrectangles. You should set the lower and the
    * high addresses.
+   *
+   * @param data Points that are contained in the node.
    */
-  void UpdateAddressBounds();
+  template<typename MatType>
+  void UpdateAddressBounds(const MatType& data);
 
   /**
    * Returns the diameter of the hyperrectangle (that is, the longest diagonal).
@@ -260,23 +263,30 @@ class CellBound
    *
    * @param loCorner The lower corner of the subrectangle that is being added.
    * @param hiCorner The high corner of the subrectangle that is being added.
+   * @param data Points that are contained in the node.
    */
+  template<typename MatType>
   void AddBound(const arma::Col<ElemType>& loCorner,
-                const arma::Col<ElemType>& hiCorner);
+                const arma::Col<ElemType>& hiCorner,
+                const MatType& data);
   /**
    * Initialize all subrectangles that touches the lower address.
    *
    * @param numEqualBits The number of equal leading bits of the lower address
    * and the high address.
+   * @param data Points that are contained in the node.
    */
-  void InitHighBound(size_t numEqualBits);
+  template<typename MatType>
+  void InitHighBound(size_t numEqualBits, const MatType& data);
   /**
    * Initialize all subrectangles that touches the high address.
    *
    * @param numEqualBits The number of equal leading bits of the lower address
    * and the high address.
+   * @param data Points that are contained in the node.
    */
-  void InitLowerBound(size_t numEqualBits);
+  template<typename MatType>
+  void InitLowerBound(size_t numEqualBits, const MatType& data);
 };
 
 // A specialization of BoundTraits for this class.
diff --git a/src/mlpack/core/tree/cellbound_impl.hpp b/src/mlpack/core/tree/cellbound_impl.hpp
index 150e2a2..ebb9a34 100644
--- a/src/mlpack/core/tree/cellbound_impl.hpp
+++ b/src/mlpack/core/tree/cellbound_impl.hpp
@@ -169,34 +169,55 @@ inline void CellBound<MetricType, ElemType>::Center(
 }
 
 template<typename MetricType, typename ElemType>
+template<typename MatType>
 void CellBound<MetricType, ElemType>::AddBound(
     const arma::Col<ElemType>& loCorner,
-    const arma::Col<ElemType>& hiCorner)
+    const arma::Col<ElemType>& hiCorner,
+    const MatType& data)
 {
   assert(numBounds < loBound.n_cols);
   assert(loBound.n_rows == dim);
   assert(loCorner.n_elem == dim);
   assert(hiCorner.n_elem == dim);
 
-  // If the subrectangle is not contained entirely in the outer rectangle,
-  // we shrink it.
   for (size_t k = 0; k < dim; k++)
   {
-    loBound(k, numBounds) = std::max(loCorner[k], bounds[k].Lo());
+    loBound(k, numBounds) = std::numeric_limits<ElemType>::max();
+    hiBound(k, numBounds) = std::numeric_limits<ElemType>::lowest();
+  }
 
-    hiBound(k, numBounds) = std::min(bounds[k].Hi(), hiCorner[k]);
+  for (size_t i = 0; i < data.n_cols; i++)
+  {
+    size_t k = 0;
+    // Check if the point is contained in the hyperrectangle.
+    for (k = 0; k < dim; k++)
+      if (data(k, i) < loCorner[k] || data(k, i) > hiCorner[k])
+        break;
 
-    // This should never happen.
-    if (loBound(k, numBounds) > hiBound(k, numBounds))
-      return;
+    if (k < dim)
+      continue; // The point is not contained in the hyperrectangle.
+
+    // Srink the bound.
+    for (k = 0; k < dim; k++)
+    {
+      loBound(k, numBounds) = std::min(loBound(k, numBounds), data(k, i));
+
+      hiBound(k, numBounds) = std::max(hiBound(k, numBounds), data(k, i));
+    }
   }
 
+  for (size_t k = 0; k < dim; k++)
+    if (loBound(k, numBounds) > hiBound(k, numBounds))
+      return; // The hyperrectangle does not contain points.
+
   numBounds++;
 }
 
 
 template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
+template<typename MatType>
+void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits,
+                                                    const MatType& data)
 {
   arma::Col<AddressElemType> tmpHiAddress(hiAddress);
   arma::Col<AddressElemType> tmpLoAddress(hiAddress);
@@ -241,7 +262,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-      AddBound(loCorner, hiCorner);
+      AddBound(loCorner, hiCorner, data);
       break;
     }
     // Nullify the bit that corresponds to this step.
@@ -254,7 +275,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
     addr::AddressToPoint(loCorner, tmpLoAddress);
     addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-    AddBound(loCorner, hiCorner);
+    AddBound(loCorner, hiCorner, data);
   }
 
   for ( ; pos > numEqualBits; pos--)
@@ -275,7 +296,7 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-      AddBound(loCorner, hiCorner);
+      AddBound(loCorner, hiCorner, data);
     }
     // The high bound should correspond to this step.
     tmpHiAddress[row] |= ((AddressElemType) 1 << bit);
@@ -283,7 +304,9 @@ void CellBound<MetricType, ElemType>::InitHighBound(size_t numEqualBits)
 }
 
 template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
+template<typename MatType>
+void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits,
+                                                     const MatType& data)
 {
   arma::Col<AddressElemType> tmpHiAddress(loAddress);
   arma::Col<AddressElemType> tmpLoAddress(loAddress);
@@ -326,7 +349,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-      AddBound(loCorner, hiCorner);
+      AddBound(loCorner, hiCorner, data);
       break;
     }
     // Enlarge the hyperrectangle at this step since it is contained
@@ -340,7 +363,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
     addr::AddressToPoint(loCorner, tmpLoAddress);
     addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-    AddBound(loCorner, hiCorner);
+    AddBound(loCorner, hiCorner, data);
   }
 
   for ( ; pos > numEqualBits; pos--)
@@ -362,7 +385,7 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
       addr::AddressToPoint(loCorner, tmpLoAddress);
       addr::AddressToPoint(hiCorner, tmpHiAddress);
 
-      AddBound(loCorner, hiCorner);
+      AddBound(loCorner, hiCorner, data);
     }
 
     // The lower bound should correspond to this step.
@@ -371,7 +394,8 @@ void CellBound<MetricType, ElemType>::InitLowerBound(size_t numEqualBits)
 }
 
 template<typename MetricType, typename ElemType>
-void CellBound<MetricType, ElemType>::UpdateAddressBounds()
+template<typename MatType>
+void CellBound<MetricType, ElemType>::UpdateAddressBounds(const MatType& data)
 {
   numBounds = 0;
 
@@ -416,8 +440,8 @@ void CellBound<MetricType, ElemType>::UpdateAddressBounds()
   }
 
   size_t numEqualBits = row * order + bit;
-  InitHighBound(numEqualBits);
-  InitLowerBound(numEqualBits);
+  InitHighBound(numEqualBits, data);
+  InitLowerBound(numEqualBits, data);
 
   assert(numBounds <= maxNumBounds);
 
@@ -892,7 +916,7 @@ inline CellBound<MetricType, ElemType>& CellBound<MetricType, ElemType>::operato
       loBound(i, 0) = bounds[i].Lo();
       hiBound(i, 0) = bounds[i].Hi();
     }
-    numBounds = 0;
+    numBounds = 1;
   }
   return *this;
 }
diff --git a/src/mlpack/tests/ub_tree_test.cpp b/src/mlpack/tests/ub_tree_test.cpp
index 2d3c28e..84bffed 100644
--- a/src/mlpack/tests/ub_tree_test.cpp
+++ b/src/mlpack/tests/ub_tree_test.cpp
@@ -180,13 +180,17 @@ void CheckDistance(TreeType& tree, TreeType* node = NULL)
           minDist = dist;
       }
 
-      BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist);
-      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point));
+      BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point) *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
 
       math::RangeType<ElemType> r = tree.Bound().RangeDistance(point);
 
-      BOOST_REQUIRE_LE(r.Lo(), minDist);
-      BOOST_REQUIRE_LE(maxDist, r.Hi());
+      BOOST_REQUIRE_LE(r.Lo(), minDist *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+      BOOST_REQUIRE_LE(maxDist, r.Hi() *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
     }
       
     if (!tree.IsLeaf())
@@ -214,13 +218,17 @@ void CheckDistance(TreeType& tree, TreeType* node = NULL)
             minDist = dist;
         }
 
-      BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist);
-      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()));
+      BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()) *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
 
       math::RangeType<ElemType> r = tree.Bound().RangeDistance(node->Bound());
 
-      BOOST_REQUIRE_LE(r.Lo(), minDist);
-      BOOST_REQUIRE_LE(maxDist, r.Hi());
+      BOOST_REQUIRE_LE(r.Lo(), minDist *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
+      BOOST_REQUIRE_LE(maxDist, r.Hi() *
+          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
     }
     if (!node->IsLeaf())
     {




More information about the mlpack-git mailing list