[mlpack-svn] r16714 - mlpack/trunk/src/mlpack/methods/decision_stump
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jun 25 15:27:19 EDT 2014
Author: rcurtin
Date: Wed Jun 25 15:27:19 2014
New Revision: 16714
Log:
Remove classLabels; it isn't necessary for the DecisionStump class to hold it.
Modified:
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp Wed Jun 25 15:27:19 2014
@@ -61,9 +61,6 @@
//! Size of bucket while determining splitting criterion.
size_t bucketSize;
- //! Stores the class labels for the input data.
- arma::Row<size_t> classLabels;
-
//! Stores the splitting values after training.
arma::vec split;
@@ -77,7 +74,8 @@
* @param attribute A row from the training data, which might be a
* candidate for the splitting attribute.
*/
- double SetupSplitAttribute(const arma::rowvec& attribute);
+ double SetupSplitAttribute(const arma::rowvec& attribute,
+ const arma::Row<size_t>& labels);
/**
* After having decided the attribute on which to split, train on that
@@ -86,7 +84,8 @@
* @param attribute attribute is the attribute decided by the constructor
* on which we now train the decision stump.
*/
- template <typename rType> void TrainOnAtt(const arma::rowvec& attribute);
+ template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
+ const arma::Row<size_t>& labels);
/**
* After the "split" matrix has been set up, merge ranges with identical class
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp Wed Jun 25 15:27:19 2014
@@ -31,24 +31,19 @@
const size_t classes,
size_t inpBucketSize)
{
- arma::Row<size_t> zLabels(labels.n_elem);
- zLabels.fill(0);
- classLabels = labels + zLabels;
-
numClass = classes;
bucketSize = inpBucketSize;
// Check whether the input labels are not all identical.
if (!isDistinct<size_t>(labels))
{
- // If the classLabels are all identical, the default class is the only
- // class.
+ // If the labels are all identical, the default class is the only class.
oneClass = true;
- defaultClass = classLabels(0);
+ defaultClass = labels(0);
}
else
{
- // If classLabels are not all identical, proceed with training.
+ // If labels are not all identical, proceed with training.
oneClass = false;
int bestAtt = -1;
double entropy;
@@ -56,7 +51,7 @@
// Set the default class to handle attribute values which are not present in
// the training data.
- defaultClass = CountMostFreq<size_t>(classLabels);
+ defaultClass = CountMostFreq<size_t>(labels);
for (int i = 0; i < data.n_rows; i++)
{
@@ -65,7 +60,7 @@
{
// For each attribute with non-identical values, treat it as a potential
// splitting attribute and calculate entropy if split on it.
- entropy = SetupSplitAttribute(data.row(i));
+ entropy = SetupSplitAttribute(data.row(i), labels);
// Find the attribute with the bestEntropy so that the gain is
// maximized.
@@ -79,7 +74,7 @@
splitCol = bestAtt;
// Once the splitting column/attribute has been decided, train on it.
- TrainOnAtt<double>(data.row(splitCol));
+ TrainOnAtt<double>(data.row(splitCol), labels);
}
}
@@ -131,7 +126,9 @@
* the splitting attribute.
*/
template <typename MatType>
-double DecisionStump<MatType>::SetupSplitAttribute(const arma::rowvec& attribute)
+double DecisionStump<MatType>::SetupSplitAttribute(
+ const arma::rowvec& attribute,
+ const arma::Row<size_t>& labels)
{
int i, count, begin, end;
double entropy = 0.0;
@@ -147,7 +144,7 @@
sortedLabels.fill(0);
for (i = 0; i < attribute.n_elem; i++)
- sortedLabels(i) = classLabels(sortedIndexAtt(i));
+ sortedLabels(i) = labels(sortedIndexAtt(i));
arma::rowvec subColLabels;
arma::rowvec subColAtts;
@@ -225,7 +222,8 @@
*/
template <typename MatType>
template <typename rType>
-void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute)
+void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute,
+ const arma::Row<size_t>& labels)
{
int i, count, begin, end;
@@ -237,7 +235,7 @@
arma::Row<size_t> tempLabel;
for (i = 0; i < attribute.n_elem; i++)
- sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
+ sortedLabels(i) = labels(sortedSplitIndexAtt(i));
arma::rowvec subCols;
rType mostFreq;
More information about the mlpack-svn
mailing list