[mlpack-git] master: Improve expansion of node's bound. Do not include the overlapping buffer in the bounds of both nodes. (8b2dc10)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:59 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 8b2dc102fcf32cc645b6a257be5211117bb14dcf
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Jul 7 12:40:32 2016 -0300
Improve expansion of node's bound. Do not include the overlapping buffer in the bounds of both nodes.
>---------------------------------------------------------------
8b2dc102fcf32cc645b6a257be5211117bb14dcf
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 16 +++++++++++-
.../core/tree/spill_tree/spill_tree_impl.hpp | 29 +++++++++++++++-------
2 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index fd7b227..db7f6c0 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -158,12 +158,14 @@ class SpillTree
*
* @param parent Parent of this node. Its dataset will be modified!
* @param points Vector of indexes of points to be included in this node.
+ * @param overlapIndex Index where the list of overlapping points starts.
* @param tau Overlapping size.
* @param maxLeafSize Size of each leaf in the tree.
* @param rho Balance threshold.
*/
SpillTree(SpillTree* parent,
std::vector<size_t>& points,
+ const size_t overlapIndex,
const double tau,
const size_t maxLeafSize = 20,
const double rho = 0.7);
@@ -361,11 +363,13 @@ class SpillTree
* Splits the current node, assigning its left and right children recursively.
*
* @param points Vector of indexes of points to be included in this node.
+ * @param overlapIndex Index where the list of overlapping points starts.
* @param maxLeafSize Maximum number of points held in a leaf.
* @param tau Overlapping size.
* @param rho Balance threshold.
*/
void SplitNode(std::vector<size_t>& points,
+ const size_t overlapIndex,
const size_t maxLeafSize,
const double tau,
const double rho);
@@ -380,6 +384,14 @@ class SpillTree
* @param points Vector of indexes of points to be included.
* @param leftPoints Indexes of points to be included in left child.
* @param rightPoints Indexes of points to be included in right child.
+ * @param overlapIndexLeft Index in leftPoints where the list of overlapping
+ points starts ( [overlapIndexLeft, leftPoints.size()) represents the
+ indexes of the points from the right node that are included in the left
+ node).
+ * @param overlapIndexRight Index in rightPoints where the list of overlapping
+ points starts ( [overlapIndexRight, rightPoints.size()) represents the
+ indexes of the points from the left node that are included in the right
+ node).
* @return Flag to know if the overlapping buffer was included.
*/
bool SplitPoints(const size_t splitDimension,
@@ -388,7 +400,9 @@ class SpillTree
const double rho,
const std::vector<size_t>& points,
std::vector<size_t>& leftPoints,
- std::vector<size_t>& rightPoints);
+ std::vector<size_t>& rightPoints,
+ size_t& overlapIndexLeft,
+ size_t& overlapIndexRight);
protected:
/**
* A default constructor. This is meant to only be used with
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index 6ce80bd..25c83aa 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -43,7 +43,7 @@ SpillTree(
points.push_back(i);
// Do the actual splitting of this node.
- SplitNode(points, maxLeafSize, tau, rho);
+ SplitNode(points, points.size(), maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -77,7 +77,7 @@ SpillTree(
points.push_back(i);
// Do the actual splitting of this node.
- SplitNode(points, maxLeafSize, tau, rho);
+ SplitNode(points, points.size(), maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -93,6 +93,7 @@ SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SpillTree(
SpillTree* parent,
std::vector<size_t>& points,
+ const size_t overlapIndex,
const double tau,
const size_t maxLeafSize,
const double rho) :
@@ -106,7 +107,7 @@ SpillTree(
dataset(&parent->Dataset()) // Point to the parent's dataset.
{
// Perform the actual splitting.
- SplitNode(points, maxLeafSize, tau, rho);
+ SplitNode(points, overlapIndex, maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -446,12 +447,14 @@ template<typename MetricType,
class SplitType>
void SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SplitNode(std::vector<size_t>& points,
+ const size_t overlapIndex,
const size_t maxLeafSize,
const double tau,
const double rho)
{
- // We need to expand the bounds of this node properly.
- for(size_t i = 0; i < points->size(); i++)
+ // We need to expand the bounds of this node properly, ignoring overlapping
+ // points (they will be included in the bound of the other node).
+ for(size_t i = 0; i < overlapIndex; i++)
bound |= dataset->cols(points[i], points[i]);
// Calculate the furthest descendant distance.
@@ -481,17 +484,20 @@ void SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
}
std::vector<size_t> leftPoints, rightPoints;
+ size_t overlapIndexLeft, overlapIndexRight;
// Split the node.
overlappingNode = SplitPoints(splitDimension, splitVal, tau, rho, points,
- leftPoints, rightPoints);
+ leftPoints, rightPoints, overlapIndexLeft, overlapIndexRight);
// We don't need the information in points, so lets clean it.
std::vector<size_t>().swap(points);
// Now we will recursively split the children by calling their constructors
// (which perform this splitting process).
- left = new SpillTree(this, leftPoints, maxLeafSize, tau, rho);
- right = new SpillTree(this, rightPoints, maxLeafSize, tau, rho);
+ left = new SpillTree(this, leftPoints, overlapIndexLeft, maxLeafSize, tau,
+ rho);
+ right = new SpillTree(this, rightPoints, overlapIndexRight, maxLeafSize, tau,
+ rho);
// Update count number, to represent the number of descendant points.
count = left->NumDescendants() + right->NumDescendants();
@@ -523,7 +529,9 @@ bool SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
const double rho,
const std::vector<size_t>& points,
std::vector<size_t>& leftPoints,
- std::vector<size_t>& rightPoints)
+ std::vector<size_t>& rightPoints,
+ size_t& overlapIndexLeft,
+ size_t& overlapIndexRight)
{
std::vector<size_t> leftFrontier, rightFrontier;
@@ -550,6 +558,9 @@ bool SpillTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
const double p2 = double (rightPoints.size() + leftFrontier.size()) /
points.size();
+ overlapIndexLeft = leftPoints.size();
+ overlapIndexRight = rightPoints.size();
+
if (p1 <= rho && p2 <= rho)
{
leftPoints.insert(leftPoints.end(), rightFrontier.begin(),
More information about the mlpack-git
mailing list