[mlpack-git] master: Hilbert R tree fixes. (72f53d6)

gitdub at mlpack.org gitdub at mlpack.org
Mon Jun 27 11:36:02 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/37fda23945b4f998cd5fa6ec011ae345236c8552...479eca0c625cc4255a3b1a354a4788dae10f1b01

>---------------------------------------------------------------

commit 72f53d600fb7b511af6d4e0acb2fecfa4bc17593
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date:   Wed Jun 1 16:34:47 2016 +0300

    Hilbert R tree fixes.


>---------------------------------------------------------------

72f53d600fb7b511af6d4e0acb2fecfa4bc17593
 .../tree/rectangle_tree/discrete_hilbert_value.hpp |  12 ++-
 .../rectangle_tree/discrete_hilbert_value_impl.hpp | 100 ++++++++++++++-------
 .../hilbert_r_tree_auxiliary_information.hpp       |   9 +-
 .../hilbert_r_tree_auxiliary_information_impl.hpp  |  28 +++---
 .../tree/rectangle_tree/hilbert_r_tree_split.hpp   |   2 +-
 .../rectangle_tree/hilbert_r_tree_split_impl.hpp   |  17 ++--
 .../core/tree/rectangle_tree/rectangle_tree.hpp    |   8 +-
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    |  35 ++++----
 .../rectangle_tree/recursive_hilbert_value.hpp     |  14 ++-
 .../recursive_hilbert_value_impl.hpp               |  40 ++++-----
 src/mlpack/core/tree/rectangle_tree/typedef.hpp    |   2 +-
 src/mlpack/tests/rectangle_tree_test.cpp           |   8 +-
 12 files changed, 175 insertions(+), 100 deletions(-)

diff --git a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
index 6bcee8f..3350100 100644
--- a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
@@ -92,7 +92,7 @@ class DiscreteHilbertValue
    * @param tree Not used
    * @param val The number of the point to compare with.
    */
-  template<typename TreeType,typename ElemType>
+  template<typename TreeType>
   int CompareWith(TreeType *tree, const size_t point);
 
   /**
@@ -146,12 +146,16 @@ class DiscreteHilbertValue
   void UpdateLargestValue(TreeType *node);
 
   //! Copy the largest Hilbert value.
-  DiscreteHilbertValue operator = (DiscreteHilbertValue &val);
+  DiscreteHilbertValue operator = (const DiscreteHilbertValue &val);
 
   //! Return the largest Hilbert value
   std::list<arma::Col<uint64_t>>::iterator LargestValue() const
   { return largestValue; }
 
+  //! Modify the largest Hilbert value
+  std::list<arma::Col<uint64_t>>::iterator &LargestValue()
+  { return largestValue; }
+
   //! Modify the local dataset
   std::list<arma::Col<uint64_t>> *LocalDataset() { return localDataset; }
   //! Modify the dataset
@@ -182,6 +186,10 @@ class DiscreteHilbertValue
    */
   static int CompareValues(const arma::Col<uint64_t> &value1,
                            const arma::Col<uint64_t> &value2);
+  /**
+   * Returns true if the node has the largest Hilbert value.
+   */
+  bool HasValue();
 };
 } // namespace tree
 } // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
index ce4df87..787ae8c 100644
--- a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
@@ -33,31 +33,45 @@ template<typename TreeType>
 DiscreteHilbertValue::DiscreteHilbertValue(const TreeType *tree) :
     dataset(tree->Parent() ?
           tree->Parent()->AuxiliaryInfo().LargestHilbertValue().Dataset() :
-          new arma::Mat<uint64_t>(tree->Dataset()->n_rows,
-                                  tree->MaxLeafSize()+1)),
+          new arma::Mat<uint64_t>(tree->Dataset().n_rows,
+                                  tree->Dataset().n_cols)),
     ownsDataset(!tree->Parent()),
     localDataset(new std::list<arma::Col<uint64_t>>()),
     largestValue(localDataset->end())
 {
+  typedef typename TreeType::ElemType ElemType;
   // Calculate the Hilbert value for all points
   if(!tree->Parent())
   {
-    for(size_t i = 0; i < tree->Dataset()->n_rows; i++)
-      dataset->col(i) = CalculateValue(tree->Dataset()->col(i));
+    for(size_t i = 0; i < tree->Dataset().n_cols; i++)
+      dataset->col(i) = CalculateValue((arma::Col<ElemType>)tree->Dataset().col(i));
   }
-};
+}
 
 template<typename TreeType>
 DiscreteHilbertValue::DiscreteHilbertValue(const TreeType &other) :
     dataset(other.AuxiliaryInfo().LargestHilbertValue().Dataset()),
-    ownsDataset(!other.Parent()),
-    localDataset(other.AuxiliaryInfo().LargestHilbertValue().LocalDataset()),
+    ownsDataset(false),
+    localDataset(new std::list<arma::Col<uint64_t>>()),
     largestValue(other.AuxiliaryInfo().LargestHilbertValue().LargestValue())
 {
-};
+  if(other.IsLeaf())
+  {
+    std::list<arma::Col<uint64_t>> *otherDataset =
+                     other.AuxiliaryInfo().LargestHilbertValue().LocalDataset();
+    for(std::list<arma::Col<uint64_t>>::iterator it = otherDataset->begin(); it != otherDataset->end(); it++)
+    {
+      localDataset->push_back(*it);
+    }
+    largestValue = localDataset->end();
+    if(otherDataset->size() > 0)
+      largestValue--;
+  }
+}
 
 template<typename ElemType>
-arma::Col<uint64_t> CalculateValue(const arma::Col<ElemType> &pt)
+arma::Col<uint64_t> DiscreteHilbertValue::
+CalculateValue(const arma::Col<ElemType> &pt)
 {
   arma::Col<uint64_t> res(pt.n_rows);
   constexpr int order = 64;   // The number of bits that we can store
@@ -86,19 +100,19 @@ arma::Col<uint64_t> CalculateValue(const arma::Col<ElemType> &pt)
       normalizedVal /= tmp;
     }
     //  Extract the mantissa
-    uint64_t tmp = 1 << numMantBits;
-    res(i) = std::floor(normalizedVal / numMantBits);
+    uint64_t tmp = (uint64_t)1 << numMantBits;
+    res(i) = std::floor(normalizedVal / tmp);
     //  Add the exponent
-    res(i) |= (e - std::numeric_limits<ElemType>::min_exponent) << numMantBits;
+    res(i) |= ((uint64_t)(e - std::numeric_limits<ElemType>::min_exponent)) << numMantBits;
 
     // Negative values should be inverted
     if(sgn)
-      res(i) = 1 << (order - 1) - 1 - res(i);
+      res(i) = ((uint64_t)1 << (order - 1)) - 1 - res(i);
     else
-      res(i) |= 1 << (order - 1);
+      res(i) |= (uint64_t)1 << (order - 1);
   }
 
-  uint64_t M = 1 << (order - 1);
+  uint64_t M = (uint64_t)1 << (order - 1);
 
   // Since the Hilbert curve is continuous we should permutate and intend
   // coordinate axes depending on the position of the point
@@ -176,14 +190,21 @@ int DiscreteHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
 }
 
 template<typename TreeType>
-int DiscreteHilbertValue::CompareValues(TreeType *tree,
+int DiscreteHilbertValue::CompareValues(TreeType *,
                          DiscreteHilbertValue &val1, DiscreteHilbertValue &val2)
 {
+  if(val1.HasValue() && !val2.HasValue())
+    return 1;
+  else if(!val1.HasValue() && val2.HasValue())
+    return -1;
+  else if(!val1.HasValue() && !val2.HasValue())
+    return 0;
+
   return CompareValues(*val1.LargestValue(),*val2.LargestValue());
 }
 
 template<typename TreeType>
-int DiscreteHilbertValue::CompareWith(TreeType *tree, DiscreteHilbertValue &val)
+int DiscreteHilbertValue::CompareWith(TreeType *, DiscreteHilbertValue &val)
 {
   return CompareValues(*largestValue,*val.LargestValue());
 }
@@ -194,13 +215,18 @@ int DiscreteHilbertValue::CompareWith(TreeType *tree,
 {
   arma::Col<uint64_t> val = CalculateValue(pt);
 
+  if(!HasValue())
+    return -1;
+
   return CompareValues(*largestValue,val);
 }
 
-template<typename TreeType,typename ElemType>
-int DiscreteHilbertValue::CompareWith(TreeType *tree,
+template<typename TreeType>
+int DiscreteHilbertValue::CompareWith(TreeType *,
                                       const size_t point)
 {
+  if(!HasValue())
+    return -1;
   return CompareValues(*largestValue,dataset->col(point));
 }
 
@@ -208,7 +234,7 @@ template<typename TreeType>
 size_t DiscreteHilbertValue::InsertPoint(TreeType *node, const size_t point)
 {
   size_t i = 0;
-  std::list<arma::Col<uint64_t>>::iterator it;
+  std::list<arma::Col<uint64_t>>::iterator it = localDataset->end();
 
   if(node->IsLeaf())
   {
@@ -241,7 +267,7 @@ size_t DiscreteHilbertValue::InsertPoint(TreeType *node, const size_t point)
   {
     // We do not update the largest Hilbert value since we do not know the
     // iterator
-    if(*largestValue < dataset->col(point))
+    if(CompareValues(*largestValue,dataset->col(point)) < 0)
       largestValue = localDataset->end();
   }
 
@@ -308,16 +334,22 @@ void DiscreteHilbertValue::Copy(TreeType *dst, TreeType *src)
   DiscreteHilbertValue &srcVal = src->AuxiliaryInfo().LargestHilbertValue();
 
   // Copy the largest Hilbert value and the local dataset
-  dst.LargestValue() = src.LargestValue();
+  dstVal.LargestValue() = srcVal.LargestValue();
 
-  dst.LocalDataset()->clear();
-  std::list<arma::Col<uint64_t>>::iterator it = src.LocalDataset()->begin();
-  for( ; it != src.LocalDataset()->end(); it++)
-    dst.LocalDataset()->push_back(*it);
+  dstVal.LocalDataset()->clear();
+  std::list<arma::Col<uint64_t>>::iterator it = srcVal.LocalDataset()->begin();
+  for( ; it != srcVal.LocalDataset()->end(); it++)
+    dstVal.LocalDataset()->push_back(*it);
 
+  if(dst->IsLeaf())
+  {
+    dstVal.LargestValue() = dstVal.LocalDataset()->end();
+    if(dst->NumPoints() > 0)
+      dstVal.LargestValue()--;
+  }
 }
 
-inline DiscreteHilbertValue DiscreteHilbertValue::operator = (DiscreteHilbertValue &val)
+inline DiscreteHilbertValue DiscreteHilbertValue::operator = (const DiscreteHilbertValue &val)
 {
   // Copy the largest Hilbert value
   largestValue = val.LargestValue();
@@ -340,21 +372,27 @@ void DiscreteHilbertValue::UpdateLargestValue(TreeType *node)
     for(size_t i = 0; i < node->NumPoints(); i++)
       localDataset->push_back(dataset->col(node->Points()[i]));
     largestValue = localDataset->end();
-    localDataset--;
+    largestValue--;
   }
   else
   {
+    if(localDataset->size() > 0)
+      localDataset->clear();
     //  Update the largest Hilbert value;
     if(node->NumChildren() == 0)
       largestValue = localDataset->end();
-    else if(node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LargestValue() !=
-        node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LocalDataset()->end())
-      largestValue = node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue();
+    else if(node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().HasValue())
+      largestValue = node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LargestValue();
     else
       largestValue = localDataset->end();
   }
 }
 
+inline bool DiscreteHilbertValue::HasValue()
+{
+  return largestValue != localDataset->end();
+}
+
 } // namespace tree
 } // namespace mlpack
 
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
index 4d3070c..01bbf99 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
@@ -31,6 +31,9 @@ class HilbertRTreeAuxiliaryInformation
    */
   HilbertRTreeAuxiliaryInformation(const TreeType &other);
 
+  //! Free memory
+  ~HilbertRTreeAuxiliaryInformation();
+
   /**
    * The Hilbert R tree requires to insert points according to their
    * Hilbert value. This method should take care of it.
@@ -92,13 +95,13 @@ class HilbertRTreeAuxiliaryInformation
 
  private:
   //! The largest Hilbert value of a point enclosed by the node.
-  HilbertValue largestHilbertValue;
+  HilbertValue *largestHilbertValue;
 
  public:
   //! Return the largest Hilbert value of a point covered by the node.
-  HilbertValue LargestHilbertValue() const { return largestHilbertValue; }
+  HilbertValue& LargestHilbertValue() const { return *largestHilbertValue; }
   //! Modify the largest Hilbert value of a point covered by the node.
-  HilbertValue& LargestHilbertValue() { return largestHilbertValue; }
+  HilbertValue& LargestHilbertValue() { return *largestHilbertValue; }
 
   /**
    * Serialize the information.
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
index bd190d5..0341d30 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
@@ -25,7 +25,7 @@ HilbertRTreeAuxiliaryInformation()
 template<typename TreeType,typename HilbertValue>
 HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 HilbertRTreeAuxiliaryInformation(const TreeType *node) :
-    largestHilbertValue(node)
+    largestHilbertValue(new HilbertValue(node))
 {
 
 };
@@ -33,12 +33,19 @@ HilbertRTreeAuxiliaryInformation(const TreeType *node) :
 template<typename TreeType,typename HilbertValue>
 HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 HilbertRTreeAuxiliaryInformation(const TreeType &other) :
-    largestHilbertValue(other)
+    largestHilbertValue(new HilbertValue(other))
 {
 
 };
 
 template<typename TreeType,typename HilbertValue>
+HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
+~HilbertRTreeAuxiliaryInformation()
+{
+  delete largestHilbertValue;
+}
+  
+template<typename TreeType,typename HilbertValue>
 bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 HandlePointInsertion(TreeType *node,const size_t point)
 {
@@ -46,7 +53,7 @@ HandlePointInsertion(TreeType *node,const size_t point)
   {
     // Get the position at which the point should be inserted
     // Update the largest Hilbert value of the node
-    size_t pos = largestHilbertValue.InsertPoint(node,point);
+    size_t pos = largestHilbertValue->InsertPoint(node,point);
 
     // Move points
     for(size_t i = node->NumPoints(); i > pos; i--)
@@ -60,7 +67,7 @@ HandlePointInsertion(TreeType *node,const size_t point)
     node->Count()++;
   }
   else
-    largestHilbertValue.InsertPoint(node,point);  //  Update LHV
+    largestHilbertValue->InsertPoint(node,point);  //  Update LHV
 
   return true;
 }
@@ -90,10 +97,10 @@ HandleNodeInsertion(TreeType *node,TreeType *nodeToInsert,bool insertionLevel)
     nodeToInsert->Parent() = node;
 
     // Update the largest Hilbert value
-    largestHilbertValue.InsertNode(nodeToInsert);
+    largestHilbertValue->InsertNode(nodeToInsert);
   }
   else
-    largestHilbertValue.InsertNode(nodeToInsert); //  Update LHV
+    largestHilbertValue->InsertNode(nodeToInsert); //  Update LHV
 
   return true;
 }
@@ -103,7 +110,7 @@ bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 HandlePointDeletion(TreeType *node,const size_t localIndex)
 {
   // Update the largest Hilbert value
-  largestHilbertValue.DeletePoint(node,localIndex);
+  largestHilbertValue->DeletePoint(node,localIndex);
 
   for(size_t i = localIndex + 1; localIndex < node->NumPoints(); i++)
   {
@@ -119,7 +126,7 @@ bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 HandleNodeRemoval(TreeType *node,const size_t nodeIndex)
 {
   // Update the largest Hilbert value
-  largestHilbertValue.RemoveNode(node,nodeIndex);
+  largestHilbertValue->RemoveNode(node,nodeIndex);
 
   for(size_t i = nodeIndex + 1; nodeIndex < node->NumChildren(); i++)
     node->Children()[i-1] = node->Children()[i];
@@ -139,7 +146,8 @@ UpdateAuxiliaryInfo(TreeType *node)
   if(HilbertValue::CompareValues(largestHilbertValue,
                             child->AuxiliaryInfo().LargestHilbertValue()) < 0)
   {
-    largestHilbertValue = child->AuxiliaryInfo().LargestHilbertValue();
+    largestHilbertValue->Copy(node,child);
+//    largestHilbertValue = child->AuxiliaryInfo().LargestHilbertValue();
     return true;
   }
   return false;
@@ -149,7 +157,7 @@ template<typename TreeType,typename HilbertValue>
 void HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
 Copy(TreeType *dst,TreeType *src)
 {
-  largestHilbertValue.Copy(dst,src);
+  largestHilbertValue->Copy(dst,src);
 }
 
 template<typename TreeType,typename HilbertValue>
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
index e046a74..f830c13 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
@@ -72,7 +72,7 @@ class HilbertRTreeSplit
    * @param lastSibling The last cooperating sibling.
    */
   template<typename TreeType>
-  static void RedistributePointsEvenly(const TreeType *parent,
+  static void RedistributePointsEvenly(TreeType *parent,
                                 size_t firstSibling,size_t lastSibling);
 
 };
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index 0f0ea4b..2315e88 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -65,7 +65,7 @@ SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
                  iTree + splitOrder : parent->NumChildren() - 1);
   firstSibling = (lastSibling > splitOrder ? lastSibling - splitOrder : 0);
 
-  assert(lastSibling - firstSibling == splitOrder);
+  assert(lastSibling - firstSibling <= splitOrder);
   assert(firstSibling >= 0);
   assert(lastSibling < parent->NumChildren());
 
@@ -130,7 +130,7 @@ SplitNonLeafNode(TreeType *tree,std::vector<bool>& relevels)
   firstSibling = (lastSibling > splitOrder ?
                   lastSibling - splitOrder : 0);
 
-  assert(lastSibling - firstSibling == splitOrder);
+  assert(lastSibling - firstSibling <= splitOrder);
   assert(firstSibling >= 0);
   assert(lastSibling < parent->NumChildren());
 
@@ -251,14 +251,13 @@ RedistributeNodesEvenly(const TreeType *parent,
            parent->Children()[i]->MaxNumChildren());
 
     // Fix the largest Hilbert value of the sibling.
-    parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue() = 
-                    children[iChild-1]->AuxiliaryInfo().LargestHilbertValue();
+    parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(parent->Children()[i]);
   }
 }
 
 template<typename TreeType>
 void HilbertRTreeSplit::
-RedistributePointsEvenly(const TreeType *parent,
+RedistributePointsEvenly(TreeType *parent,
                          size_t firstSibling, size_t lastSibling)
 {
   size_t numPoints = 0;
@@ -320,6 +319,14 @@ RedistributePointsEvenly(const TreeType *parent,
     // Fix the largest Hilbert value of the sibling.
     parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(parent->Children()[i]);
   }
+
+  TreeType *root = parent;
+
+  while(root != NULL)
+  {
+    root->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(root);
+    root = root->Parent();
+  }
 }
 
 } // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index b2fa544..84b728e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -94,7 +94,7 @@ class RectangleTree
   //! The local dataset
   MatType* localDataset;
   //! A tree-specific information
-  AuxiliaryInformationType<RectangleTree> auxiliaryInfo;
+  AuxiliaryInformationType<RectangleTree> *auxiliaryInfo;
 
  public:
   //! A single traverser for rectangle type trees.  See
@@ -294,11 +294,11 @@ class RectangleTree
   StatisticType& Stat() { return stat; }
 
   //! Return the auxiliary information object of this node.
-  const AuxiliaryInformationType<RectangleTree>& AuxiliaryInfo() const 
-  { return auxiliaryInfo; }
+  const AuxiliaryInformationType<RectangleTree> &AuxiliaryInfo() const
+  { return *auxiliaryInfo; }
   //! Modify the split object of this node.
   AuxiliaryInformationType<RectangleTree>& AuxiliaryInfo()
-  { return auxiliaryInfo; }
+  { return *auxiliaryInfo; }
 
   //! Return whether or not this node is a leaf (true if it has no children).
   bool IsLeaf() const;
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 3b397a8..c9be471 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -50,7 +50,7 @@ RectangleTree(const MatType& data,
 {
   stat = StatisticType(*this);
 
-  auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+  auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
 
   // For now, just insert the points in order.
   RectangleTree* root = this;
@@ -92,7 +92,7 @@ RectangleTree(MatType&& data,
 {
   stat = StatisticType(*this);
 
-  auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+  auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
 
   // For now, just insert the points in order.
   RectangleTree* root = this;
@@ -132,7 +132,7 @@ RectangleTree(
                                                   maxLeafSize + 1)))
 {
   stat = StatisticType(*this);
-  auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+  auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
 }
 
 /**
@@ -166,7 +166,7 @@ RectangleTree(
     points(other.Points()),
     localDataset(NULL)
 {
-  auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(other);
+  auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(other);
   if (deepCopy)
   {
     if (numChildren > 0)
@@ -225,6 +225,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
               AuxiliaryInformationType>::
 ~RectangleTree()
 {
+  delete auxiliaryInfo;
   for (size_t i = 0; i < numChildren; i++)
     delete children[i];
 
@@ -297,7 +298,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   // If this is a leaf node, we stop here and add the point.
   if (numChildren == 0)
   {
-    if(!auxiliaryInfo.HandlePointInsertion(this,point))
+    if(!auxiliaryInfo->HandlePointInsertion(this,point))
     {
       localDataset->col(count) = dataset->col(point);
       points[count++] = point;
@@ -308,7 +309,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 
   // If it is not a leaf node, we use the DescentHeuristic to choose a child
   // to which we recurse.
-  auxiliaryInfo.HandlePointInsertion(this,point);
+  auxiliaryInfo->HandlePointInsertion(this,point);
   const size_t descentNode = DescentType::ChooseDescentNode(this, point);
   children[descentNode]->InsertPoint(point, lvls);
 }
@@ -335,7 +336,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   // If this is a leaf node, we stop here and add the point.
   if (numChildren == 0)
   {
-    if(!auxiliaryInfo.HandlePointInsertion(this,point))
+    if(!auxiliaryInfo->HandlePointInsertion(this,point))
     {
       localDataset->col(count) = dataset->col(point);
       points[count++] = point;
@@ -346,7 +347,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 
   // If it is not a leaf node, we use the DescentHeuristic to choose a child
   // to which we recurse.
-  auxiliaryInfo.HandlePointInsertion(this,point);
+  auxiliaryInfo->HandlePointInsertion(this,point);
   const size_t descentNode = DescentType::ChooseDescentNode(this,point);
   children[descentNode]->InsertPoint(point, relevels);
 }
@@ -375,7 +376,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   bound |= node->Bound();
   if (level == TreeDepth())
   {
-    if(!auxiliaryInfo.HandleNodeInsertion(this,node,true))
+    if(!auxiliaryInfo->HandleNodeInsertion(this,node,true))
     {
       children[numChildren++] = node;
       node->Parent() = this;
@@ -384,7 +385,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   }
   else
   {
-    auxiliaryInfo.HandleNodeInsertion(this,node,false);
+    auxiliaryInfo->HandleNodeInsertion(this,node,false);
     const size_t descentNode = DescentType::ChooseDescentNode(this, node);
     children[descentNode]->InsertNode(node, level, relevels);
   }
@@ -420,7 +421,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
     {
       if (points[i] == point)
       {
-        if(!auxiliaryInfo.HandlePointDeletion(this,i))
+        if(!auxiliaryInfo->HandlePointDeletion(this,i))
         {
           localDataset->col(i) = localDataset->col(--count); // Decrement count.
           points[i] = points[count];
@@ -460,7 +461,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
     {
       if (points[i] == point)
       {
-        if(!auxiliaryInfo.HandlePointDeletion(this,i))
+        if(!auxiliaryInfo->HandlePointDeletion(this,i))
         {
           localDataset->col(i) = localDataset->col(--count);
           points[i] = points[count];
@@ -498,7 +499,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
   {
     if (children[i] == node)
     {
-      if(!auxiliaryInfo.HandleNodeRemoval(this,i))
+      if(!auxiliaryInfo->HandleNodeRemoval(this,i))
       {
         children[i] = children[--numChildren]; // Decrement numChildren.
       }
@@ -843,7 +844,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
         if (parent->Children()[j] == this)
         {
           // Decrement numChildren.
-          if(!auxiliaryInfo.HandleNodeRemoval(parent,j))
+          if(!auxiliaryInfo->HandleNodeRemoval(parent,j))
           {
             parent->Children()[j] = parent->Children()[--parent->NumChildren()];
           }
@@ -911,7 +912,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
         localDataset->col(i) = child->LocalDataset().col(i);
       }
 
-      auxiliaryInfo.Copy(this,child);
+      auxiliaryInfo->Copy(this,child);
 
       count = child->Count();
       child->SoftDelete();
@@ -921,11 +922,11 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
 
   // If we didn't delete it, shrink the bound if we need to.
   if (usePoint &&
-      (ShrinkBoundForPoint(point) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
+      (ShrinkBoundForPoint(point) || auxiliaryInfo->UpdateAuxiliaryInfo(this)) &&
       parent != NULL)
     parent->CondenseTree(point, relevels, usePoint);
   else if (!usePoint &&
-           (ShrinkBoundForBound(bound) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
+           (ShrinkBoundForBound(bound) || auxiliaryInfo->UpdateAuxiliaryInfo(this)) &&
            parent != NULL)
     parent->CondenseTree(point, relevels, usePoint);
 }
diff --git a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
index 2f90ed5..d020461 100644
--- a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
@@ -54,14 +54,23 @@ class RecursiveHilbertValue
     //! Indicates that the axis should be inverted
     std::vector<bool> inversion;
     //! Indicates that the result should be inverted
+    arma::Col<ElemType> center;
+    arma::Col<ElemType> vec;
+    std::vector<int> bits;
+    std::vector<int> bits2;
     bool invertResult;
     int recursionLevel;
 
+
     tagCompareStruct(size_t dim) :
       Lo(dim),
       Hi(dim),
       permutation(dim),
       inversion(dim),
+      center(dim),
+      vec(dim),
+      bits(dim),
+      bits2(dim),
       invertResult(false),
       recursionLevel(0)
     {
@@ -186,7 +195,10 @@ class RecursiveHilbertValue
   void UpdateLargestValue(TreeType *node);
 
   //! Return the largest Hilbert value
-  size_t LargestValue() const { return largestValue; }
+  ptrdiff_t LargestValue() const { return largestValue; }
+
+  //! Modify the largest Hilbert value
+  ptrdiff_t& LargestValue() { return largestValue; }
 
  private:
   //! The largest Hilbert value i.e. the number of the point in the dataset.
diff --git a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
index b78cd76..e95daaa 100644
--- a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
@@ -63,45 +63,43 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
                                          const arma::Col<ElemType> &pt2,
                                          CompareStruct<ElemType> &comp)
 {
-  arma::Col<ElemType> center = comp.Hi * 0.5;
-  arma::Col<ElemType> vec = comp.Lo * 0.5;
-  std::vector<int> bits(pt1.n_rows,0);
-  std::vector<int> bits2(pt1.n_rows,0);
+  comp.center = comp.Hi * 0.5;
+  comp.vec = comp.Lo * 0.5;
 
-  center += vec;
+  comp.center += comp.vec;
 
   // Get bits in order to use the Gray code
   for(size_t i = 0; i < pt1.n_rows; i++)
   {
     size_t j = comp.permutation[i];
-    bits[i] = (pt1(j) > center(j) && !comp.inversion[j]) ||
-       (pt1(j) <= center(j) && !comp.inversion[j]);
+    comp.bits[i] = (pt1(j) > comp.center(j) && !comp.inversion[j]) ||
+       (pt1(j) <= comp.center(j) && !comp.inversion[j]);
 
-    bits2[i] = (pt2(j) > center(j) && !comp.inversion[j]) ||
-       (pt2(j) <= center(j) && !comp.inversion[j]);
+    comp.bits2[i] = (pt2(j) > comp.center(j) && !comp.inversion[j]) ||
+       (pt2(j) <= comp.center(j) && !comp.inversion[j]);
   }
 
   // Gray encode
   for(size_t i = 1; i < pt1.n_rows; i++)
   {
-    bits[i] ^= bits[i-1];
-    bits2[i] ^= bits2[i-1];
+    comp.bits[i] ^= comp.bits[i-1];
+    comp.bits2[i] ^= comp.bits2[i-1];
   }
 
   if(comp.invertResult)
   {
     for(size_t i = 0; i < pt1.n_rows; i++)
     {
-      bits[i] = !bits[i];
-      bits2[i] = !bits2[i];
+      comp.bits[i] = !comp.bits[i];
+      comp.bits2[i] = !comp.bits2[i];
     }
   }
 
   for(size_t i = 0; i < pt1.n_rows; i++)
   {
-    if(bits[i] < bits2[i])
+    if(comp.bits[i] < comp.bits2[i])
       return -1;
-    if(bits[i] > bits2[i])
+    if(comp.bits[i] > comp.bits2[i])
       return 1;
   }
 
@@ -110,7 +108,7 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
 
   comp.recursionLevel++;
 
-  if(bits[pt1.n_rows-1])
+  if(comp.bits[pt1.n_rows-1])
     comp.invertResult = !comp.invertResult;
 
   // Since the Hilbert curve is continuous we should permutate and intend
@@ -119,8 +117,8 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
   {
     size_t j = comp.permutation[i];
     size_t j0 = comp.permutation[0];
-    if((pt1(j) > center(j) && !comp.inversion[j]) ||
-       (pt1(j) <= center(j) && !comp.inversion[j]))
+    if((pt1(j) > comp.center(j) && !comp.inversion[j]) ||
+       (pt1(j) <= comp.center(j) && !comp.inversion[j]))
       comp.inversion[j0] = !comp.inversion[j0];
     else
     {
@@ -134,10 +132,10 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
   // Choose an appropriate subhypercube
   for(size_t i = 0; i < pt1.n_rows; i++)
   {
-    if(pt1(i) > center(i))
-      comp.Lo(i) = center(i);
+    if(pt1(i) > comp.center(i))
+      comp.Lo(i) = comp.center(i);
     else
-      comp.Hi(i) = center(i);
+      comp.Hi(i) = comp.center(i);
   }
 
   return ComparePoints(pt1,pt2,comp);
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index cb87d07..6622099 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -115,7 +115,7 @@ using RecursiveHilbertRTree = RectangleTree<MetricType,
 
 template<typename TreeType>
 using DiscreteHilbertRTreeAuxiliaryInformation =
-      HilbertRTreeAuxiliaryInformation<TreeType,RecursiveHilbertValue>;
+      HilbertRTreeAuxiliaryInformation<TreeType,DiscreteHilbertValue>;
 
 template<typename MetricType, typename StatisticType, typename MatType>
 using DiscreteHilbertRTree = RectangleTree<MetricType,
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 8a86b05..bf390ba 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -706,8 +706,8 @@ void CheckHilbertOrdering(TreeType *tree)
     for(size_t i = 0; i < tree->NumPoints() - 1; i++)
       BOOST_REQUIRE_LE(
               tree->AuxiliaryInfo().LargestHilbertValue().ComparePoints(
-                arma::vec(tree->LocalDataset().col(i-1)),
-                arma::vec(tree->LocalDataset().col(i))),
+                arma::vec(tree->LocalDataset().col(i)),
+                arma::vec(tree->LocalDataset().col(i+1))),
               0);
 
     BOOST_REQUIRE_EQUAL(
@@ -721,8 +721,8 @@ void CheckHilbertOrdering(TreeType *tree)
     for(size_t i = 0; i < tree->NumChildren() - 1; i++)
       BOOST_REQUIRE_LE(
             tree->AuxiliaryInfo().LargestHilbertValue().CompareValues(tree,
-                tree->Children()[i-1]->AuxiliaryInfo().LargestHilbertValue(),
-                tree->Children()[i]->AuxiliaryInfo().LargestHilbertValue()),
+                tree->Children()[i]->AuxiliaryInfo().LargestHilbertValue(),
+                tree->Children()[i+1]->AuxiliaryInfo().LargestHilbertValue()),
             0);
 
     BOOST_REQUIRE_EQUAL(




More information about the mlpack-git mailing list