[mlpack-git] master: Add tests for empty DecisionStump constructor and serialization. (48fcbbc)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:25 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271
>---------------------------------------------------------------
commit 48fcbbce8016dec1b1bb3990afdc54b8a215b23d
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Nov 21 01:41:32 2015 +0000
Add tests for empty DecisionStump constructor and serialization.
Also change so that the size isn't automatically set before calling Classify().
>---------------------------------------------------------------
48fcbbce8016dec1b1bb3990afdc54b8a215b23d
src/mlpack/tests/decision_stump_test.cpp | 55 +++++++++++++++++++++++++++++---
src/mlpack/tests/serialization_test.cpp | 38 ++++++++++++++++++++++
2 files changed, 88 insertions(+), 5 deletions(-)
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 4400532..69e279d 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -41,7 +41,7 @@ BOOST_AUTO_TEST_CASE(OneClass)
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
- Row<size_t> predictedLabels(testingData.n_cols);
+ Row<size_t> predictedLabels;
ds.Classify(testingData, predictedLabels);
for (size_t i = 0; i < predictedLabels.size(); i++ )
@@ -106,7 +106,7 @@ BOOST_AUTO_TEST_CASE(PerfectSplitOnZero)
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
- Row<size_t> predictedLabels(testingData.n_cols);
+ Row<size_t> predictedLabels;
ds.Classify(testingData, predictedLabels);
BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -137,7 +137,7 @@ BOOST_AUTO_TEST_CASE(BinningTesting)
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
- Row<size_t> predictedLabels(testingData.n_cols);
+ Row<size_t> predictedLabels;
ds.Classify(testingData, predictedLabels);
BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -167,7 +167,7 @@ BOOST_AUTO_TEST_CASE(PerfectMultiClassSplit)
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
- Row<size_t> predictedLabels(testingData.n_cols);
+ Row<size_t> predictedLabels;
ds.Classify(testingData, predictedLabels);
BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -202,7 +202,7 @@ BOOST_AUTO_TEST_CASE(MultiClassSplit)
DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
- Row<size_t> predictedLabels(testingData.n_cols);
+ Row<size_t> predictedLabels;
ds.Classify(testingData, predictedLabels);
BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -308,4 +308,49 @@ BOOST_AUTO_TEST_CASE(DimensionSelectionTest)
}
}
+/**
+ * Ensure that the default constructor works and that it classifies things as 0
+ * always.
+ */
+BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
+{
+ DecisionStump<> d;
+
+ arma::mat data = arma::randu<arma::mat>(3, 10);
+ arma::Row<size_t> labels;
+
+ d.Classify(data, labels);
+
+ for (size_t i = 0; i < 10; ++i)
+ BOOST_REQUIRE_EQUAL(labels[i], 0);
+
+ // Now train on another dataset and make sure something kind of makes sense.
+ mat trainingData;
+ trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1
+ << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9 << 10;
+
+ // No need to normalize labels here.
+ Mat<size_t> labelsIn;
+ labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0
+ << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+
+
+ mat testingData;
+ testingData << -6.1 << -5.9 << -2.1 << -0.7 << 2.5 << 4.7 << 7.2 << 9.1;
+
+ DecisionStump<> ds(trainingData, labelsIn.row(0), 4, 3);
+
+ Row<size_t> predictedLabels(testingData.n_cols);
+ ds.Classify(testingData, predictedLabels);
+
+ BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 1), 0);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 2), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 3), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 4), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 5), 1);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 6), 2);
+ BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
+}
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 679fa0e..f2d3077 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -32,6 +32,7 @@
#include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
#include <mlpack/methods/rann/ra_search.hpp>
#include <mlpack/methods/lsh/lsh_search.hpp>
+#include <mlpack/methods/decision_stump/decision_stump.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@ -43,6 +44,7 @@ using namespace mlpack::perceptron;
using namespace mlpack::regression;
using namespace mlpack::naive_bayes;
using namespace mlpack::neighbor;
+using namespace mlpack::decision_stump;
using namespace arma;
using namespace boost;
@@ -1429,4 +1431,40 @@ BOOST_AUTO_TEST_CASE(LSHTest)
textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
}
+// Make sure serialization works for the decision stump.
+BOOST_AUTO_TEST_CASE(DecisionStumpTest)
+{
+ // Generate dataset.
+ arma::mat trainingData = arma::randu<arma::mat>(4, 100);
+ arma::Row<size_t> labels(100);
+ for (size_t i = 0; i < 25; ++i)
+ labels[i] = 0;
+ for (size_t i = 25; i < 50; ++i)
+ labels[i] = 3;
+ for (size_t i = 50; i < 75; ++i)
+ labels[i] = 1;
+ for (size_t i = 75; i < 100; ++i)
+ labels[i] = 2;
+
+ DecisionStump<> ds(trainingData, labels, 4, 3);
+
+ arma::mat otherData = arma::randu<arma::mat>(3, 100);
+ arma::Row<size_t> otherLabels = arma::randu<arma::Row<size_t>>(100);
+ DecisionStump<> xmlDs(otherData, otherLabels, 2, 3);
+
+ DecisionStump<> textDs;
+ DecisionStump<> binaryDs(trainingData, labels, 4, 10);
+
+ SerializeObjectAll(ds, xmlDs, textDs, binaryDs);
+
+ // Make sure that everything is the same about the new decision stumps.
+ BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), xmlDs.SplitAttribute());
+ BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), textDs.SplitAttribute());
+ BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), binaryDs.SplitAttribute());
+
+ CheckMatrices(ds.Split(), xmlDs.Split(), textDs.Split(), binaryDs.Split());
+ CheckMatrices(ds.BinLabels(), xmlDs.BinLabels(), textDs.BinLabels(),
+ binaryDs.BinLabels());
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list