[mlpack-svn] r13232 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Jul 15 22:53:18 EDT 2012
Author: rcurtin
Date: 2012-07-15 22:53:18 -0400 (Sun, 15 Jul 2012)
New Revision: 13232
Modified:
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Clean up FindSplit() and reduce API to only what is necessary.
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-16 00:57:59 UTC (rev 13231)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-16 02:53:18 UTC (rev 13232)
@@ -148,20 +148,17 @@
inline double LogNegativeError(const size_t total_points) const;
bool FindSplit(const arma::mat& data,
- size_t& split_dim,
- size_t& split_ind,
- double& left_error,
- double& right_error,
+ size_t& splitDim,
+ double& splitValue,
+ double& leftError,
+ double& rightError,
const size_t maxLeafSize = 10,
const size_t minLeafSize = 5) const;
- void SplitData_(MatType* data,
- size_t split_dim,
- size_t split_ind,
- arma::Col<size_t>* old_from_new,
- eT* split_val,
- eT* lsplit_val,
- eT* rsplit_val);
+ size_t SplitData(arma::mat& data,
+ const size_t splitDim,
+ const double splitValue,
+ arma::Col<size_t>& oldFromNew) const;
void GetMaxMinVals_(MatType* data,
VecType* max_vals,
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-16 00:57:59 UTC (rev 13231)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-16 02:53:18 UTC (rev 13232)
@@ -31,7 +31,7 @@
template<typename eT, typename cT>
bool DTree<eT, cT>::FindSplit(const arma::mat& data,
size_t& splitDim,
- size_t& splitIndex,
+ double& splitValue,
double& leftError,
double& rightError,
const size_t maxLeafSize,
@@ -63,9 +63,9 @@
bool dimSplitFound = false;
// Take an error estimate for this dimension.
double minDimError = points / (max - min);
- double dimLeftError = 0.0;
- double dimRightError = 0.0;
- size_t dimSplitIndex = -1;
+ double dimLeftError;
+ double dimRightError;
+ double dimSplitValue;
// Find the log volume of all the other dimensions.
double volumeWithoutDim = 0;
@@ -90,17 +90,13 @@
// there are spikes if this min_leaf_size is enforced here...
for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
{
- double split;
- double lsplit = dimVec[i];
- double rsplit = dimVec[i + 1];
-
- if (lsplit == rsplit)
- continue; // We can't split here.
-
// This makes sense for real continuous data. This kinda corrupts the
// data and estimation if the data is ordinal.
- split = (lsplit + rsplit) / 2;
+ const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+ if (split == dimVec[i])
+ continue; // We can't split here (two points are the same).
+
// Another way of picking split is using this:
// split = left_split;
if ((split - min > 0.0) && (max - split > 0.0))
@@ -124,14 +120,12 @@
minDimError = negLeftError + negRightError;
dimLeftError = negLeftError;
dimRightError = negRightError;
- dimSplitIndex = i;
+ dimSplitValue = split;
dimSplitFound = true;
}
}
}
- dimVec.clear();
-
double actualMinDimError = std::log(minDimError) - 2 * std::log(data.n_cols)
- volumeWithoutDim;
@@ -141,7 +135,7 @@
// estimate.
minError = actualMinDimError;
splitDim = dim;
- splitIndex = dimSplitIndex;
+ splitValue = dimSplitValue;
leftError = std::log(dimLeftError) - 2 * std::log(data.n_cols) -
volumeWithoutDim;
rightError = std::log(dimRightError) - 2 * std::log(data.n_cols) -
@@ -159,59 +153,37 @@
}
template<typename eT, typename cT>
-void DTree<eT, cT>::SplitData_(MatType* data,
- size_t split_dim,
- size_t split_ind,
- arma::Col<size_t> *old_from_new,
- eT *split_val,
- eT *lsplit_val,
- eT *rsplit_val)
+size_t DTree<eT, cT>::SplitData(arma::mat& data,
+ const size_t splitDim,
+ const double splitValue,
+ arma::Col<size_t>& oldFromNew) const
{
- // Get the values for the split dimension.
- RowVecType dimVec = data->row(split_dim).subvec(start_, end_ - 1);
-
- // Sort the values.
- dimVec = arma::sort(dimVec);
-
- *lsplit_val = dimVec[split_ind];
- *rsplit_val = dimVec[split_ind + 1];
- *split_val = (*lsplit_val + *rsplit_val) / 2 ;
-
- std::vector<bool> left_membership;
- left_membership.reserve(end_ - start_);
-
- for (size_t i = start_; i < end_; ++i)
- {
- if ((*data)(split_dim, i) > *split_val)
- left_membership[i - start_] = false;
- else
- left_membership[i - start_] = true;
- }
-
- size_t left_ind = start_, right_ind = end_ - 1;
+ // Swap all columns such that any columns with value in dimension splitDim
+ // less than or equal to splitValue are on the left side, and all others are
+ // on the right side. A similar sort to this is also performed in
+ // BinarySpaceTree construction (its comments are more detailed).
+ size_t left = start_;
+ size_t right = end_ - 1;
for (;;)
{
- while (left_membership[left_ind - start_] && (left_ind <= right_ind))
- left_ind++;
+ while (data(splitDim, left) <= splitValue)
+ ++left;
+ while (data(splitDim, right) > splitValue)
+ --right;
- while (!left_membership[right_ind - start_] && (left_ind <= right_ind))
- right_ind--;
-
- if (left_ind > right_ind)
+ if (left > right)
break;
+ data.swap_cols(left, right);
- data->swap_cols(left_ind, right_ind);
- bool tmp = left_membership[left_ind - start_];
- left_membership[left_ind - start_] = left_membership[right_ind - start_];
- left_membership[right_ind - start_] = tmp;
-
- size_t t = (*old_from_new)[left_ind];
- (*old_from_new)[left_ind] = (*old_from_new)[right_ind];
- (*old_from_new)[right_ind] = t;
+ // Store the mapping from old to new.
+ const size_t tmp = oldFromNew[left];
+ oldFromNew[left] = oldFromNew[right];
+ oldFromNew[right] = tmp;
}
- assert(left_ind == right_ind + 1);
+ // This now refers to the first index of the "right" side.
+ return left;
}
@@ -371,16 +343,16 @@
if ((size_t) (end_ - start_) > maxLeafSize) {
// Find the split.
- size_t dim, split_ind;
+ size_t dim;
+ double splitValue;
double left_error, right_error;
- if (FindSplit(*data, dim, split_ind, left_error, right_error, maxLeafSize,
+ if (FindSplit(*data, dim, splitValue, left_error, right_error, maxLeafSize,
minLeafSize))
{
// Move the data around for the children to have points in a node lie
// contiguously (to increase efficiency during the training).
- eT split_val, lsplit_val, rsplit_val;
- SplitData_(data, dim, split_ind, old_from_new, &split_val, &lsplit_val,
- &rsplit_val);
+ const size_t splitIndex = SplitData(*data, dim, splitValue,
+ *old_from_new);
// Make max and min vals for the children.
VecType* max_vals_l = new VecType(*max_vals_);
@@ -388,17 +360,17 @@
VecType* min_vals_l = new VecType(*min_vals_);
VecType* min_vals_r = new VecType(*min_vals_);
- (*max_vals_l)[dim] = split_val;
- (*min_vals_r)[dim] = split_val;
+ (*max_vals_l)[dim] = splitValue;
+ (*min_vals_r)[dim] = splitValue;
// Store split dim and split val in the node.
- split_value_ = split_val;
+ split_value_ = splitValue;
split_dim_ = dim;
// Recursively grow the children.
- left_ = new DTree(max_vals_l, min_vals_l, start_, start_ + split_ind + 1,
+ left_ = new DTree(max_vals_l, min_vals_l, start_, splitIndex,
(cT) left_error);
- right_ = new DTree(max_vals_r, min_vals_r, start_ + split_ind + 1, end_,
+ right_ = new DTree(max_vals_r, min_vals_r, splitIndex, end_,
(cT) right_error);
left_g = left_->Grow(data, old_from_new, useVolReg, maxLeafSize,
More information about the mlpack-svn
mailing list