[mlpack-git] master: Record splitDimension and splitValue. (2af816a)

gitdub at mlpack.org gitdub at mlpack.org
Thu Aug 18 13:39:28 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0

>---------------------------------------------------------------

commit 2af816aed2a0bcf1f07143a6d1dc6fc14694c1dd
Author: MarcosPividori <marcos.pividori at gmail.com>
Date:   Tue Jul 26 23:33:45 2016 -0300

    Record splitDimension and splitValue.


>---------------------------------------------------------------

2af816aed2a0bcf1f07143a6d1dc6fc14694c1dd
 src/mlpack/core/tree/spill_tree/spill_tree.hpp      | 10 ++++++++++
 src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp | 18 +++++++++++++-----
 2 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index eca4770..270a9b9 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -86,6 +86,10 @@ class SpillTree
   std::vector<size_t>* pointsIndex;
   //! Flag to distinguish overlapping nodes from non-overlapping nodes.
   bool overlappingNode;
+  //! Dimension considered when splitting.
+  size_t splitDimension;
+  //! Dimension value that determines the decision boundary.
+  double splitValue;
   //! The bound object for this node.
   bound::HRectBound<MetricType> bound;
   //! Any extra data contained in the node.
@@ -231,6 +235,12 @@ class SpillTree
   //! Distinguish overlapping nodes from non-overlapping nodes.
   bool Overlap() const { return overlappingNode; }
 
+  //! Dimension considered when splitting.
+  size_t SplitDimension() const { return splitDimension; }
+
+  //! Dimension value that determines the decision boundary.
+  double SplitValue() const { return splitValue; }
+
   //! Get the metric that the tree uses.
   MetricType Metric() const { return MetricType(); }
 
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
index de35140..2272c21 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree_impl.hpp
@@ -29,6 +29,8 @@ SpillTree(
     count(0),
     pointsIndex(NULL),
     overlappingNode(false),
+    splitDimension(0),
+    splitValue(0),
     bound(data.n_rows),
     parentDistance(0), // Parent distance for the root is 0: it has no parent.
     dataset(new MatType(data)) // Copies the dataset.
@@ -62,6 +64,8 @@ SpillTree(
     count(0),
     pointsIndex(NULL),
     overlappingNode(false),
+    splitDimension(0),
+    splitValue(0),
     bound(data.n_rows),
     parentDistance(0), // Parent distance for the root is 0: it has no parent.
     dataset(new MatType(std::move(data)))
@@ -97,6 +101,8 @@ SpillTree(
     count(0),
     pointsIndex(NULL),
     overlappingNode(false),
+    splitDimension(0),
+    splitValue(0),
     bound(parent->Dataset().n_rows),
     dataset(&parent->Dataset()) // Point to the parent's dataset.
 {
@@ -124,6 +130,8 @@ SpillTree(const SpillTree& other) :
     count(other.count),
     pointsIndex(NULL),
     overlappingNode(other.overlappingNode),
+    splitDimension(other.splitDimension),
+    splitValue(other.splitValue),
     bound(other.bound),
     stat(other.stat),
     parentDistance(other.parentDistance),
@@ -186,6 +194,8 @@ SpillTree(SpillTree&& other) :
     count(other.count),
     overlappingNode(other.overlappingNode),
     pointsIndex(other.pointsIndex),
+    splitDimension(other.splitDimension),
+    splitValue(other.splitValue),
     bound(std::move(other.bound)),
     stat(std::move(other.stat)),
     parentDistance(other.parentDistance),
@@ -448,10 +458,8 @@ void SpillTree<MetricType, StatisticType, MatType, SplitType>::
     return; // We can't split this.
   }
 
-  size_t splitDimension;
-  double splitVal;
-  const bool split = SplitType<BoundType<MetricType>, MatType>::SplitNode(
-      bound, *dataset, points, splitDimension, splitVal);
+  const bool split = SplitType<bound::HRectBound<MetricType>,
+      MatType>::SplitNode(bound, *dataset, points, splitDimension, splitValue);
   // The node may not be always split. For instance, if all the points are the
   // same, we can't split them.
   if (!split)
@@ -465,7 +473,7 @@ void SpillTree<MetricType, StatisticType, MatType, SplitType>::
   std::vector<size_t> leftPoints, rightPoints;
   size_t overlapIndexLeft, overlapIndexRight;
   // Split the node.
-  overlappingNode = SplitPoints(splitDimension, splitVal, tau, rho, points,
+  overlappingNode = SplitPoints(splitDimension, splitValue, tau, rho, points,
       leftPoints, rightPoints, overlapIndexLeft, overlapIndexRight);
 
   // We don't need the information in points, so lets clean it.




More information about the mlpack-git mailing list