[mlpack-svn] r16021 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 13 11:42:51 EST 2013


Author: rcurtin
Date: Wed Nov 13 11:42:51 2013
New Revision: 16021

Log:
Add test for LogisticRegressionFunction::Evaluate().


Added:
   mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
Modified:
   mlpack/trunk/src/mlpack/tests/CMakeLists.txt

Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt	(original)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt	Wed Nov 13 11:42:51 2013
@@ -24,6 +24,7 @@
   linear_regression_test.cpp
   load_save_test.cpp
   local_coordinate_coding_test.cpp
+  logistic_regression_test.cpp
   lrsdp_test.cpp
   lsh_test.cpp
   math_test.cpp

Added: mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp
==============================================================================
--- (empty file)
+++ mlpack/trunk/src/mlpack/tests/logistic_regression_test.cpp	Wed Nov 13 11:42:51 2013
@@ -0,0 +1,41 @@
+/**
+ * @file logistic_regression_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test for LogisticFunction and LogisticRegression.
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace mlpack;
+using namespace mlpack::regression;
+
+BOOST_AUTO_TEST_SUITE(LogisticRegressionTest);
+
+/**
+ * Test the LogisticFunction on a simple set of points.
+ */
+BOOST_AUTO_TEST_CASE(LogisticRegressionFunctionEvaluate)
+{
+  // Very simple fake dataset.
+  arma::mat data("1 1 1;" // Fake row for intercept.
+                 "1 2 3;"
+                 "1 2 3");
+  arma::vec responses("1 1 0");
+
+  // Create a LogisticRegressionFunction.
+  LogisticRegressionFunction lrf(data, responses, 0.0 /* no regularization */);
+
+  // These were hand-calculated using Octave.
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("1 1 1")), 7.0562141665, 1e-5);
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("0 0 0")), 2.0794415417, 1e-5);
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("-1 -1 -1")), 8.0562141665, 1e-5);
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("200 -40 -40")), 0.0, 1e-5);
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("200 -80 0")), 0.0, 1e-5);
+  BOOST_REQUIRE_CLOSE(lrf.Evaluate(arma::vec("200 -100 20")), 0.0, 1e-5);
+}
+
+BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list