[mlpack-git] master: Record splitDimension and splitValue. (2af816a)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:28 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 2af816aed2a0bcf1f07143a6d1dc6fc14694c1dd
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Jul 26 23:33:45 2016 -0300
Record splitDimension and splitValue.
>---------------------------------------------------------------
2af816aed2a0bcf1f07143a6d1dc6fc14694c1dd
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 10 ++++++++++
src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp | 18 +++++++++++++-----
2 files changed, 23 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index eca4770..270a9b9 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -86,6 +86,10 @@ class SpillTree
std::vector<size_t>* pointsIndex;
//! Flag to distinguish overlapping nodes from non-overlapping nodes.
bool overlappingNode;
+ //! Dimension considered when splitting.
+ size_t splitDimension;
+ //! Dimension value that determines the decision boundary.
+ double splitValue;
//! The bound object for this node.
bound::HRectBound<MetricType> bound;
//! Any extra data contained in the node.
@@ -231,6 +235,12 @@ class SpillTree
//! Distinguish overlapping nodes from non-overlapping nodes.
bool Overlap() const { return overlappingNode; }
+ //! Dimension considered when splitting.
+ size_t SplitDimension() const { return splitDimension; }
+
+ //! Dimension value that determines the decision boundary.
+ double SplitValue() const { return splitValue; }
+
//! Get the metric that the tree uses.
MetricType Metric() const { return MetricType(); }
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index de35140..2272c21 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -29,6 +29,8 @@ SpillTree(
count(0),
pointsIndex(NULL),
overlappingNode(false),
+ splitDimension(0),
+ splitValue(0),
bound(data.n_rows),
parentDistance(0), // Parent distance for the root is 0: it has no parent.
dataset(new MatType(data)) // Copies the dataset.
@@ -62,6 +64,8 @@ SpillTree(
count(0),
pointsIndex(NULL),
overlappingNode(false),
+ splitDimension(0),
+ splitValue(0),
bound(data.n_rows),
parentDistance(0), // Parent distance for the root is 0: it has no parent.
dataset(new MatType(std::move(data)))
@@ -97,6 +101,8 @@ SpillTree(
count(0),
pointsIndex(NULL),
overlappingNode(false),
+ splitDimension(0),
+ splitValue(0),
bound(parent->Dataset().n_rows),
dataset(&parent->Dataset()) // Point to the parent's dataset.
{
@@ -124,6 +130,8 @@ SpillTree(const SpillTree& other) :
count(other.count),
pointsIndex(NULL),
overlappingNode(other.overlappingNode),
+ splitDimension(other.splitDimension),
+ splitValue(other.splitValue),
bound(other.bound),
stat(other.stat),
parentDistance(other.parentDistance),
@@ -186,6 +194,8 @@ SpillTree(SpillTree&& other) :
count(other.count),
overlappingNode(other.overlappingNode),
pointsIndex(other.pointsIndex),
+ splitDimension(other.splitDimension),
+ splitValue(other.splitValue),
bound(std::move(other.bound)),
stat(std::move(other.stat)),
parentDistance(other.parentDistance),
@@ -448,10 +458,8 @@ void SpillTree<MetricType, StatisticType, MatType, SplitType>::
return; // We can't split this.
}
- size_t splitDimension;
- double splitVal;
- const bool split = SplitType<BoundType<MetricType>, MatType>::SplitNode(
- bound, *dataset, points, splitDimension, splitVal);
+ const bool split = SplitType<bound::HRectBound<MetricType>,
+ MatType>::SplitNode(bound, *dataset, points, splitDimension, splitValue);
// The node may not be always split. For instance, if all the points are the
// same, we can't split them.
if (!split)
@@ -465,7 +473,7 @@ void SpillTree<MetricType, StatisticType, MatType, SplitType>::
std::vector<size_t> leftPoints, rightPoints;
size_t overlapIndexLeft, overlapIndexRight;
// Split the node.
- overlappingNode = SplitPoints(splitDimension, splitVal, tau, rho, points,
+ overlappingNode = SplitPoints(splitDimension, splitValue, tau, rho, points,
leftPoints, rightPoints, overlapIndexLeft, overlapIndexRight);
// We don't need the information in points, so lets clean it.
More information about the mlpack-git
mailing list