[mlpack-git] master: Add midpoint split. This often produces better trees than mean split... (07a099f)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu May 7 15:10:57 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/237ab40b5a32dd626a387e8a109771307fe59153...148bfca48cccba1647ffcfacd03e578a493b7265
>---------------------------------------------------------------
commit 07a099f2b60289704ea85bfd427973e516fc4d52
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu May 7 14:28:30 2015 -0400
Add midpoint split. This often produces better trees than mean split...
>---------------------------------------------------------------
07a099f2b60289704ea85bfd427973e516fc4d52
.../{mean_split.hpp => midpoint_split.hpp} | 13 ++--
...mean_split_impl.hpp => midpoint_split_impl.hpp} | 79 ++++++++++------------
2 files changed, 41 insertions(+), 51 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
similarity index 91%
copy from src/mlpack/core/tree/binary_space_tree/mean_split.hpp
copy to src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index 3573f11..3aabf9a 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -1,10 +1,11 @@
/**
- * @file mean_split.hpp
+ * @file midpoint_split.hpp
* @author Yash Vadalia
* @author Ryan Curtin
*
- * Definition of MeanSplit, a class that splits a binary space partitioning tree
- * node into two parts using the mean of the values in a certain dimension.
+ * Definition of MidpointSplit, a class that splits a binary space partitioning
+ * tree node into two parts using the midpoint of the values in a certain
+ * dimension. The dimension to split on is the dimension with maximum variance.
*/
#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
@@ -17,10 +18,10 @@ namespace tree /** Trees and tree-building procedures. */ {
/**
* A binary space partitioning tree node is split into its left and right child.
* The split is done in the dimension that has the maximum width. The points are
- * divided into two parts based on the mean in this dimension.
+ * divided into two parts based on the midpoint in this dimension.
*/
template<typename BoundType, typename MatType = arma::mat>
-class MeanSplit
+class MidpointSplit
{
public:
/**
@@ -111,6 +112,6 @@ class MeanSplit
} // namespace mlpack
// Include implementation.
-#include "mean_split_impl.hpp"
+#include "midpoint_split_impl.hpp"
#endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
similarity index 76%
copy from src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
copy to src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
index be6ba14..ba18bee 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
@@ -1,24 +1,25 @@
/**
- * @file mean_split_impl.hpp
+ * @file midpoint_split_impl.hpp
* @author Yash Vadalia
* @author Ryan Curtin
*
- * Implementation of class(MeanSplit) to split a binary space partition tree.
+ * Implementation of class (MidpointSplit) to split a binary space partition
+ * tree.
*/
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
-#include "mean_split.hpp"
+#include "midpoint_split.hpp"
namespace mlpack {
namespace tree {
template<typename BoundType, typename MatType>
-bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
- MatType& data,
- const size_t begin,
- const size_t count,
- size_t& splitCol)
+bool MidpointSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitCol)
{
size_t splitDimension = data.n_rows; // Indicate invalid.
double maxWidth = -1;
@@ -72,14 +73,8 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
if (maxWidth == 0) // All these points are the same. We can't split.
return false;
- // Split in the mean of that dimension.
- double splitVal = 0.0;
- for (size_t i = begin; i < begin + count; ++i)
- splitVal += data(splitDimension, i);
- splitVal /= count;
-
- Log::Assert(splitVal >= bound[splitDimension].Lo());
- Log::Assert(splitVal <= bound[splitDimension].Hi());
+ // Split in the midpoint of that dimension.
+ double splitVal = bound[splitDimension].Mid();
// Perform the actual splitting. This will order the dataset such that points
// with value in dimension splitDimension less than or equal to splitVal are
@@ -91,12 +86,12 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
}
template<typename BoundType, typename MatType>
-bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
- MatType& data,
- const size_t begin,
- const size_t count,
- size_t& splitCol,
- std::vector<size_t>& oldFromNew)
+bool MidpointSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ size_t& splitCol,
+ std::vector<size_t>& oldFromNew)
{
size_t splitDimension = data.n_rows; // Indicate invalid.
double maxWidth = -1;
@@ -149,14 +144,8 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
if (maxWidth == 0) // All these points are the same. We can't split.
return false;
- // Split in the mean of that dimension.
- double splitVal = 0.0;
- for (size_t i = begin; i < begin + count; ++i)
- splitVal += data(splitDimension, i);
- splitVal /= count;
-
- Log::Assert(splitVal >= bound[splitDimension].Lo());
- Log::Assert(splitVal <= bound[splitDimension].Hi());
+ // Split in the midpoint of that dimension.
+ double splitVal = bound[splitDimension].Mid();
// Perform the actual splitting. This will order the dataset such that points
// with value in dimension splitDimension less than or equal to splitVal are
@@ -169,12 +158,12 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
}
template<typename BoundType, typename MatType>
-size_t MeanSplit<BoundType, MatType>::
- PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const size_t splitDimension,
- const double splitVal)
+size_t MidpointSplit<BoundType, MatType>::PerformSplit(
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal)
{
// This method modifies the input dataset. We loop both from the left and
// right sides of the points contained in this node. The points less than
@@ -215,13 +204,13 @@ size_t MeanSplit<BoundType, MatType>::
}
template<typename BoundType, typename MatType>
-size_t MeanSplit<BoundType, MatType>::
- PerformSplit(MatType& data,
- const size_t begin,
- const size_t count,
- const size_t splitDimension,
- const double splitVal,
- std::vector<size_t>& oldFromNew)
+size_t MidpointSplit<BoundType, MatType>::PerformSplit(
+ MatType& data,
+ const size_t begin,
+ const size_t count,
+ const size_t splitDimension,
+ const double splitVal,
+ std::vector<size_t>& oldFromNew)
{
// This method modifies the input dataset. We loop both from the left and
// right sides of the points contained in this node. The points less than
More information about the mlpack-git
mailing list