[mlpack-git] master: Include overlapping points in each child's bounding box. (a4ac0b2)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:59 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit a4ac0b2b67ceac6f348cc878da150eeaded4cc0f
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Mon Aug 1 17:10:36 2016 -0300
Include overlapping points in each child's bounding box.
>---------------------------------------------------------------
a4ac0b2b67ceac6f348cc878da150eeaded4cc0f
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 16 +-----------
.../core/tree/spill_tree/spill_tree_impl.hpp | 29 +++++++---------------
2 files changed, 10 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index c48fbd5..29208ce 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -156,14 +156,12 @@ class SpillTree
*
* @param parent Parent of this node. Its dataset will be modified!
* @param points Vector of indexes of points to be included in this node.
- * @param overlapIndex Index where the list of overlapping points starts.
* @param tau Overlapping size.
* @param maxLeafSize Size of each leaf in the tree.
* @param rho Balance threshold.
*/
SpillTree(SpillTree* parent,
std::vector<size_t>& points,
- const size_t overlapIndex,
const double tau = 0,
const size_t maxLeafSize = 20,
const double rho = 0.7);
@@ -368,13 +366,11 @@ class SpillTree
* Splits the current node, assigning its left and right children recursively.
*
* @param points Vector of indexes of points to be included in this node.
- * @param overlapIndex Index where the list of overlapping points starts.
* @param maxLeafSize Maximum number of points held in a leaf.
* @param tau Overlapping size.
* @param rho Balance threshold.
*/
void SplitNode(std::vector<size_t>& points,
- const size_t overlapIndex,
const size_t maxLeafSize,
const double tau,
const double rho);
@@ -389,14 +385,6 @@ class SpillTree
* @param points Vector of indexes of points to be included.
* @param leftPoints Indexes of points to be included in left child.
* @param rightPoints Indexes of points to be included in right child.
- * @param overlapIndexLeft Index in leftPoints where the list of overlapping
- points starts ( [overlapIndexLeft, leftPoints.size()) represents the
- indexes of the points from the right node that are included in the left
- node).
- * @param overlapIndexRight Index in rightPoints where the list of overlapping
- points starts ( [overlapIndexRight, rightPoints.size()) represents the
- indexes of the points from the left node that are included in the right
- node).
* @return Flag to know if the overlapping buffer was included.
*/
bool SplitPoints(const size_t splitDimension,
@@ -405,9 +393,7 @@ class SpillTree
const double rho,
const std::vector<size_t>& points,
std::vector<size_t>& leftPoints,
- std::vector<size_t>& rightPoints,
- size_t& overlapIndexLeft,
- size_t& overlapIndexRight);
+ std::vector<size_t>& rightPoints);
protected:
/**
* A default constructor. This is meant to only be used with
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index e2f0f93..b8c7598 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -42,7 +42,7 @@ SpillTree(
points.push_back(i);
// Do the actual splitting of this node.
- SplitNode(points, points.size(), maxLeafSize, tau, rho);
+ SplitNode(points, maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -78,7 +78,7 @@ SpillTree(
points.push_back(i);
// Do the actual splitting of this node.
- SplitNode(points, points.size(), maxLeafSize, tau, rho);
+ SplitNode(points, maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -93,7 +93,6 @@ SpillTree<MetricType, StatisticType, MatType, SplitType>::
SpillTree(
SpillTree* parent,
std::vector<size_t>& points,
- const size_t overlapIndex,
const double tau,
const size_t maxLeafSize,
const double rho) :
@@ -110,7 +109,7 @@ SpillTree(
localDataset(false)
{
// Perform the actual splitting.
- SplitNode(points, overlapIndex, maxLeafSize, tau, rho);
+ SplitNode(points, maxLeafSize, tau, rho);
// Create the statistic depending on if we are a leaf or not.
stat = StatisticType(*this);
@@ -448,14 +447,12 @@ template<typename MetricType,
class SplitType>
void SpillTree<MetricType, StatisticType, MatType, SplitType>::
SplitNode(std::vector<size_t>& points,
- const size_t overlapIndex,
const size_t maxLeafSize,
const double tau,
const double rho)
{
- // We need to expand the bounds of this node properly, ignoring overlapping
- // points (they will be included in the bound of the other node).
- for (size_t i = 0; i < overlapIndex; i++)
+ // We need to expand the bounds of this node properly.
+ for (size_t i = 0; i < points.size(); i++)
bound |= dataset->cols(points[i], points[i]);
// Calculate the furthest descendant distance.
@@ -483,20 +480,17 @@ void SpillTree<MetricType, StatisticType, MatType, SplitType>::
}
std::vector<size_t> leftPoints, rightPoints;
- size_t overlapIndexLeft, overlapIndexRight;
// Split the node.
overlappingNode = SplitPoints(splitDimension, splitValue, tau, rho, points,
- leftPoints, rightPoints, overlapIndexLeft, overlapIndexRight);
+ leftPoints, rightPoints);
// We don't need the information in points, so lets clean it.
std::vector<size_t>().swap(points);
// Now we will recursively split the children by calling their constructors
// (which perform this splitting process).
- left = new SpillTree(this, leftPoints, overlapIndexLeft, tau, maxLeafSize,
- rho);
- right = new SpillTree(this, rightPoints, overlapIndexRight, tau, maxLeafSize,
- rho);
+ left = new SpillTree(this, leftPoints, tau, maxLeafSize, rho);
+ right = new SpillTree(this, rightPoints, tau, maxLeafSize, rho);
// Update count number, to represent the number of descendant points.
count = left->NumDescendants() + right->NumDescendants();
@@ -527,9 +521,7 @@ bool SpillTree<MetricType, StatisticType, MatType, SplitType>::
const double rho,
const std::vector<size_t>& points,
std::vector<size_t>& leftPoints,
- std::vector<size_t>& rightPoints,
- size_t& overlapIndexLeft,
- size_t& overlapIndexRight)
+ std::vector<size_t>& rightPoints)
{
std::vector<size_t> leftFrontier, rightFrontier;
@@ -556,9 +548,6 @@ bool SpillTree<MetricType, StatisticType, MatType, SplitType>::
const double p2 = double (rightPoints.size() + leftFrontier.size()) /
points.size();
- overlapIndexLeft = leftPoints.size();
- overlapIndexRight = rightPoints.size();
-
if ((p1 <= rho || rightFrontier.empty()) && (p2 <= rho || leftFrontier.empty()))
{
leftPoints.insert(leftPoints.end(), rightFrontier.begin(),
More information about the mlpack-git
mailing list