[mlpack-git] master: Add tests for empty DecisionStump constructor and serialization. (48fcbbc)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Nov 30 17:24:25 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/10b9d45b806a3e879b0564d78ccb183ebc7051ba...31c557d9cc7e4da57fd8a246085c19e076d12271

>---------------------------------------------------------------

commit 48fcbbce8016dec1b1bb3990afdc54b8a215b23d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Nov 21 01:41:32 2015 +0000

    Add tests for empty DecisionStump constructor and serialization.
    
    Also change so that the size isn't automatically set before calling Classify().


>---------------------------------------------------------------

48fcbbce8016dec1b1bb3990afdc54b8a215b23d
 src/mlpack/tests/decision_stump_test.cpp | 55 +++++++++++++++++++++++++++++---
 src/mlpack/tests/serialization_test.cpp  | 38 ++++++++++++++++++++++
 2 files changed, 88 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 4400532..69e279d 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -41,7 +41,7 @@ BOOST_AUTO_TEST_CASE(OneClass)
 
   DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
 
-  Row<size_t> predictedLabels(testingData.n_cols);
+  Row<size_t> predictedLabels;
   ds.Classify(testingData, predictedLabels);
 
   for (size_t i = 0; i < predictedLabels.size(); i++ )
@@ -106,7 +106,7 @@ BOOST_AUTO_TEST_CASE(PerfectSplitOnZero)
 
   DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
 
-  Row<size_t> predictedLabels(testingData.n_cols);
+  Row<size_t> predictedLabels;
   ds.Classify(testingData, predictedLabels);
 
   BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -137,7 +137,7 @@ BOOST_AUTO_TEST_CASE(BinningTesting)
 
   DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
 
-  Row<size_t> predictedLabels(testingData.n_cols);
+  Row<size_t> predictedLabels;
   ds.Classify(testingData, predictedLabels);
 
   BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -167,7 +167,7 @@ BOOST_AUTO_TEST_CASE(PerfectMultiClassSplit)
 
   DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
 
-  Row<size_t> predictedLabels(testingData.n_cols);
+  Row<size_t> predictedLabels;
   ds.Classify(testingData, predictedLabels);
 
   BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -202,7 +202,7 @@ BOOST_AUTO_TEST_CASE(MultiClassSplit)
 
   DecisionStump<> ds(trainingData, labelsIn.row(0), numClasses, inpBucketSize);
 
-  Row<size_t> predictedLabels(testingData.n_cols);
+  Row<size_t> predictedLabels;
   ds.Classify(testingData, predictedLabels);
 
   BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
@@ -308,4 +308,49 @@ BOOST_AUTO_TEST_CASE(DimensionSelectionTest)
   }
 }
 
+/**
+ * Ensure that the default constructor works and that it classifies things as 0
+ * always.
+ */
+BOOST_AUTO_TEST_CASE(EmptyConstructorTest)
+{
+  DecisionStump<> d;
+
+  arma::mat data = arma::randu<arma::mat>(3, 10);
+  arma::Row<size_t> labels;
+
+  d.Classify(data, labels);
+
+  for (size_t i = 0; i < 10; ++i)
+    BOOST_REQUIRE_EQUAL(labels[i], 0);
+
+  // Now train on another dataset and make sure something kind of makes sense.
+  mat trainingData;
+  trainingData << -7 << -6 << -5 << -4 << -3 << -2 << -1 << 0 << 1
+               << 2  << 3  << 4  << 5  << 6  << 7  << 8  << 9 << 10;
+
+  // No need to normalize labels here.
+  Mat<size_t> labelsIn;
+  labelsIn << 0 << 0 << 0 << 0 << 1 << 1 << 0 << 0
+           << 1 << 1 << 1 << 2 << 1 << 2 << 2 << 2 << 2 << 2;
+
+
+  mat testingData;
+  testingData << -6.1 << -5.9 << -2.1 << -0.7 << 2.5 << 4.7 << 7.2 << 9.1;
+
+  DecisionStump<> ds(trainingData, labelsIn.row(0), 4, 3);
+
+  Row<size_t> predictedLabels(testingData.n_cols);
+  ds.Classify(testingData, predictedLabels);
+
+  BOOST_CHECK_EQUAL(predictedLabels(0, 0), 0);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 1), 0);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 2), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 3), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 4), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 5), 1);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 6), 2);
+  BOOST_CHECK_EQUAL(predictedLabels(0, 7), 2);
+}
+
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 679fa0e..f2d3077 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -32,6 +32,7 @@
 #include <mlpack/methods/naive_bayes/naive_bayes_classifier.hpp>
 #include <mlpack/methods/rann/ra_search.hpp>
 #include <mlpack/methods/lsh/lsh_search.hpp>
+#include <mlpack/methods/decision_stump/decision_stump.hpp>
 
 using namespace mlpack;
 using namespace mlpack::distribution;
@@ -43,6 +44,7 @@ using namespace mlpack::perceptron;
 using namespace mlpack::regression;
 using namespace mlpack::naive_bayes;
 using namespace mlpack::neighbor;
+using namespace mlpack::decision_stump;
 
 using namespace arma;
 using namespace boost;
@@ -1429,4 +1431,40 @@ BOOST_AUTO_TEST_CASE(LSHTest)
       textLsh.SecondHashTable(), binaryLsh.SecondHashTable());
 }
 
+// Make sure serialization works for the decision stump.
+BOOST_AUTO_TEST_CASE(DecisionStumpTest)
+{
+  // Generate dataset.
+  arma::mat trainingData = arma::randu<arma::mat>(4, 100);
+  arma::Row<size_t> labels(100);
+  for (size_t i = 0; i < 25; ++i)
+    labels[i] = 0;
+  for (size_t i = 25; i < 50; ++i)
+    labels[i] = 3;
+  for (size_t i = 50; i < 75; ++i)
+    labels[i] = 1;
+  for (size_t i = 75; i < 100; ++i)
+    labels[i] = 2;
+
+  DecisionStump<> ds(trainingData, labels, 4, 3);
+
+  arma::mat otherData = arma::randu<arma::mat>(3, 100);
+  arma::Row<size_t> otherLabels = arma::randu<arma::Row<size_t>>(100);
+  DecisionStump<> xmlDs(otherData, otherLabels, 2, 3);
+
+  DecisionStump<> textDs;
+  DecisionStump<> binaryDs(trainingData, labels, 4, 10);
+
+  SerializeObjectAll(ds, xmlDs, textDs, binaryDs);
+
+  // Make sure that everything is the same about the new decision stumps.
+  BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), xmlDs.SplitAttribute());
+  BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), textDs.SplitAttribute());
+  BOOST_REQUIRE_EQUAL(ds.SplitAttribute(), binaryDs.SplitAttribute());
+
+  CheckMatrices(ds.Split(), xmlDs.Split(), textDs.Split(), binaryDs.Split());
+  CheckMatrices(ds.BinLabels(), xmlDs.BinLabels(), textDs.BinLabels(),
+      binaryDs.BinLabels());
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list