[mlpack-svn] r16793 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 9 12:50:00 EDT 2014
Author: rcurtin
Date: Wed Jul 9 12:49:59 2014
New Revision: 16793
Log:
Another test to make sure the correct splitting attribute is used.
Modified:
mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp
Modified: mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/decision_stump_test.cpp Wed Jul 9 12:49:59 2014
@@ -13,6 +13,7 @@
using namespace mlpack;
using namespace mlpack::decision_stump;
using namespace arma;
+using namespace mlpack::distribution;
BOOST_AUTO_TEST_SUITE(DecisionStumpTest);
@@ -47,24 +48,28 @@
BOOST_CHECK_EQUAL(predictedLabels(i), 1);
}
-/*
-This tests whether the entropy is being correctly calculated by
-checking the correct value of the splitting column value.
-This test is for an inpBucketSize of 4 and the correct value of
-the splitCol is 1.
-*/
+
+/**
+ * This tests whether the entropy is being correctly calculated by checking the
+ * correct value of the splitting column value. This test is for an
+ * inpBucketSize of 4 and the correct value of the splitting attribute is 0.
+ */
BOOST_AUTO_TEST_CASE(CorrectAttributeChosen)
{
const size_t numClasses = 2;
const size_t inpBucketSize = 4;
+ // This dataset comes from Chapter 6 of the book "Data Mining: Concepts,
+ // Models, Methods, and Algorithms" (2nd Edition) by Mehmed Kantardzic. It is
+ // found on page 176 (and a description of the correct splitting attribute is
+ // given below that).
mat trainingData;
- trainingData << 0 << 0 << 0 << 0 << 0 << 1 << 1 << 1 << 1
- << 2 << 2 << 2 << 2 << 2 << endr
+ trainingData << 0 << 0 << 0 << 0 << 0 << 1 << 1 << 1 << 1
+ << 2 << 2 << 2 << 2 << 2 << endr
<< 70 << 90 << 85 << 95 << 70 << 90 << 78 << 65 << 75
- << 80 << 70 << 80 << 80 << 96 << endr
- << 1 << 1 << 0 << 0 << 0 << 1 << 0 << 1 << 0
- << 1 << 1 << 0 << 0 << 0 << endr;
+ << 80 << 70 << 80 << 80 << 96 << endr
+ << 1 << 1 << 0 << 0 << 0 << 1 << 0 << 1 << 0
+ << 1 << 1 << 0 << 0 << 0 << endr;
// No need to normalize labels here.
Mat<size_t> labelsIn;
@@ -74,8 +79,7 @@
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
// Only need to check the value of the splitting column, no need of classification.
-
- BOOST_CHECK_EQUAL(ds.splitCol,1);
+ BOOST_CHECK_EQUAL(ds.SplitAttribute(), 0);
}
/**
@@ -210,4 +214,97 @@
BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
}
+/**
+ * This tests that the decision stump can learn a good split on a dataset with
+ * four dimensions that have progressing levels of separation.
+ */
+BOOST_AUTO_TEST_CASE(DimensionSelectionTest)
+{
+ const size_t numClasses = 2;
+ const size_t inpBucketSize = 25;
+
+ arma::mat dataset(4, 5000);
+
+ // The most separable dimension.
+ GaussianDistribution g1("-5", "1");
+ GaussianDistribution g2("5", "1");
+
+ for (size_t i = 0; i < 2500; ++i)
+ {
+ arma::vec tmp = g1.Random();
+ dataset(1, i) = tmp[0];
+ }
+ for (size_t i = 2500; i < 5000; ++i)
+ {
+ arma::vec tmp = g2.Random();
+ dataset(1, i) = tmp[0];
+ }
+
+ g1 = GaussianDistribution("-3", "1");
+ g2 = GaussianDistribution("3", "1");
+
+ for (size_t i = 0; i < 2500; ++i)
+ {
+ arma::vec tmp = g1.Random();
+ dataset(3, i) = tmp[0];
+ }
+ for (size_t i = 2500; i < 5000; ++i)
+ {
+ arma::vec tmp = g2.Random();
+ dataset(3, i) = tmp[0];
+ }
+
+ g1 = GaussianDistribution("-1", "1");
+ g2 = GaussianDistribution("1", "1");
+
+ for (size_t i = 0; i < 2500; ++i)
+ {
+ arma::vec tmp = g1.Random();
+ dataset(0, i) = tmp[0];
+ }
+ for (size_t i = 2500; i < 5000; ++i)
+ {
+ arma::vec tmp = g2.Random();
+ dataset(0, i) = tmp[0];
+ }
+
+ // Not separable at all.
+ g1 = GaussianDistribution("0", "1");
+ g2 = GaussianDistribution("0", "1");
+
+ for (size_t i = 0; i < 2500; ++i)
+ {
+ arma::vec tmp = g1.Random();
+ dataset(2, i) = tmp[0];
+ }
+ for (size_t i = 2500; i < 5000; ++i)
+ {
+ arma::vec tmp = g2.Random();
+ dataset(2, i) = tmp[0];
+ }
+
+ // Generate the labels.
+ arma::Row<size_t> labels(5000);
+ for (size_t i = 0; i < 2500; ++i)
+ labels[i] = 0;
+ for (size_t i = 2500; i < 5000; ++i)
+ labels[i] = 1;
+
+ // Now create a decision stump.
+ DecisionStump<> ds(dataset, labels, numClasses, inpBucketSize);
+
+ // Make sure it split on the dimension that is most separable.
+ BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), 1);
+
+ // Make sure every bin below -1 classifies as label 0, and every bin above 1
+ // classifies as label 1 (What happens in [-1, 1] isn't that big a deal.).
+ for (size_t i = 0; i < ds.Split().n_elem; ++i)
+ {
+ if (ds.Split()[i] <= -3.0)
+ BOOST_REQUIRE_EQUAL(ds.BinLabels()[i], 0);
+ else if (ds.Split()[i] >= 3.0)
+ BOOST_REQUIRE_EQUAL(ds.BinLabels()[i], 1);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list