[mlpack-git] master: Hilbert R tree fixes. (72f53d6)
gitdub at mlpack.org
gitdub at mlpack.org
Mon Jun 27 11:36:02 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/37fda23945b4f998cd5fa6ec011ae345236c8552...479eca0c625cc4255a3b1a354a4788dae10f1b01
>---------------------------------------------------------------
commit 72f53d600fb7b511af6d4e0acb2fecfa4bc17593
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Wed Jun 1 16:34:47 2016 +0300
Hilbert R tree fixes.
>---------------------------------------------------------------
72f53d600fb7b511af6d4e0acb2fecfa4bc17593
.../tree/rectangle_tree/discrete_hilbert_value.hpp | 12 ++-
.../rectangle_tree/discrete_hilbert_value_impl.hpp | 100 ++++++++++++++-------
.../hilbert_r_tree_auxiliary_information.hpp | 9 +-
.../hilbert_r_tree_auxiliary_information_impl.hpp | 28 +++---
.../tree/rectangle_tree/hilbert_r_tree_split.hpp | 2 +-
.../rectangle_tree/hilbert_r_tree_split_impl.hpp | 17 ++--
.../core/tree/rectangle_tree/rectangle_tree.hpp | 8 +-
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 35 ++++----
.../rectangle_tree/recursive_hilbert_value.hpp | 14 ++-
.../recursive_hilbert_value_impl.hpp | 40 ++++-----
src/mlpack/core/tree/rectangle_tree/typedef.hpp | 2 +-
src/mlpack/tests/rectangle_tree_test.cpp | 8 +-
12 files changed, 175 insertions(+), 100 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
index 6bcee8f..3350100 100644
--- a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value.hpp
@@ -92,7 +92,7 @@ class DiscreteHilbertValue
* @param tree Not used
* @param val The number of the point to compare with.
*/
- template<typename TreeType,typename ElemType>
+ template<typename TreeType>
int CompareWith(TreeType *tree, const size_t point);
/**
@@ -146,12 +146,16 @@ class DiscreteHilbertValue
void UpdateLargestValue(TreeType *node);
//! Copy the largest Hilbert value.
- DiscreteHilbertValue operator = (DiscreteHilbertValue &val);
+ DiscreteHilbertValue operator = (const DiscreteHilbertValue &val);
//! Return the largest Hilbert value
std::list<arma::Col<uint64_t>>::iterator LargestValue() const
{ return largestValue; }
+ //! Modify the largest Hilbert value
+ std::list<arma::Col<uint64_t>>::iterator &LargestValue()
+ { return largestValue; }
+
//! Modify the local dataset
std::list<arma::Col<uint64_t>> *LocalDataset() { return localDataset; }
//! Modify the dataset
@@ -182,6 +186,10 @@ class DiscreteHilbertValue
*/
static int CompareValues(const arma::Col<uint64_t> &value1,
const arma::Col<uint64_t> &value2);
+ /**
+ * Returns true if the node has the largest Hilbert value.
+ */
+ bool HasValue();
};
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
index ce4df87..787ae8c 100644
--- a/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/discrete_hilbert_value_impl.hpp
@@ -33,31 +33,45 @@ template<typename TreeType>
DiscreteHilbertValue::DiscreteHilbertValue(const TreeType *tree) :
dataset(tree->Parent() ?
tree->Parent()->AuxiliaryInfo().LargestHilbertValue().Dataset() :
- new arma::Mat<uint64_t>(tree->Dataset()->n_rows,
- tree->MaxLeafSize()+1)),
+ new arma::Mat<uint64_t>(tree->Dataset().n_rows,
+ tree->Dataset().n_cols)),
ownsDataset(!tree->Parent()),
localDataset(new std::list<arma::Col<uint64_t>>()),
largestValue(localDataset->end())
{
+ typedef typename TreeType::ElemType ElemType;
// Calculate the Hilbert value for all points
if(!tree->Parent())
{
- for(size_t i = 0; i < tree->Dataset()->n_rows; i++)
- dataset->col(i) = CalculateValue(tree->Dataset()->col(i));
+ for(size_t i = 0; i < tree->Dataset().n_cols; i++)
+ dataset->col(i) = CalculateValue((arma::Col<ElemType>)tree->Dataset().col(i));
}
-};
+}
template<typename TreeType>
DiscreteHilbertValue::DiscreteHilbertValue(const TreeType &other) :
dataset(other.AuxiliaryInfo().LargestHilbertValue().Dataset()),
- ownsDataset(!other.Parent()),
- localDataset(other.AuxiliaryInfo().LargestHilbertValue().LocalDataset()),
+ ownsDataset(false),
+ localDataset(new std::list<arma::Col<uint64_t>>()),
largestValue(other.AuxiliaryInfo().LargestHilbertValue().LargestValue())
{
-};
+ if(other.IsLeaf())
+ {
+ std::list<arma::Col<uint64_t>> *otherDataset =
+ other.AuxiliaryInfo().LargestHilbertValue().LocalDataset();
+ for(std::list<arma::Col<uint64_t>>::iterator it = otherDataset->begin(); it != otherDataset->end(); it++)
+ {
+ localDataset->push_back(*it);
+ }
+ largestValue = localDataset->end();
+ if(otherDataset->size() > 0)
+ largestValue--;
+ }
+}
template<typename ElemType>
-arma::Col<uint64_t> CalculateValue(const arma::Col<ElemType> &pt)
+arma::Col<uint64_t> DiscreteHilbertValue::
+CalculateValue(const arma::Col<ElemType> &pt)
{
arma::Col<uint64_t> res(pt.n_rows);
constexpr int order = 64; // The number of bits that we can store
@@ -86,19 +100,19 @@ arma::Col<uint64_t> CalculateValue(const arma::Col<ElemType> &pt)
normalizedVal /= tmp;
}
// Extract the mantissa
- uint64_t tmp = 1 << numMantBits;
- res(i) = std::floor(normalizedVal / numMantBits);
+ uint64_t tmp = (uint64_t)1 << numMantBits;
+ res(i) = std::floor(normalizedVal / tmp);
// Add the exponent
- res(i) |= (e - std::numeric_limits<ElemType>::min_exponent) << numMantBits;
+ res(i) |= ((uint64_t)(e - std::numeric_limits<ElemType>::min_exponent)) << numMantBits;
// Negative values should be inverted
if(sgn)
- res(i) = 1 << (order - 1) - 1 - res(i);
+ res(i) = ((uint64_t)1 << (order - 1)) - 1 - res(i);
else
- res(i) |= 1 << (order - 1);
+ res(i) |= (uint64_t)1 << (order - 1);
}
- uint64_t M = 1 << (order - 1);
+ uint64_t M = (uint64_t)1 << (order - 1);
// Since the Hilbert curve is continuous we should permutate and intend
// coordinate axes depending on the position of the point
@@ -176,14 +190,21 @@ int DiscreteHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
}
template<typename TreeType>
-int DiscreteHilbertValue::CompareValues(TreeType *tree,
+int DiscreteHilbertValue::CompareValues(TreeType *,
DiscreteHilbertValue &val1, DiscreteHilbertValue &val2)
{
+ if(val1.HasValue() && !val2.HasValue())
+ return 1;
+ else if(!val1.HasValue() && val2.HasValue())
+ return -1;
+ else if(!val1.HasValue() && !val2.HasValue())
+ return 0;
+
return CompareValues(*val1.LargestValue(),*val2.LargestValue());
}
template<typename TreeType>
-int DiscreteHilbertValue::CompareWith(TreeType *tree, DiscreteHilbertValue &val)
+int DiscreteHilbertValue::CompareWith(TreeType *, DiscreteHilbertValue &val)
{
return CompareValues(*largestValue,*val.LargestValue());
}
@@ -194,13 +215,18 @@ int DiscreteHilbertValue::CompareWith(TreeType *tree,
{
arma::Col<uint64_t> val = CalculateValue(pt);
+ if(!HasValue())
+ return -1;
+
return CompareValues(*largestValue,val);
}
-template<typename TreeType,typename ElemType>
-int DiscreteHilbertValue::CompareWith(TreeType *tree,
+template<typename TreeType>
+int DiscreteHilbertValue::CompareWith(TreeType *,
const size_t point)
{
+ if(!HasValue())
+ return -1;
return CompareValues(*largestValue,dataset->col(point));
}
@@ -208,7 +234,7 @@ template<typename TreeType>
size_t DiscreteHilbertValue::InsertPoint(TreeType *node, const size_t point)
{
size_t i = 0;
- std::list<arma::Col<uint64_t>>::iterator it;
+ std::list<arma::Col<uint64_t>>::iterator it = localDataset->end();
if(node->IsLeaf())
{
@@ -241,7 +267,7 @@ size_t DiscreteHilbertValue::InsertPoint(TreeType *node, const size_t point)
{
// We do not update the largest Hilbert value since we do not know the
// iterator
- if(*largestValue < dataset->col(point))
+ if(CompareValues(*largestValue,dataset->col(point)) < 0)
largestValue = localDataset->end();
}
@@ -308,16 +334,22 @@ void DiscreteHilbertValue::Copy(TreeType *dst, TreeType *src)
DiscreteHilbertValue &srcVal = src->AuxiliaryInfo().LargestHilbertValue();
// Copy the largest Hilbert value and the local dataset
- dst.LargestValue() = src.LargestValue();
+ dstVal.LargestValue() = srcVal.LargestValue();
- dst.LocalDataset()->clear();
- std::list<arma::Col<uint64_t>>::iterator it = src.LocalDataset()->begin();
- for( ; it != src.LocalDataset()->end(); it++)
- dst.LocalDataset()->push_back(*it);
+ dstVal.LocalDataset()->clear();
+ std::list<arma::Col<uint64_t>>::iterator it = srcVal.LocalDataset()->begin();
+ for( ; it != srcVal.LocalDataset()->end(); it++)
+ dstVal.LocalDataset()->push_back(*it);
+ if(dst->IsLeaf())
+ {
+ dstVal.LargestValue() = dstVal.LocalDataset()->end();
+ if(dst->NumPoints() > 0)
+ dstVal.LargestValue()--;
+ }
}
-inline DiscreteHilbertValue DiscreteHilbertValue::operator = (DiscreteHilbertValue &val)
+inline DiscreteHilbertValue DiscreteHilbertValue::operator = (const DiscreteHilbertValue &val)
{
// Copy the largest Hilbert value
largestValue = val.LargestValue();
@@ -340,21 +372,27 @@ void DiscreteHilbertValue::UpdateLargestValue(TreeType *node)
for(size_t i = 0; i < node->NumPoints(); i++)
localDataset->push_back(dataset->col(node->Points()[i]));
largestValue = localDataset->end();
- localDataset--;
+ largestValue--;
}
else
{
+ if(localDataset->size() > 0)
+ localDataset->clear();
// Update the largest Hilbert value;
if(node->NumChildren() == 0)
largestValue = localDataset->end();
- else if(node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LargestValue() !=
- node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LocalDataset()->end())
- largestValue = node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue();
+ else if(node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().HasValue())
+ largestValue = node->Children()[node->NumChildren()-1]->AuxiliaryInfo().LargestHilbertValue().LargestValue();
else
largestValue = localDataset->end();
}
}
+inline bool DiscreteHilbertValue::HasValue()
+{
+ return largestValue != localDataset->end();
+}
+
} // namespace tree
} // namespace mlpack
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
index 4d3070c..01bbf99 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information.hpp
@@ -31,6 +31,9 @@ class HilbertRTreeAuxiliaryInformation
*/
HilbertRTreeAuxiliaryInformation(const TreeType &other);
+ //! Free memory
+ ~HilbertRTreeAuxiliaryInformation();
+
/**
* The Hilbert R tree requires to insert points according to their
* Hilbert value. This method should take care of it.
@@ -92,13 +95,13 @@ class HilbertRTreeAuxiliaryInformation
private:
//! The largest Hilbert value of a point enclosed by the node.
- HilbertValue largestHilbertValue;
+ HilbertValue *largestHilbertValue;
public:
//! Return the largest Hilbert value of a point covered by the node.
- HilbertValue LargestHilbertValue() const { return largestHilbertValue; }
+ HilbertValue& LargestHilbertValue() const { return *largestHilbertValue; }
//! Modify the largest Hilbert value of a point covered by the node.
- HilbertValue& LargestHilbertValue() { return largestHilbertValue; }
+ HilbertValue& LargestHilbertValue() { return *largestHilbertValue; }
/**
* Serialize the information.
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
index bd190d5..0341d30 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_auxiliary_information_impl.hpp
@@ -25,7 +25,7 @@ HilbertRTreeAuxiliaryInformation()
template<typename TreeType,typename HilbertValue>
HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
HilbertRTreeAuxiliaryInformation(const TreeType *node) :
- largestHilbertValue(node)
+ largestHilbertValue(new HilbertValue(node))
{
};
@@ -33,12 +33,19 @@ HilbertRTreeAuxiliaryInformation(const TreeType *node) :
template<typename TreeType,typename HilbertValue>
HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
HilbertRTreeAuxiliaryInformation(const TreeType &other) :
- largestHilbertValue(other)
+ largestHilbertValue(new HilbertValue(other))
{
};
template<typename TreeType,typename HilbertValue>
+HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
+~HilbertRTreeAuxiliaryInformation()
+{
+ delete largestHilbertValue;
+}
+
+template<typename TreeType,typename HilbertValue>
bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
HandlePointInsertion(TreeType *node,const size_t point)
{
@@ -46,7 +53,7 @@ HandlePointInsertion(TreeType *node,const size_t point)
{
// Get the position at which the point should be inserted
// Update the largest Hilbert value of the node
- size_t pos = largestHilbertValue.InsertPoint(node,point);
+ size_t pos = largestHilbertValue->InsertPoint(node,point);
// Move points
for(size_t i = node->NumPoints(); i > pos; i--)
@@ -60,7 +67,7 @@ HandlePointInsertion(TreeType *node,const size_t point)
node->Count()++;
}
else
- largestHilbertValue.InsertPoint(node,point); // Update LHV
+ largestHilbertValue->InsertPoint(node,point); // Update LHV
return true;
}
@@ -90,10 +97,10 @@ HandleNodeInsertion(TreeType *node,TreeType *nodeToInsert,bool insertionLevel)
nodeToInsert->Parent() = node;
// Update the largest Hilbert value
- largestHilbertValue.InsertNode(nodeToInsert);
+ largestHilbertValue->InsertNode(nodeToInsert);
}
else
- largestHilbertValue.InsertNode(nodeToInsert); // Update LHV
+ largestHilbertValue->InsertNode(nodeToInsert); // Update LHV
return true;
}
@@ -103,7 +110,7 @@ bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
HandlePointDeletion(TreeType *node,const size_t localIndex)
{
// Update the largest Hilbert value
- largestHilbertValue.DeletePoint(node,localIndex);
+ largestHilbertValue->DeletePoint(node,localIndex);
for(size_t i = localIndex + 1; localIndex < node->NumPoints(); i++)
{
@@ -119,7 +126,7 @@ bool HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
HandleNodeRemoval(TreeType *node,const size_t nodeIndex)
{
// Update the largest Hilbert value
- largestHilbertValue.RemoveNode(node,nodeIndex);
+ largestHilbertValue->RemoveNode(node,nodeIndex);
for(size_t i = nodeIndex + 1; nodeIndex < node->NumChildren(); i++)
node->Children()[i-1] = node->Children()[i];
@@ -139,7 +146,8 @@ UpdateAuxiliaryInfo(TreeType *node)
if(HilbertValue::CompareValues(largestHilbertValue,
child->AuxiliaryInfo().LargestHilbertValue()) < 0)
{
- largestHilbertValue = child->AuxiliaryInfo().LargestHilbertValue();
+ largestHilbertValue->Copy(node,child);
+// largestHilbertValue = child->AuxiliaryInfo().LargestHilbertValue();
return true;
}
return false;
@@ -149,7 +157,7 @@ template<typename TreeType,typename HilbertValue>
void HilbertRTreeAuxiliaryInformation<TreeType,HilbertValue>::
Copy(TreeType *dst,TreeType *src)
{
- largestHilbertValue.Copy(dst,src);
+ largestHilbertValue->Copy(dst,src);
}
template<typename TreeType,typename HilbertValue>
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
index e046a74..f830c13 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split.hpp
@@ -72,7 +72,7 @@ class HilbertRTreeSplit
* @param lastSibling The last cooperating sibling.
*/
template<typename TreeType>
- static void RedistributePointsEvenly(const TreeType *parent,
+ static void RedistributePointsEvenly(TreeType *parent,
size_t firstSibling,size_t lastSibling);
};
diff --git a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
index 0f0ea4b..2315e88 100644
--- a/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/hilbert_r_tree_split_impl.hpp
@@ -65,7 +65,7 @@ SplitLeafNode(TreeType *tree,std::vector<bool>& relevels)
iTree + splitOrder : parent->NumChildren() - 1);
firstSibling = (lastSibling > splitOrder ? lastSibling - splitOrder : 0);
- assert(lastSibling - firstSibling == splitOrder);
+ assert(lastSibling - firstSibling <= splitOrder);
assert(firstSibling >= 0);
assert(lastSibling < parent->NumChildren());
@@ -130,7 +130,7 @@ SplitNonLeafNode(TreeType *tree,std::vector<bool>& relevels)
firstSibling = (lastSibling > splitOrder ?
lastSibling - splitOrder : 0);
- assert(lastSibling - firstSibling == splitOrder);
+ assert(lastSibling - firstSibling <= splitOrder);
assert(firstSibling >= 0);
assert(lastSibling < parent->NumChildren());
@@ -251,14 +251,13 @@ RedistributeNodesEvenly(const TreeType *parent,
parent->Children()[i]->MaxNumChildren());
// Fix the largest Hilbert value of the sibling.
- parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue() =
- children[iChild-1]->AuxiliaryInfo().LargestHilbertValue();
+ parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(parent->Children()[i]);
}
}
template<typename TreeType>
void HilbertRTreeSplit::
-RedistributePointsEvenly(const TreeType *parent,
+RedistributePointsEvenly(TreeType *parent,
size_t firstSibling, size_t lastSibling)
{
size_t numPoints = 0;
@@ -320,6 +319,14 @@ RedistributePointsEvenly(const TreeType *parent,
// Fix the largest Hilbert value of the sibling.
parent->Children()[i]->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(parent->Children()[i]);
}
+
+ TreeType *root = parent;
+
+ while(root != NULL)
+ {
+ root->AuxiliaryInfo().LargestHilbertValue().UpdateLargestValue(root);
+ root = root->Parent();
+ }
}
} // namespace tree
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index b2fa544..84b728e 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -94,7 +94,7 @@ class RectangleTree
//! The local dataset
MatType* localDataset;
//! A tree-specific information
- AuxiliaryInformationType<RectangleTree> auxiliaryInfo;
+ AuxiliaryInformationType<RectangleTree> *auxiliaryInfo;
public:
//! A single traverser for rectangle type trees. See
@@ -294,11 +294,11 @@ class RectangleTree
StatisticType& Stat() { return stat; }
//! Return the auxiliary information object of this node.
- const AuxiliaryInformationType<RectangleTree>& AuxiliaryInfo() const
- { return auxiliaryInfo; }
+ const AuxiliaryInformationType<RectangleTree> &AuxiliaryInfo() const
+ { return *auxiliaryInfo; }
//! Modify the split object of this node.
AuxiliaryInformationType<RectangleTree>& AuxiliaryInfo()
- { return auxiliaryInfo; }
+ { return *auxiliaryInfo; }
//! Return whether or not this node is a leaf (true if it has no children).
bool IsLeaf() const;
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 3b397a8..c9be471 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -50,7 +50,7 @@ RectangleTree(const MatType& data,
{
stat = StatisticType(*this);
- auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+ auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
// For now, just insert the points in order.
RectangleTree* root = this;
@@ -92,7 +92,7 @@ RectangleTree(MatType&& data,
{
stat = StatisticType(*this);
- auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+ auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
// For now, just insert the points in order.
RectangleTree* root = this;
@@ -132,7 +132,7 @@ RectangleTree(
maxLeafSize + 1)))
{
stat = StatisticType(*this);
- auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(this);
+ auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(this);
}
/**
@@ -166,7 +166,7 @@ RectangleTree(
points(other.Points()),
localDataset(NULL)
{
- auxiliaryInfo = AuxiliaryInformationType<RectangleTree>(other);
+ auxiliaryInfo = new AuxiliaryInformationType<RectangleTree>(other);
if (deepCopy)
{
if (numChildren > 0)
@@ -225,6 +225,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
AuxiliaryInformationType>::
~RectangleTree()
{
+ delete auxiliaryInfo;
for (size_t i = 0; i < numChildren; i++)
delete children[i];
@@ -297,7 +298,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// If this is a leaf node, we stop here and add the point.
if (numChildren == 0)
{
- if(!auxiliaryInfo.HandlePointInsertion(this,point))
+ if(!auxiliaryInfo->HandlePointInsertion(this,point))
{
localDataset->col(count) = dataset->col(point);
points[count++] = point;
@@ -308,7 +309,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
- auxiliaryInfo.HandlePointInsertion(this,point);
+ auxiliaryInfo->HandlePointInsertion(this,point);
const size_t descentNode = DescentType::ChooseDescentNode(this, point);
children[descentNode]->InsertPoint(point, lvls);
}
@@ -335,7 +336,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// If this is a leaf node, we stop here and add the point.
if (numChildren == 0)
{
- if(!auxiliaryInfo.HandlePointInsertion(this,point))
+ if(!auxiliaryInfo->HandlePointInsertion(this,point))
{
localDataset->col(count) = dataset->col(point);
points[count++] = point;
@@ -346,7 +347,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// If it is not a leaf node, we use the DescentHeuristic to choose a child
// to which we recurse.
- auxiliaryInfo.HandlePointInsertion(this,point);
+ auxiliaryInfo->HandlePointInsertion(this,point);
const size_t descentNode = DescentType::ChooseDescentNode(this,point);
children[descentNode]->InsertPoint(point, relevels);
}
@@ -375,7 +376,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
bound |= node->Bound();
if (level == TreeDepth())
{
- if(!auxiliaryInfo.HandleNodeInsertion(this,node,true))
+ if(!auxiliaryInfo->HandleNodeInsertion(this,node,true))
{
children[numChildren++] = node;
node->Parent() = this;
@@ -384,7 +385,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
}
else
{
- auxiliaryInfo.HandleNodeInsertion(this,node,false);
+ auxiliaryInfo->HandleNodeInsertion(this,node,false);
const size_t descentNode = DescentType::ChooseDescentNode(this, node);
children[descentNode]->InsertNode(node, level, relevels);
}
@@ -420,7 +421,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
{
if (points[i] == point)
{
- if(!auxiliaryInfo.HandlePointDeletion(this,i))
+ if(!auxiliaryInfo->HandlePointDeletion(this,i))
{
localDataset->col(i) = localDataset->col(--count); // Decrement count.
points[i] = points[count];
@@ -460,7 +461,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
{
if (points[i] == point)
{
- if(!auxiliaryInfo.HandlePointDeletion(this,i))
+ if(!auxiliaryInfo->HandlePointDeletion(this,i))
{
localDataset->col(i) = localDataset->col(--count);
points[i] = points[count];
@@ -498,7 +499,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
{
if (children[i] == node)
{
- if(!auxiliaryInfo.HandleNodeRemoval(this,i))
+ if(!auxiliaryInfo->HandleNodeRemoval(this,i))
{
children[i] = children[--numChildren]; // Decrement numChildren.
}
@@ -843,7 +844,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
if (parent->Children()[j] == this)
{
// Decrement numChildren.
- if(!auxiliaryInfo.HandleNodeRemoval(parent,j))
+ if(!auxiliaryInfo->HandleNodeRemoval(parent,j))
{
parent->Children()[j] = parent->Children()[--parent->NumChildren()];
}
@@ -911,7 +912,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
localDataset->col(i) = child->LocalDataset().col(i);
}
- auxiliaryInfo.Copy(this,child);
+ auxiliaryInfo->Copy(this,child);
count = child->Count();
child->SoftDelete();
@@ -921,11 +922,11 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
// If we didn't delete it, shrink the bound if we need to.
if (usePoint &&
- (ShrinkBoundForPoint(point) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
+ (ShrinkBoundForPoint(point) || auxiliaryInfo->UpdateAuxiliaryInfo(this)) &&
parent != NULL)
parent->CondenseTree(point, relevels, usePoint);
else if (!usePoint &&
- (ShrinkBoundForBound(bound) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
+ (ShrinkBoundForBound(bound) || auxiliaryInfo->UpdateAuxiliaryInfo(this)) &&
parent != NULL)
parent->CondenseTree(point, relevels, usePoint);
}
diff --git a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
index 2f90ed5..d020461 100644
--- a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value.hpp
@@ -54,14 +54,23 @@ class RecursiveHilbertValue
//! Indicates that the axis should be inverted
std::vector<bool> inversion;
//! Indicates that the result should be inverted
+ arma::Col<ElemType> center;
+ arma::Col<ElemType> vec;
+ std::vector<int> bits;
+ std::vector<int> bits2;
bool invertResult;
int recursionLevel;
+
tagCompareStruct(size_t dim) :
Lo(dim),
Hi(dim),
permutation(dim),
inversion(dim),
+ center(dim),
+ vec(dim),
+ bits(dim),
+ bits2(dim),
invertResult(false),
recursionLevel(0)
{
@@ -186,7 +195,10 @@ class RecursiveHilbertValue
void UpdateLargestValue(TreeType *node);
//! Return the largest Hilbert value
- size_t LargestValue() const { return largestValue; }
+ ptrdiff_t LargestValue() const { return largestValue; }
+
+ //! Modify the largest Hilbert value
+ ptrdiff_t& LargestValue() { return largestValue; }
private:
//! The largest Hilbert value i.e. the number of the point in the dataset.
diff --git a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
index b78cd76..e95daaa 100644
--- a/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/recursive_hilbert_value_impl.hpp
@@ -63,45 +63,43 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
const arma::Col<ElemType> &pt2,
CompareStruct<ElemType> &comp)
{
- arma::Col<ElemType> center = comp.Hi * 0.5;
- arma::Col<ElemType> vec = comp.Lo * 0.5;
- std::vector<int> bits(pt1.n_rows,0);
- std::vector<int> bits2(pt1.n_rows,0);
+ comp.center = comp.Hi * 0.5;
+ comp.vec = comp.Lo * 0.5;
- center += vec;
+ comp.center += comp.vec;
// Get bits in order to use the Gray code
for(size_t i = 0; i < pt1.n_rows; i++)
{
size_t j = comp.permutation[i];
- bits[i] = (pt1(j) > center(j) && !comp.inversion[j]) ||
- (pt1(j) <= center(j) && !comp.inversion[j]);
+ comp.bits[i] = (pt1(j) > comp.center(j) && !comp.inversion[j]) ||
+ (pt1(j) <= comp.center(j) && !comp.inversion[j]);
- bits2[i] = (pt2(j) > center(j) && !comp.inversion[j]) ||
- (pt2(j) <= center(j) && !comp.inversion[j]);
+ comp.bits2[i] = (pt2(j) > comp.center(j) && !comp.inversion[j]) ||
+ (pt2(j) <= comp.center(j) && !comp.inversion[j]);
}
// Gray encode
for(size_t i = 1; i < pt1.n_rows; i++)
{
- bits[i] ^= bits[i-1];
- bits2[i] ^= bits2[i-1];
+ comp.bits[i] ^= comp.bits[i-1];
+ comp.bits2[i] ^= comp.bits2[i-1];
}
if(comp.invertResult)
{
for(size_t i = 0; i < pt1.n_rows; i++)
{
- bits[i] = !bits[i];
- bits2[i] = !bits2[i];
+ comp.bits[i] = !comp.bits[i];
+ comp.bits2[i] = !comp.bits2[i];
}
}
for(size_t i = 0; i < pt1.n_rows; i++)
{
- if(bits[i] < bits2[i])
+ if(comp.bits[i] < comp.bits2[i])
return -1;
- if(bits[i] > bits2[i])
+ if(comp.bits[i] > comp.bits2[i])
return 1;
}
@@ -110,7 +108,7 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
comp.recursionLevel++;
- if(bits[pt1.n_rows-1])
+ if(comp.bits[pt1.n_rows-1])
comp.invertResult = !comp.invertResult;
// Since the Hilbert curve is continuous we should permutate and intend
@@ -119,8 +117,8 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
{
size_t j = comp.permutation[i];
size_t j0 = comp.permutation[0];
- if((pt1(j) > center(j) && !comp.inversion[j]) ||
- (pt1(j) <= center(j) && !comp.inversion[j]))
+ if((pt1(j) > comp.center(j) && !comp.inversion[j]) ||
+ (pt1(j) <= comp.center(j) && !comp.inversion[j]))
comp.inversion[j0] = !comp.inversion[j0];
else
{
@@ -134,10 +132,10 @@ int RecursiveHilbertValue::ComparePoints(const arma::Col<ElemType> &pt1,
// Choose an appropriate subhypercube
for(size_t i = 0; i < pt1.n_rows; i++)
{
- if(pt1(i) > center(i))
- comp.Lo(i) = center(i);
+ if(pt1(i) > comp.center(i))
+ comp.Lo(i) = comp.center(i);
else
- comp.Hi(i) = center(i);
+ comp.Hi(i) = comp.center(i);
}
return ComparePoints(pt1,pt2,comp);
diff --git a/src/mlpack/core/tree/rectangle_tree/typedef.hpp b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
index cb87d07..6622099 100644
--- a/src/mlpack/core/tree/rectangle_tree/typedef.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/typedef.hpp
@@ -115,7 +115,7 @@ using RecursiveHilbertRTree = RectangleTree<MetricType,
template<typename TreeType>
using DiscreteHilbertRTreeAuxiliaryInformation =
- HilbertRTreeAuxiliaryInformation<TreeType,RecursiveHilbertValue>;
+ HilbertRTreeAuxiliaryInformation<TreeType,DiscreteHilbertValue>;
template<typename MetricType, typename StatisticType, typename MatType>
using DiscreteHilbertRTree = RectangleTree<MetricType,
diff --git a/src/mlpack/tests/rectangle_tree_test.cpp b/src/mlpack/tests/rectangle_tree_test.cpp
index 8a86b05..bf390ba 100644
--- a/src/mlpack/tests/rectangle_tree_test.cpp
+++ b/src/mlpack/tests/rectangle_tree_test.cpp
@@ -706,8 +706,8 @@ void CheckHilbertOrdering(TreeType *tree)
for(size_t i = 0; i < tree->NumPoints() - 1; i++)
BOOST_REQUIRE_LE(
tree->AuxiliaryInfo().LargestHilbertValue().ComparePoints(
- arma::vec(tree->LocalDataset().col(i-1)),
- arma::vec(tree->LocalDataset().col(i))),
+ arma::vec(tree->LocalDataset().col(i)),
+ arma::vec(tree->LocalDataset().col(i+1))),
0);
BOOST_REQUIRE_EQUAL(
@@ -721,8 +721,8 @@ void CheckHilbertOrdering(TreeType *tree)
for(size_t i = 0; i < tree->NumChildren() - 1; i++)
BOOST_REQUIRE_LE(
tree->AuxiliaryInfo().LargestHilbertValue().CompareValues(tree,
- tree->Children()[i-1]->AuxiliaryInfo().LargestHilbertValue(),
- tree->Children()[i]->AuxiliaryInfo().LargestHilbertValue()),
+ tree->Children()[i]->AuxiliaryInfo().LargestHilbertValue(),
+ tree->Children()[i+1]->AuxiliaryInfo().LargestHilbertValue()),
0);
BOOST_REQUIRE_EQUAL(
More information about the mlpack-git
mailing list