[mlpack-git] master: Oops, did not apply changes to other overload. (9d8c552)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun May 3 21:37:57 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/076156df78e26ba87012f2b5fbc6d45e84da918b...744b3268b46dbd04fc42b343d992bceda121bc11
>---------------------------------------------------------------
commit 9d8c55236df28ad3a8856f835cb4926e258dee7a
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun May 3 21:37:37 2015 -0400
Oops, did not apply changes to other overload.
>---------------------------------------------------------------
9d8c55236df28ad3a8856f835cb4926e258dee7a
.../tree/binary_space_tree/mean_split_impl.hpp | 54 ++++++++++++++++++++--
1 file changed, 49 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
index 1fe630b..9a6c041 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
@@ -70,12 +70,18 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
splitVal += data(splitDimension, i);
splitVal /= count;
+ Log::Assert(splitVal >= bound[splitDimension].Lo());
+ Log::Assert(splitVal <= bound[splitDimension].Hi());
+
// Perform the actual splitting. This will order the dataset such that points
// with value in dimension splitDimension less than or equal to splitVal are
// on the left of splitCol, and points with value in dimension splitDimension
// greater than splitVal are on the right side of splitCol.
splitCol = PerformSplit(data, begin, count, splitDimension, splitVal);
+ // Increment the tie-breaker counter.
+ ++nextDimension;
+
return true;
}
@@ -91,22 +97,54 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
double maxWidth = -1;
// Find the split dimension.
+ size_t ties = 0;
for (size_t d = 0; d < data.n_rows; d++)
{
- double width = bound[d].Width();
+ const double width = bound[d].Width();
if (width > maxWidth)
{
maxWidth = width;
splitDimension = d;
}
+ if (width == maxWidth)
+ {
+ // There's a tie. Record that.
+ ++ties;
+ }
+ }
+
+ if (ties > 0)
+ {
+ // Look through a second time, and determine the correct dimension.
+ size_t tieIndex = 0;
+ for (size_t d = 0; d < data.n_rows; ++d)
+ {
+ const double width = bound[d].Width();
+
+ if (width == maxWidth)
+ {
+ if (tieIndex == (nextDimension % ties))
+ {
+ splitDimension = d;
+ break;
+ }
+ ++tieIndex;
+ }
+ }
}
if (maxWidth == 0) // All these points are the same. We can't split.
return false;
- // Split in the middle of that dimension.
- double splitVal = bound[splitDimension].Mid();
+ // Split in the mean of that dimension.
+ double splitVal = 0.0;
+ for (size_t i = begin; i < begin + count; ++i)
+ splitVal += data(splitDimension, i);
+ splitVal /= count;
+
+ Log::Assert(splitVal >= bound[splitDimension].Lo());
+ Log::Assert(splitVal <= bound[splitDimension].Hi());
// Perform the actual splitting. This will order the dataset such that points
// with value in dimension splitDimension less than or equal to splitVal are
@@ -115,6 +153,9 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
splitCol = PerformSplit(data, begin, count, splitDimension, splitVal,
oldFromNew);
+ // Increment the tie-breaker counter.
+ ++nextDimension;
+
return true;
}
@@ -161,8 +202,8 @@ size_t MeanSplit<BoundType, MatType>::
Log::Assert(left == right + 1);
- if (left >= begin + count)
- Log::Fatal << "Left is count. Bad.\n";
+ Log::Assert(left > begin);
+ Log::Assert(left < begin + count - 1);
return left;
}
@@ -216,6 +257,9 @@ size_t MeanSplit<BoundType, MatType>::
Log::Assert(left == right + 1);
+ Log::Assert(left > begin);
+ Log::Assert(left < begin + count - 1);
+
return left;
}
More information about the mlpack-git
mailing list