[mlpack-svn] r16759 - in mlpack/trunk/src/mlpack: methods/decision_stump tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 3 14:39:04 EDT 2014
Author: saxena.udit
Date: Thu Jul 3 14:39:04 2014
New Revision: 16759
Log:
New test added. Improved entropy calculation.
Modified:
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump.hpp Thu Jul 3 14:39:04 2014
@@ -45,12 +45,13 @@
*/
void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+ int splitCol;
private:
//! Stores the number of classes.
size_t numClass;
//! Stores the value of the attribute on which to split.
- int splitCol;
+ // int splitCol;
//! Size of bucket while determining splitting criterion.
size_t bucketSize;
@@ -109,8 +110,7 @@
* @param labels Corresponding labels of the attribute.
*/
template <typename AttType, typename LabelType>
- double CalculateEntropy(arma::subview_row<AttType> attribute,
- arma::subview_row<LabelType> labels);
+ double CalculateEntropy(arma::subview_row<LabelType> labels);
};
}; // namespace decision_stump
Modified: mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/decision_stump/decision_stump_impl.hpp Thu Jul 3 14:39:04 2014
@@ -147,21 +147,23 @@
i = 0;
count = 0;
-
+ double ratioEl;
// This splits the sorted into buckets of size greater than or equal to
// inpBucketSize.
while (i < sortedLabels.n_elem)
{
count++;
- if (i == sortedLabels.n_elem - 1)
+ if (i == sortedLabels.n_elem - 1)
{
// if we're at the end, then don't worry about the bucket size
// just take this as the last bin.
begin = i - count + 1;
end = i;
-
- entropy += CalculateEntropy<double, size_t>(
- sortedAtt.subvec(begin,end),sortedLabels.subvec(begin,end));
+
+ // using ratioEl to calculate the ratio of elements in this split.
+ ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
+
+ entropy += ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
i++;
}
else if (sortedLabels(i) != sortedLabels(i + 1))
@@ -171,6 +173,8 @@
if (count < bucketSize)
{
// if it is, then take the minimum bucket size anyways
+ // this is where the inpBucketSize comes into use
+ // This makes sure there isn't a bucket for every change in labels.
begin = i - count + 1;
end = begin + bucketSize - 1;
@@ -183,9 +187,9 @@
begin = i - count + 1;
end = i;
}
-
- entropy += CalculateEntropy<double, size_t>(
- sortedAtt.subvec(begin,end),sortedLabels.subvec(begin,end));
+ ratioEl = ((double)(end - begin + 1)/sortedLabels.n_elem);
+
+ entropy +=ratioEl * CalculateEntropy<size_t>(sortedLabels.subvec(begin,end));
i = end + 1;
count = 0;
@@ -269,7 +273,7 @@
// Find the most frequent element in subCols so as to assign a label to
// the bucket of subCols.
- mostFreq = CountMostFreq<double>(subCols);//sortedLabels.subvec(begin, end));
+ mostFreq = CountMostFreq<double>(subCols);
split.resize(split.n_elem + 1);
split(split.n_elem - 1) = sortedSplitAtt(begin);
@@ -372,45 +376,25 @@
*/
template<typename MatType>
template<typename AttType, typename LabelType>
-double DecisionStump<MatType>::CalculateEntropy(arma::subview_row<AttType> attribute,
- arma::subview_row<LabelType> labels)
+double DecisionStump<MatType>::CalculateEntropy(arma::subview_row<LabelType> labels)
{
double entropy = 0.0;
-
- arma::rowvec uniqueAtt = arma::unique(attribute);
- arma::Row<LabelType> uniqueLabel = arma::unique(labels);
- arma::Row<size_t> numElem(uniqueAtt.n_elem);
+ size_t j;
+
+ arma::Row<size_t> numElem(numClass);
numElem.fill(0);
- arma::Mat<size_t> entropyArray(uniqueAtt.n_elem,numClass);
- entropyArray.fill(0);
- // Populate entropyArray and numElem; they are used as helpers to calculate
+ // Populate numElem; they are used as helpers to calculate
// entropy.
- for (int j = 0; j < uniqueAtt.n_elem; j++)
- {
- for (int i = 0; i < attribute.n_elem; i++)
- {
- if (uniqueAtt[j] == attribute[i])
- {
- entropyArray(j, labels(i))++;
- numElem(j)++;
- }
- }
- }
+ for (j = 0; j < labels.n_elem; j++)
+ numElem(labels(j))++;
- for (int j = 0; j < uniqueAtt.size(); j++)
+ for (j = 0; j < numClass; j++)
{
- const double p1 = ((double) numElem(j) / attribute.n_elem);
-
- for (int i = 0; i < numClass; i++)
- {
- const double p2 = ((double) entropyArray(j, i) / numElem(j));
- const double p3 = (p2 == 0) ? 0 : p2 * log2(p2);
-
- entropy += p1 * p3;
- }
+ const double p1 = ((double) numElem(j) / labels.n_elem);
+
+ entropy += (p1 == 0) ? 0 : p1 * log2(p1);
}
-
return entropy;
}
Modified: mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp Thu Jul 3 14:39:04 2014
@@ -6,7 +6,7 @@
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/decision_stump/decision_stump.hpp>
-
+
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -47,6 +47,36 @@
BOOST_CHECK_EQUAL(predictedLabels(i), 1);
}
+/*
+This tests whether the entropy is being correctly calculated by
+checking the correct value of the splitting column value.
+This test is for an inpBucketSize of 4 and the correct value of
+the splitCol is 1.
+*/
+BOOST_AUTO_TEST_CASE(CorrectAttributeChosen)
+{
+ const size_t numClasses = 2;
+ const size_t inpBucketSize = 4;
+
+ mat trainingData;
+ trainingData << 0 << 0 << 0 << 0 << 0 << 1 << 1 << 1 << 1
+ << 2 << 2 << 2 << 2 << 2 << endr
+ << 70 << 90 << 85 << 95 << 70 << 90 << 78 << 65 << 75
+ << 80 << 70 << 80 << 80 << 96 << endr
+ << 1 << 1 << 0 << 0 << 0 << 1 << 0 << 1 << 0
+ << 1 << 1 << 0 << 0 << 0 << endr;
+
+ // No need to normalize labels here.
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 1 << 1 << 1 << 0 << 0 << 0 << 0
+ << 0 << 1 << 1 << 0 << 0 << 0;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
+
+ // Only need to check the value of the splitting column, no need of classification.
+
+ BOOST_CHECK_EQUAL(ds.splitCol,1);
+}
/**
* This tests for the classification:
More information about the mlpack-svn
mailing list