[mlpack-git] master: Make sure the tree doesn't split on the same dimension over and over again when there's a tie. (77893ce)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun May 3 19:34:30 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/174d2de995a3fe343cd92d158730f3afa03e622d...076156df78e26ba87012f2b5fbc6d45e84da918b
>---------------------------------------------------------------
commit 77893ce4a53079bc9ef3e7ebb0b669b8527b9c86
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun May 3 19:12:00 2015 -0400
Make sure the tree doesn't split on the same dimension over and over again when there's a tie.
>---------------------------------------------------------------
77893ce4a53079bc9ef3e7ebb0b669b8527b9c86
.../core/tree/binary_space_tree/mean_split.hpp | 54 +++++++++++++---------
.../tree/binary_space_tree/mean_split_impl.hpp | 29 ++++++++++++
2 files changed, 61 insertions(+), 22 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
index 219cc8c..f5df981 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
@@ -24,6 +24,13 @@ class MeanSplit
{
public:
/**
+ * Instantiate a MeanSplit object. This will initialize the nextDimension
+ * object, which will determine which dimension to use next, in the instance
+ * of a tie.
+ */
+ MeanSplit() : nextDimension(0) { /* Nothing to do. */ }
+
+ /**
* Split the node according to the mean value in the dimension with maximum
* width.
*
@@ -37,11 +44,11 @@ class MeanSplit
* @param splitCol The index at which the dataset is divided into two parts
* after the rearrangement.
*/
- static bool SplitNode(const BoundType& bound,
- MatType& data,
- const size_t begin,
- const size_t count,
- size_t& splitCol);
+ bool SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitCol);
/**
* Split the node according to the mean value in the dimension with maximum
@@ -59,12 +66,12 @@ class MeanSplit
* @param oldFromNew Vector which will be filled with the old positions for
* each new point.
*/
- static bool SplitNode(const BoundType& bound,
- MatType& data,
- const size_t begin,
- const size_t count,
- size_t& splitCol,
- std::vector<size_t>& oldFromNew);
+ bool SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitCol,
+ std::vector<size_t>& oldFromNew);
private:
/**
@@ -79,11 +86,11 @@ class MeanSplit
* @param splitVal The split in dimension splitDimension is based on this
* value.
*/
- static size_t PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const size_t splitDimension,
- const double splitVal);
+ size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal);
/**
* Reorder the dataset into two parts such that they lie on either side of
@@ -99,12 +106,15 @@ class MeanSplit
* @param oldFromNew Vector which will be filled with the old positions for
* each new point.
*/
- static size_t PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const size_t splitDimension,
- const double splitVal,
- std::vector<size_t>& oldFromNew);
+ size_t PerformSplit(MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal,
+ std::vector<size_t>& oldFromNew);
+
+ //! Tracks the next dimension to use in case of a tie.
+ size_t nextDimension;
};
}; // namespace tree
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
index f9d07fe..73edf6f 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -24,6 +24,7 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
double maxWidth = -1;
// Find the split dimension.
+ size_t ties = 0;
for (size_t d = 0; d < data.n_rows; d++)
{
double width = bound[d].Width();
@@ -33,6 +34,31 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
maxWidth = width;
splitDimension = d;
}
+ if (width == maxWidth)
+ {
+ // There's a tie. Record that.
+ ++ties;
+ }
+ }
+
+ if (ties > 0)
+ {
+ // Look through a second time, and determine the correct dimension.
+ size_t tieIndex = 0;
+ for (size_t d = 0; d < data.n_rows; ++d)
+ {
+ const double width = bound[d].Width();
+
+ if (width == maxWidth)
+ {
+ if (tieIndex == (nextDimension % ties))
+ {
+ splitDimension = d;
+ break;
+ }
+ ++tieIndex;
+ }
+ }
}
if (maxWidth == 0) // All these points are the same. We can't split.
@@ -132,6 +158,9 @@ size_t MeanSplit<BoundType, MatType>::
Log::Assert(left == right + 1);
+ if (left >= begin + count)
+ Log::Fatal << "Left is count. Bad.\n";
+
return left;
}
More information about the mlpack-git
mailing list