[mlpack-svn] r16708 - mlpack/trunk/src/mlpack/methods/decision_stump
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Jun 24 20:22:10 EDT 2014
Author: rcurtin
Date: Tue Jun 24 20:22:09 2014
New Revision: 16708
Log:
For clarity, use separate split and binLabels objects instead of storing the label in the split matrix. No casting is necessary anymore.
Modified:
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp Tue Jun 24 20:22:09 2014
@@ -64,8 +64,11 @@
//! Stores the class labels for the input data.
arma::Row<size_t> classLabels;
- //! Stores the splitting criterion after training.
- arma::mat split;
+ //! Stores the splitting values after training.
+ arma::vec split;
+
+ //! Stores the labels for each splitting bin.
+ arma::Col<size_t> binLabels;
/**
* Sets up attribute as if it were splitting on it and finds entropy when
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp Tue Jun 24 20:22:09 2014
@@ -39,7 +39,7 @@
bucketSize = inpBucketSize;
// Check whether the input labels are not all identical.
- if (!isDistinct<size_t>(classLabels))
+ if (!isDistinct<size_t>(labels))
{
// If the classLabels are all identical, the default class is the only
// class.
@@ -99,31 +99,21 @@
{
for (int i = 0; i < test.n_cols; i++)
{
- int j = 0;
-
+ // Determine which bin the test point falls into.
+ // Assume first that it falls into the first bin, then proceed through the
+ // bins until it is known which bin it falls into.
+ int bin = 0;
const double val = test(splitCol, i);
- while (j < split.n_rows)
+
+ while (bin < split.n_elem - 1)
{
- if (val < split(j, 0) && (!j))
- {
- predictedLabels(i) = split(0, 1);
+ if (val < split(bin + 1))
break;
- }
- else if (val >= split(j, 0))
- {
- if (j == split.n_rows - 1)
- {
- predictedLabels(i) = split(split.n_rows - 1, 1);
- break;
- }
- else if (val < split(j + 1, 0))
- {
- predictedLabels(i) = split(j, 1);
- break;
- }
- }
- j++;
+
+ ++bin;
}
+
+ predictedLabels(i) = binLabels(bin);
}
}
else
@@ -243,7 +233,8 @@
arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t());
arma::Row<size_t> sortedLabels(attribute.n_elem);
sortedLabels.fill(0);
- arma::mat tempSplit;
+ arma::vec tempSplit;
+ arma::Row<size_t> tempLabel;
for (i = 0; i < attribute.n_elem; i++)
sortedLabels(i) = classLabels(sortedSplitIndexAtt(i));
@@ -267,8 +258,10 @@
mostFreq = CountMostFreq<double>(subCols);
- tempSplit << sortedSplitAtt(begin)<< mostFreq << arma::endr;
- split = arma::join_cols(split, tempSplit);
+ split.resize(split.n_elem + 1);
+ split(split.n_elem - 1) = sortedSplitAtt(begin);
+ binLabels.resize(binLabels.n_elem + 1);
+ binLabels(binLabels.n_elem - 1) = mostFreq;
i++;
}
@@ -297,8 +290,10 @@
// the bucket of subCols.
mostFreq = CountMostFreq<double>(subCols);
- tempSplit << sortedSplitAtt(begin) << mostFreq << arma::endr;
- split = arma::join_cols(split, tempSplit);
+ split.resize(split.n_elem + 1);
+ split(split.n_elem - 1) = sortedSplitAtt(begin);
+ binLabels.resize(binLabels.n_elem + 1);
+ binLabels(binLabels.n_elem - 1) = mostFreq;
i = end + 1;
count = 0;
@@ -321,9 +316,10 @@
{
for (int i = 1; i < split.n_rows; i++)
{
- if (split(i, 1) == split(i - 1, 1))
+ if (binLabels(i) == binLabels(i - 1))
{
// Remove this row, as it has the same label as the previous bucket.
+ binLabels.shed_row(i);
split.shed_row(i);
// Go back to previous row.
i--;
More information about the mlpack-svn
mailing list