[mlpack-git] master: Refactor DecisionStump significantly. (cca2005)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:11 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271
>---------------------------------------------------------------
commit cca2005449af19053365a595d60089dac2a487a0
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Nov 21 01:17:02 2015 +0000
Refactor DecisionStump significantly.
Remove unnecessary template definitions, avoid copies, clean up API, and add a standalone public Train() function.
>---------------------------------------------------------------
cca2005449af19053365a595d60089dac2a487a0
.../methods/decision_stump/decision_stump.hpp | 100 ++++++------
.../methods/decision_stump/decision_stump_impl.hpp | 167 ++++++++++-----------
2 files changed, 139 insertions(+), 128 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump.hpp b/src/mlpack/methods/decision_stump/decision_stump.hpp
index 61eea06..10099b9 100644
--- a/src/mlpack/methods/decision_stump/decision_stump.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump.hpp
@@ -25,7 +25,7 @@ namespace decision_stump {
*
* @tparam MatType Type of matrix that is being used (sparse or dense).
*/
-template <typename MatType = arma::mat>
+template<typename MatType = arma::mat>
class DecisionStump
{
public:
@@ -36,22 +36,12 @@ class DecisionStump
* @param data Input, training data.
* @param labels Labels of training data.
* @param classes Number of distinct classes in labels.
- * @param inpBucketSize Minimum size of bucket when splitting.
+ * @param bucketSize Minimum size of bucket when splitting.
*/
DecisionStump(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
- size_t inpBucketSize);
-
- /**
- * Classification function. After training, classify test, and put the
- * predicted classes in predictedLabels.
- *
- * @param test Testing data or data to classify.
- * @param predictedLabels Vector to store the predicted classes after
- * classifying test data.
- */
- void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+ const size_t bucketSize);
/**
* Alternate constructor which copies parameters bucketSize and numClass from
@@ -61,19 +51,43 @@ class DecisionStump
* @param other The other initiated Decision Stump object from
* which we copy the values.
* @param data The data on which to train this object on.
- * @param D Weight vector to use while training. For boosting purposes.
* @param labels The labels of data.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param weights Weight vector to use while training. For boosting purposes.
*/
DecisionStump(const DecisionStump<>& other,
const MatType& data,
const arma::Row<size_t>& labels,
const arma::rowvec& weights);
+ /**
+ * Train the decision stump on the given data. This completely overwrites any
+ * previous training data, so after training the stump may be completely
+ * different.
+ *
+ * @param data Dataset to train on.
+ * @param labels Labels for each point in the dataset.
+ * @param classes Number of classes in the dataset.
+ * @param bucketSize Minimum size of bucket when splitting.
+ */
+ void Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const size_t classes,
+ const size_t bucketSize);
+
+ /**
+ * Classification function. After training, classify test, and put the
+ * predicted classes in predictedLabels.
+ *
+ * @param test Testing data or data to classify.
+ * @param predictedLabels Vector to store the predicted classes after
+ * classifying test data.
+ */
+ void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+
//! Access the splitting attribute.
- int SplitAttribute() const { return splitAttribute; }
+ size_t SplitAttribute() const { return splitAttribute; }
//! Modify the splitting attribute (be careful!).
- int& SplitAttribute() { return splitAttribute; }
+ size_t& SplitAttribute() { return splitAttribute; }
//! Access the splitting values.
const arma::vec& Split() const { return split; }
@@ -86,18 +100,15 @@ class DecisionStump
arma::Col<size_t>& BinLabels() { return binLabels; }
private:
- //! Stores the number of classes.
- size_t numClasses;
-
- //! Stores the value of the attribute on which to split.
- int splitAttribute;
-
- //! Size of bucket while determining splitting criterion.
+ //! The number of classes (we must store this for boosting).
+ size_t classes;
+ //! The minimum number of points in a bucket.
size_t bucketSize;
+ //! Stores the value of the attribute on which to split.
+ size_t splitAttribute;
//! Stores the splitting values after training.
arma::vec split;
-
//! Stores the labels for each splitting bin.
arma::Col<size_t> binLabels;
@@ -107,9 +118,9 @@ class DecisionStump
*
* @param attribute A row from the training data, which might be a
* candidate for the splitting attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
- template <bool isWeight>
+ template<bool UseWeights>
double SetupSplitAttribute(const arma::rowvec& attribute,
const arma::Row<size_t>& labels,
const arma::rowvec& weightD);
@@ -121,8 +132,8 @@ class DecisionStump
* @param attribute attribute is the attribute decided by the constructor
* on which we now train the decision stump.
*/
- template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
- const arma::Row<size_t>& labels);
+ void TrainOnAtt(const arma::rowvec& attribute,
+ const arma::Row<size_t>& labels);
/**
* After the "split" matrix has been set up, merge ranges with identical class
@@ -136,42 +147,43 @@ class DecisionStump
* @param subCols The vector in which to find the most frequently
* occurring element.
*/
- template <typename rType> rType CountMostFreq(const arma::Row<rType>&
- subCols);
+ template<typename VecType>
+ double CountMostFreq(const VecType& subCols);
/**
* Returns 1 if all the values of featureRow are not same.
*
* @param featureRow The attribute which is checked for identical values.
*/
- template <typename rType> int IsDistinct(const arma::Row<rType>& featureRow);
+ template<typename VecType>
+ int IsDistinct(const VecType& featureRow);
/**
* Calculate the entropy of the given attribute.
*
- * @param attribute The attribute of which we calculate the entropy.
* @param labels Corresponding labels of the attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param classes Number of classes.
+ * @param weights Weights for this set of labels.
*/
- template <typename LabelType, bool isWeight>
- double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
- const arma::rowvec& tempD);
+ template<bool UseWeights, typename VecType, typename WeightVecType>
+ double CalculateEntropy(const VecType& labels,
+ const WeightVecType& weights);
/**
* Train the decision stump on the given data and labels.
*
* @param data Dataset to train on.
* @param labels Labels for dataset.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
- template <bool isWeight>
- void Train(const MatType& data, const arma::Row<size_t>& labels,
- const arma::rowvec& weightD);
-
+ template<bool UseWeights>
+ void Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& weights);
};
-}; // namespace decision_stump
-}; // namespace mlpack
+} // namespace decision_stump
+} // namespace mlpack
#include "decision_stump_impl.hpp"
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index 4b5d987..c052063 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -20,20 +20,35 @@ namespace decision_stump {
* @param data Input, training data.
* @param labels Labels of data.
* @param classes Number of distinct classes in labels.
- * @param inpBucketSize Minimum size of bucket when splitting.
+ * @param bucketSize Minimum size of bucket when splitting.
*/
template<typename MatType>
DecisionStump<MatType>::DecisionStump(const MatType& data,
const arma::Row<size_t>& labels,
const size_t classes,
- size_t inpBucketSize)
+ const size_t bucketSize) :
+ classes(classes),
+ bucketSize(bucketSize)
{
- numClasses = classes;
- bucketSize = inpBucketSize;
+ arma::rowvec weights;
+ Train<false>(data, labels, weights);
+}
- arma::rowvec weightD;
+/**
+ * Train on the given data and labels.
+ */
+template<typename MatType>
+void DecisionStump<MatType>::Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const size_t classes,
+ const size_t bucketSize)
+{
+ this->classes = classes;
+ this->bucketSize = bucketSize;
- Train<false>(data, labels, weightD);
+ // Pass to unweighted training function.
+ arma::rowvec weights;
+ Train<false>(data, labels, weights);
}
/**
@@ -41,36 +56,38 @@ DecisionStump<MatType>::DecisionStump(const MatType& data,
*
* @param data Dataset to train on.
* @param labels Labels for dataset.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
template<typename MatType>
-template <bool isWeight>
-void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>& labels,
- const arma::rowvec& weightD)
+template<bool UseWeights>
+void DecisionStump<MatType>::Train(const MatType& data,
+ const arma::Row<size_t>& labels,
+ const arma::rowvec& weights)
{
+ this->classes = classes;
+ this->bucketSize = bucketSize;
+
// If classLabels are not all identical, proceed with training.
- int bestAtt = 0;
+ size_t bestAtt = 0;
double entropy;
- const double rootEntropy = CalculateEntropy<size_t, isWeight>(
- labels.subvec(0, labels.n_elem - 1), 0, weightD);
+ const double rootEntropy = CalculateEntropy<UseWeights>(labels, weights);
double gain, bestGain = 0.0;
for (size_t i = 0; i < data.n_rows; i++)
{
// Go through each attribute of the data.
- if (IsDistinct<double>(data.row(i)))
+ if (IsDistinct(data.row(i)))
{
// For each attribute with non-identical values, treat it as a potential
// splitting attribute and calculate entropy if split on it.
- entropy = SetupSplitAttribute<isWeight>(data.row(i), labels, weightD);
+ entropy = SetupSplitAttribute<UseWeights>(data.row(i), labels, weights);
gain = rootEntropy - entropy;
// Find the attribute with the best entropy so that the gain is
// maximized.
- // if (entropy < bestEntropy)
- // Instead of the above rule, we are maximizing gain, which was
- // what is returned from SetupSplitAttribute.
+ // We are maximizing gain, which is what is returned from
+ // SetupSplitAttribute().
if (gain < bestGain)
{
bestAtt = i;
@@ -81,7 +98,7 @@ void DecisionStump<MatType>::Train(const MatType& data, const arma::Row<size_t>&
splitAttribute = bestAtt;
// Once the splitting column/attribute has been decided, train on it.
- TrainOnAtt<double>(data.row(splitAttribute), labels);
+ TrainOnAtt(data.row(splitAttribute), labels);
}
/**
@@ -126,20 +143,16 @@ void DecisionStump<MatType>::Classify(const MatType& test,
* @param data The data on which to train this object on.
* @param D Weight vector to use while training. For boosting purposes.
* @param labels The labels of data.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
-template <typename MatType>
+template<typename MatType>
DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
const MatType& data,
const arma::Row<size_t>& labels,
- const arma::rowvec& weights)
+ const arma::rowvec& weights) :
+ classes(other.classes),
+ bucketSize(other.bucketSize)
{
- numClasses = other.numClasses;
- bucketSize = other.bucketSize;
-
- // weightD = weights;
- // tempD = weightD;
-
Train<true>(data, labels, weights);
}
@@ -149,14 +162,14 @@ DecisionStump<MatType>::DecisionStump(const DecisionStump<>& other,
*
* @param attribute A row from the training data, which might be a candidate for
* the splitting attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
-template <typename MatType>
-template <bool isWeight>
+template<typename MatType>
+template<bool UseWeights>
double DecisionStump<MatType>::SetupSplitAttribute(
const arma::rowvec& attribute,
const arma::Row<size_t>& labels,
- const arma::rowvec& weightD)
+ const arma::rowvec& weights)
{
size_t i, count, begin, end;
double entropy = 0.0;
@@ -169,23 +182,22 @@ double DecisionStump<MatType>::SetupSplitAttribute(
arma::uvec sortedIndexAtt = arma::stable_sort_index(attribute.t());
arma::Row<size_t> sortedLabels(attribute.n_elem);
- sortedLabels.fill(0);
-
- arma::rowvec tempD = arma::rowvec(weightD.n_cols);
+ arma::rowvec sortedWeights(attribute.n_elem);
for (i = 0; i < attribute.n_elem; i++)
{
sortedLabels(i) = labels(sortedIndexAtt(i));
- if(isWeight)
- tempD(i) = weightD(sortedIndexAtt(i));
+ // Apply weights if necessary.
+ if (UseWeights)
+ sortedWeights(i) = weights(sortedIndexAtt(i));
}
i = 0;
count = 0;
- // This splits the sorted into buckets of size greater than or equal to
- // inpBucketSize.
+ // This splits the sorted data into buckets of size greater than or equal to
+ // bucketSize.
while (i < sortedLabels.n_elem)
{
count++;
@@ -199,8 +211,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
// Use ratioEl to calculate the ratio of elements in this split.
const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
- entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
- sortedLabels.subvec(begin, end), begin, tempD);
+ entropy += ratioEl * CalculateEntropy<UseWeights>(
+ sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end));
i++;
}
else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -226,8 +238,8 @@ double DecisionStump<MatType>::SetupSplitAttribute(
}
const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem);
- entropy += ratioEl * CalculateEntropy<size_t, isWeight>(
- sortedLabels.subvec(begin, end), begin, tempD);
+ entropy += ratioEl * CalculateEntropy<UseWeights>(
+ sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end));
i = end + 1;
count = 0;
@@ -245,8 +257,7 @@ double DecisionStump<MatType>::SetupSplitAttribute(
* @param attribute Attribute is the attribute decided by the constructor on
* which we now train the decision stump.
*/
-template <typename MatType>
-template <typename rType>
+template<typename MatType>
void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
const arma::Row<size_t>& labels)
{
@@ -256,14 +267,12 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
arma::Row<size_t> sortedLabels(attribute.n_elem);
sortedLabels.fill(0);
- arma::vec tempSplit;
- arma::Row<size_t> tempLabel;
for (i = 0; i < attribute.n_elem; i++)
sortedLabels(i) = labels(sortedSplitIndexAtt(i));
arma::rowvec subCols;
- rType mostFreq;
+ double mostFreq;
i = 0;
count = 0;
while (i < sortedLabels.n_elem)
@@ -274,12 +283,7 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
begin = i - count + 1;
end = i;
- arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem);
- zSubCols.fill(0.0);
-
- subCols = sortedLabels.cols(begin, end) + zSubCols;
-
- mostFreq = CountMostFreq<double>(subCols);
+ mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
split.resize(split.n_elem + 1);
split(split.n_elem - 1) = sortedSplitAtt(begin);
@@ -304,14 +308,10 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
begin = i - count + 1;
end = i;
}
- arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem);
- zSubCols.fill(0.0);
-
- subCols = sortedLabels.cols(begin, end) + zSubCols;
// Find the most frequent element in subCols so as to assign a label to
// the bucket of subCols.
- mostFreq = CountMostFreq<double>(subCols);
+ mostFreq = CountMostFreq(sortedLabels.cols(begin, end));
split.resize(split.n_elem + 1);
split(split.n_elem - 1) = sortedSplitAtt(begin);
@@ -334,7 +334,7 @@ void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
* After the "split" matrix has been set up, merge ranges with identical class
* labels.
*/
-template <typename MatType>
+template<typename MatType>
void DecisionStump<MatType>::MergeRanges()
{
for (size_t i = 1; i < split.n_rows; i++)
@@ -350,13 +350,13 @@ void DecisionStump<MatType>::MergeRanges()
}
}
-template <typename MatType>
-template <typename rType>
-rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
+template<typename MatType>
+template<typename VecType>
+double DecisionStump<MatType>::CountMostFreq(const VecType& subCols)
{
// We'll create a map of elements and the number of times that each element is
// seen.
- std::map<rType, size_t> countMap;
+ std::map<double, size_t> countMap;
for (size_t i = 0; i < subCols.n_elem; ++i)
{
@@ -367,8 +367,8 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
}
// Now find the maximum value.
- typename std::map<rType, size_t>::iterator it = countMap.begin();
- rType mostFreq = it->first;
+ typename std::map<double, size_t>::iterator it = countMap.begin();
+ double mostFreq = it->first;
size_t mostFreqCount = it->second;
while (it != countMap.end())
{
@@ -385,15 +385,15 @@ rType DecisionStump<MatType>::CountMostFreq(const arma::Row<rType>& subCols)
}
/**
- * Returns 1 if all the values of featureRow are not same.
+ * Returns 1 if all the values of featureRow are not the same.
*
* @param featureRow The attribute which is checked for identical values.
*/
-template <typename MatType>
-template <typename rType>
-int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
+template<typename MatType>
+template<typename VecType>
+int DecisionStump<MatType>::IsDistinct(const VecType& featureRow)
{
- rType val = featureRow(0);
+ typename VecType::elem_type val = featureRow(0);
for (size_t i = 1; i < featureRow.n_elem; ++i)
if (val != featureRow(i))
return 1;
@@ -405,34 +405,33 @@ int DecisionStump<MatType>::IsDistinct(const arma::Row<rType>& featureRow)
*
* @param attribute The attribute for which we calculate the entropy.
* @param labels Corresponding labels of the attribute.
- * @param isWeight Whether we need to run a weighted Decision Stump.
+ * @param UseWeights Whether we need to run a weighted Decision Stump.
*/
template<typename MatType>
-template<typename LabelType, bool isWeight>
+template<bool UseWeights, typename VecType, typename WeightVecType>
double DecisionStump<MatType>::CalculateEntropy(
- arma::subview_row<LabelType> labels,
- int begin, const arma::rowvec& tempD)
+ const VecType& labels,
+ const WeightVecType& weights)
{
double entropy = 0.0;
size_t j;
- arma::Row<size_t> numElem(numClasses);
+ arma::Row<size_t> numElem(classes);
numElem.fill(0);
// Variable to accumulate the weight in this subview_row.
double accWeight = 0.0;
// Populate numElem; they are used as helpers to calculate entropy.
- if (isWeight)
+ if (UseWeights)
{
for (j = 0; j < labels.n_elem; j++)
{
- numElem(labels(j)) += tempD(j + begin);
- accWeight += tempD(j + begin);
+ numElem(labels(j)) += weights(j);
+ accWeight += weights(j);
}
- // numElem(labels(j))++;
- for (j = 0; j < numClasses; j++)
+ for (j = 0; j < classes; j++)
{
const double p1 = ((double) numElem(j) / accWeight);
@@ -447,7 +446,7 @@ double DecisionStump<MatType>::CalculateEntropy(
for (j = 0; j < labels.n_elem; j++)
numElem(labels(j))++;
- for (j = 0; j < numClasses; j++)
+ for (j = 0; j < classes; j++)
{
const double p1 = ((double) numElem(j) / labels.n_elem);
@@ -461,7 +460,7 @@ double DecisionStump<MatType>::CalculateEntropy(
return entropy / std::log(2.0);
}
-}; // namespace decision_stump
-}; // namespace mlpack
+} // namespace decision_stump
+} // namespace mlpack
#endif
More information about the mlpack-git
mailing list