[mlpack-git] master: Add midpoint split. This often produces better trees than mean split... (07a099f)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu May 7 15:10:57 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/237ab40b5a32dd626a387e8a109771307fe59153...148bfca48cccba1647ffcfacd03e578a493b7265

>---------------------------------------------------------------

commit 07a099f2b60289704ea85bfd427973e516fc4d52
Author: Ryan Curtin <ryan at ratml.org>
Date:   Thu May 7 14:28:30 2015 -0400

    Add midpoint split.  This often produces better trees than mean split...


>---------------------------------------------------------------

07a099f2b60289704ea85bfd427973e516fc4d52
 .../{mean_split.hpp => midpoint_split.hpp}         | 13 ++--
 ...mean_split_impl.hpp => midpoint_split_impl.hpp} | 79 ++++++++++------------
 2 files changed, 41 insertions(+), 51 deletions(-)

diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
similarity index 91%
copy from src/mlpack/core/tree/binary_space_tree/mean_split.hpp
copy to src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
index 3573f11..3aabf9a 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split.hpp
@@ -1,10 +1,11 @@
 /**
- * @file mean_split.hpp
+ * @file midpoint_split.hpp
  * @author Yash Vadalia
  * @author Ryan Curtin
  *
- * Definition of MeanSplit, a class that splits a binary space partitioning tree
- * node into two parts using the mean of the values in a certain dimension.
+ * Definition of MidpointSplit, a class that splits a binary space partitioning
+ * tree node into two parts using the midpoint of the values in a certain
+ * dimension.  The dimension to split on is the dimension with maximum variance.
  */
 #ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
 #define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_HPP
@@ -17,10 +18,10 @@ namespace tree /** Trees and tree-building procedures. */ {
 /**
  * A binary space partitioning tree node is split into its left and right child.
  * The split is done in the dimension that has the maximum width. The points are
- * divided into two parts based on the mean in this dimension.
+ * divided into two parts based on the midpoint in this dimension.
  */
 template<typename BoundType, typename MatType = arma::mat>
-class MeanSplit
+class MidpointSplit
 {
  public:
   /**
@@ -111,6 +112,6 @@ class MeanSplit
 } // namespace mlpack
 
 // Include implementation.
-#include "mean_split_impl.hpp"
+#include "midpoint_split_impl.hpp"
 
 #endif
diff --git a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
similarity index 76%
copy from src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
copy to src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
index be6ba14..ba18bee 100644
--- a/src/mlpack/core/tree/binary_space_tree/mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/midpoint_split_impl.hpp
@@ -1,24 +1,25 @@
 /**
- * @file mean_split_impl.hpp
+ * @file midpoint_split_impl.hpp
  * @author Yash Vadalia
  * @author Ryan Curtin
  *
- * Implementation of class(MeanSplit) to split a binary space partition tree.
+ * Implementation of class (MidpointSplit) to split a binary space partition
+ * tree.
  */
-#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
-#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MEAN_SPLIT_IMPL_HPP
+#ifndef __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
+#define __MLPACK_CORE_TREE_BINARY_SPACE_TREE_MIDPOINT_SPLIT_IMPL_HPP
 
-#include "mean_split.hpp"
+#include "midpoint_split.hpp"
 
 namespace mlpack {
 namespace tree {
 
 template<typename BoundType, typename MatType>
-bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
-                                              MatType& data,
-                                              const size_t begin,
-                                              const size_t count,
-                                              size_t& splitCol)
+bool MidpointSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                                  MatType& data,
+                                                  const size_t begin,
+                                                  const size_t count,
+                                                  size_t& splitCol)
 {
   size_t splitDimension = data.n_rows; // Indicate invalid.
   double maxWidth = -1;
@@ -72,14 +73,8 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   if (maxWidth == 0) // All these points are the same.  We can't split.
     return false;
 
-  // Split in the mean of that dimension.
-  double splitVal = 0.0;
-  for (size_t i = begin; i < begin + count; ++i)
-    splitVal += data(splitDimension, i);
-  splitVal /= count;
-
-  Log::Assert(splitVal >= bound[splitDimension].Lo());
-  Log::Assert(splitVal <= bound[splitDimension].Hi());
+  // Split in the midpoint of that dimension.
+  double splitVal = bound[splitDimension].Mid();
 
   // Perform the actual splitting.  This will order the dataset such that points
   // with value in dimension splitDimension less than or equal to splitVal are
@@ -91,12 +86,12 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
 }
 
 template<typename BoundType, typename MatType>
-bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
-                                              MatType& data,
-                                              const size_t begin,
-                                              const size_t count,
-                                              size_t& splitCol,
-                                              std::vector<size_t>& oldFromNew)
+bool MidpointSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
+                                                  MatType& data,
+                                                  const size_t begin,
+                                                  const size_t count,
+                                                  size_t& splitCol,
+                                                  std::vector<size_t>& oldFromNew)
 {
   size_t splitDimension = data.n_rows; // Indicate invalid.
   double maxWidth = -1;
@@ -149,14 +144,8 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
   if (maxWidth == 0) // All these points are the same.  We can't split.
     return false;
 
-  // Split in the mean of that dimension.
-  double splitVal = 0.0;
-  for (size_t i = begin; i < begin + count; ++i)
-    splitVal += data(splitDimension, i);
-  splitVal /= count;
-
-  Log::Assert(splitVal >= bound[splitDimension].Lo());
-  Log::Assert(splitVal <= bound[splitDimension].Hi());
+  // Split in the midpoint of that dimension.
+  double splitVal = bound[splitDimension].Mid();
 
   // Perform the actual splitting.  This will order the dataset such that points
   // with value in dimension splitDimension less than or equal to splitVal are
@@ -169,12 +158,12 @@ bool MeanSplit<BoundType, MatType>::SplitNode(const BoundType& bound,
 }
 
 template<typename BoundType, typename MatType>
-size_t MeanSplit<BoundType, MatType>::
-    PerformSplit(MatType& data,
-                 const size_t begin,
-                 const size_t count,
-                 const size_t splitDimension,
-                 const double splitVal)
+size_t MidpointSplit<BoundType, MatType>::PerformSplit(
+    MatType& data,
+    const size_t begin,
+    const size_t count,
+    const size_t splitDimension,
+    const double splitVal)
 {
   // This method modifies the input dataset.  We loop both from the left and
   // right sides of the points contained in this node.  The points less than
@@ -215,13 +204,13 @@ size_t MeanSplit<BoundType, MatType>::
 }
 
 template<typename BoundType, typename MatType>
-size_t MeanSplit<BoundType, MatType>::
-    PerformSplit(MatType& data,
-                 const size_t begin,
-                 const size_t count,
-                 const size_t splitDimension,
-                 const double splitVal,
-                 std::vector<size_t>& oldFromNew)
+size_t MidpointSplit<BoundType, MatType>::PerformSplit(
+    MatType& data,
+    const size_t begin,
+    const size_t count,
+    const size_t splitDimension,
+    const double splitVal,
+    std::vector<size_t>& oldFromNew)
 {
   // This method modifies the input dataset.  We loop both from the left and
   // right sides of the points contained in this node.  The points less than



More information about the mlpack-git mailing list