[mlpack-svn] r13268 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 14:53:41 EDT 2012


Author: rcurtin
Date: 2012-07-20 14:53:40 -0400 (Fri, 20 Jul 2012)
New Revision: 13268

Modified:
   mlpack/trunk/src/mlpack/tests/det_test.cpp
Log:
Refactor test to stop using eT and cT.  Don't typedef MatType anymore, and
change variable name to line up with naming guidelines.


Modified: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-07-20 18:53:05 UTC (rev 13267)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-07-20 18:53:40 UTC (rev 13268)
@@ -23,23 +23,17 @@
 
 BOOST_AUTO_TEST_SUITE(DETTest);
 
-// Testing functions of the DTree class
+// Tests for the private functions.
 
-typedef arma::mat MatType;
-typedef arma::vec VecType;
-
-
-// the private functions
-
 BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<> tree(test_data);
+  DTree<> tree(testData);
 
   BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
   BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -55,17 +49,17 @@
   arma::vec minVals("3 0 1");
 
   DTree<> testDTree(maxVals, minVals, 5);
-  double true_node_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  double trueNodeError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  BOOST_REQUIRE_CLOSE((double) testDTree.error_, true_node_error, 1e-10);
+  BOOST_REQUIRE_CLOSE((double) testDTree.error, trueNodeError, 1e-10);
 
-  testDTree.start_ = 3;
-  testDTree.end_ = 5;
+  testDTree.start = 3;
+  testDTree.end = 5;
 
-  double node_error = -std::exp(testDTree.LogNegativeError(5));
-  true_node_error = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
+  double nodeError = -std::exp(testDTree.LogNegativeError(5));
+  trueNodeError = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
       log(7.0));
-  BOOST_REQUIRE_CLOSE(node_error, true_node_error, 1e-10);
+  BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
 }
 
 BOOST_AUTO_TEST_CASE(TestWithinRange)
@@ -75,183 +69,182 @@
 
   DTree<> testDTree(maxVals, minVals, 5);
 
-  VecType test_query(3);
-  test_query << 4.5 << 2.5 << 2;
+  arma::vec testQuery(3);
+  testQuery << 4.5 << 2.5 << 2;
 
-  BOOST_REQUIRE_EQUAL(testDTree.WithinRange(test_query), true);
+  BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), true);
 
-  test_query << 8.5 << 2.5 << 2;
+  testQuery << 8.5 << 2.5 << 2;
 
-  BOOST_REQUIRE_EQUAL(testDTree.WithinRange(test_query), false);
+  BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), false);
 }
 
 BOOST_AUTO_TEST_CASE(TestFindSplit)
 {
-  MatType test_data(3,5);
+  arma::mat testData(3,5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<> testDTree(test_data);
+  DTree<> testDTree(testData);
 
-  size_t ob_dim, true_dim;
-  double true_left_error, ob_left_error, true_right_error, ob_right_error,
-      ob_split, true_split;
+  size_t obDim, trueDim;
+  double trueLeftError, obLeftError, trueRightError, obRightError,
+      obSplit, trueSplit;
 
-  true_dim = 2;
-  true_split = 5.5;
-  true_left_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
+  trueDim = 2;
+  trueSplit = 5.5;
+  trueLeftError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
       log(4.5)));
-  true_right_error = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
+  trueRightError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
       log(2.5)));
 
-  BOOST_REQUIRE(testDTree.FindSplit(test_data, ob_dim, ob_split, ob_left_error,
-      ob_right_error, 2, 1));
+  BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
+      obRightError, 2, 1));
 
-  BOOST_REQUIRE(true_dim == ob_dim);
-  BOOST_REQUIRE_CLOSE(true_split, ob_split, 1e-10);
+  BOOST_REQUIRE(trueDim == obDim);
+  BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
 
-  BOOST_REQUIRE_CLOSE(true_left_error, ob_left_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(true_right_error, ob_right_error, 1e-10);
+  BOOST_REQUIRE_CLOSE(trueLeftError, obLeftError, 1e-10);
+  BOOST_REQUIRE_CLOSE(trueRightError, obRightError, 1e-10);
 }
 
 BOOST_AUTO_TEST_CASE(TestSplitData)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<> testDTree(test_data);
+  DTree<> testDTree(testData);
 
-  arma::Col<size_t> o_test(5);
-  o_test << 1 << 2 << 3 << 4 << 5;
+  arma::Col<size_t> oTest(5);
+  oTest << 1 << 2 << 3 << 4 << 5;
 
-  size_t split_dim = 2;
-  double true_split_val = 5.5;
+  size_t splitDim = 2;
+  double trueSplitVal = 5.5;
 
-  size_t splitInd = testDTree.SplitData(test_data, split_dim, true_split_val,
-      o_test);
+  size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
+      oTest);
 
   BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
 
-  BOOST_REQUIRE_EQUAL(o_test[0], 1);
-  BOOST_REQUIRE_EQUAL(o_test[1], 4);
-  BOOST_REQUIRE_EQUAL(o_test[2], 3);
-  BOOST_REQUIRE_EQUAL(o_test[3], 2);
-  BOOST_REQUIRE_EQUAL(o_test[4], 5);
+  BOOST_REQUIRE_EQUAL(oTest[0], 1);
+  BOOST_REQUIRE_EQUAL(oTest[1], 4);
+  BOOST_REQUIRE_EQUAL(oTest[2], 3);
+  BOOST_REQUIRE_EQUAL(oTest[3], 2);
+  BOOST_REQUIRE_EQUAL(oTest[4], 5);
 }
 
-// the public functions
+// Tests for the public functions.
 
 BOOST_AUTO_TEST_CASE(TestGrow)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  arma::Col<size_t> o_test(5);
-  o_test << 0 << 1 << 2 << 3 << 4;
+  arma::Col<size_t> oTest(5);
+  oTest << 0 << 1 << 2 << 3 << 4;
 
-  long double root_error, l_error, r_error, rl_error, rr_error;
+  double rootError, lError, rError, rlError, rrError;
 
-  root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
-  r_error =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+  lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+  rError =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
 
-  rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
-  rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+  rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+  rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
 
-  DTree<> testDTree(test_data);
-  double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+  DTree<> testDTree(testData);
+  double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
-  BOOST_REQUIRE_EQUAL(o_test[0], 0);
-  BOOST_REQUIRE_EQUAL(o_test[1], 3);
-  BOOST_REQUIRE_EQUAL(o_test[2], 1);
-  BOOST_REQUIRE_EQUAL(o_test[3], 2);
-  BOOST_REQUIRE_EQUAL(o_test[4], 4);
+  BOOST_REQUIRE_EQUAL(oTest[0], 0);
+  BOOST_REQUIRE_EQUAL(oTest[1], 3);
+  BOOST_REQUIRE_EQUAL(oTest[2], 1);
+  BOOST_REQUIRE_EQUAL(oTest[3], 2);
+  BOOST_REQUIRE_EQUAL(oTest[4], 4);
 
-  // test the structure of the tree
-  BOOST_REQUIRE(testDTree.left()->left() == NULL);
-  BOOST_REQUIRE(testDTree.left()->right() == NULL);
-  BOOST_REQUIRE(testDTree.right()->left()->left() == NULL);
-  BOOST_REQUIRE(testDTree.right()->left()->right() == NULL);
-  BOOST_REQUIRE(testDTree.right()->right()->left() == NULL);
-  BOOST_REQUIRE(testDTree.right()->right()->right() == NULL);
+  // Test the structure of the tree.
+  BOOST_REQUIRE(testDTree.Left()->Left() == NULL);
+  BOOST_REQUIRE(testDTree.Left()->Right() == NULL);
+  BOOST_REQUIRE(testDTree.Right()->Left()->Left() == NULL);
+  BOOST_REQUIRE(testDTree.Right()->Left()->Right() == NULL);
+  BOOST_REQUIRE(testDTree.Right()->Right()->Left() == NULL);
+  BOOST_REQUIRE(testDTree.Right()->Right()->Right() == NULL);
 
-  BOOST_REQUIRE(testDTree.subtree_leaves() == 3);
+  BOOST_REQUIRE(testDTree.SubtreeLeaves() == 3);
 
-  BOOST_REQUIRE(testDTree.split_dim() == 2);
-  BOOST_REQUIRE_CLOSE(testDTree.split_value(), 5.5, 1e-5);
-  BOOST_REQUIRE(testDTree.right()->split_dim() == 1);
-  BOOST_REQUIRE_CLOSE(testDTree.right()->split_value(), 0.5, 1e-5);
+  BOOST_REQUIRE(testDTree.SplitDim() == 2);
+  BOOST_REQUIRE_CLOSE(testDTree.SplitValue(), 5.5, 1e-5);
+  BOOST_REQUIRE(testDTree.Right()->SplitDim() == 1);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
 
-  // test node errors for every node
-  BOOST_REQUIRE_CLOSE(testDTree.error_, root_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.left()->error_, l_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.right()->error_, r_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.right()->left()->error_, rl_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.right()->right()->error_, rr_error, 1e-10);
+  // Test node errors for every node.
+  BOOST_REQUIRE_CLOSE(testDTree.error, rootError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Left()->error, lError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->error, rError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->error, rlError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->error, rrError, 1e-10);
 
+  // Test alpha.
+  double rootAlpha, rAlpha;
+  rootAlpha = (rootError - (lError + rlError + rrError)) / 2;
+  rAlpha = rError - (rlError + rrError);
 
-  // test alpha
-  double root_alpha, r_alpha;
-  root_alpha = (root_error - (l_error + rl_error + rr_error)) / 2;
-  r_alpha = r_error - (rl_error + rr_error);
-
-  BOOST_REQUIRE_CLOSE(alpha, min(root_alpha, r_alpha), 1e-10);
+  BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
 }
 
 BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  arma::Col<size_t> o_test(5);
-  o_test << 0 << 1 << 2 << 3 << 4;
-  DTree<> testDTree(test_data);
-  double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+  arma::Col<size_t> oTest(5);
+  oTest << 0 << 1 << 2 << 3 << 4;
+  DTree<> testDTree(testData);
+  double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
   alpha = testDTree.PruneAndUpdate(alpha, false);
 
   BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
-  BOOST_REQUIRE(testDTree.subtree_leaves() == 1);
+  BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
 
-  long double root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  double rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  BOOST_REQUIRE_CLOSE(testDTree.error(), root_error, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.subtree_leaves_error(), root_error, 1e-10);
-  BOOST_REQUIRE(testDTree.left() == NULL);
-  BOOST_REQUIRE(testDTree.right() == NULL);
+  BOOST_REQUIRE_CLOSE(testDTree.Error(), rootError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesError(), rootError, 1e-10);
+  BOOST_REQUIRE(testDTree.Left() == NULL);
+  BOOST_REQUIRE(testDTree.Right() == NULL);
 }
 
 BOOST_AUTO_TEST_CASE(TestComputeValue)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  VecType q1(3), q2(3), q3(3), q4(3);
+  arma::vec q1(3), q2(3), q3(3), q4(3);
 
   q1 << 4 << 2 << 2;
   q2 << 5 << 0.25 << 6;
   q3 << 5 << 3 << 7;
   q4 << 2 << 3 << 3;
 
-  arma::Col<size_t> o_test(5);
-  o_test << 0 << 1 << 2 << 3 << 4;
+  arma::Col<size_t> oTest(5);
+  oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree<> testDTree(test_data);
-  double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+  DTree<> testDTree(testData);
+  double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
 
   double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
   double d2 = (1.0 / 5.0) / exp(log(4.0) + log(0.5) + log(2.5));
@@ -274,37 +267,35 @@
 
 BOOST_AUTO_TEST_CASE(TestVariableImportance)
 {
-  MatType test_data(3, 5);
+  arma::mat testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
-            << 5 << 0 << 1 << 7 << 1 << arma::endr
-            << 5 << 6 << 7 << 1 << 8 << arma::endr;
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+           << 5 << 0 << 1 << 7 << 1 << arma::endr
+           << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  long double root_error, l_error, r_error, rl_error, rr_error;
+  double rootError, lError, rError, rlError, rrError;
 
-  root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
-  r_error =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+  lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+  rError =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
 
-  rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
-  rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+  rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+  rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
 
-  arma::Col<size_t> o_test(5);
-  o_test << 0 << 1 << 2 << 3 << 4;
+  arma::Col<size_t> oTest(5);
+  oTest << 0 << 1 << 2 << 3 << 4;
 
-  DTree<> testDTree(test_data);
-  testDTree.Grow(test_data, o_test, false, 2, 1);
+  DTree<> testDTree(testData);
+  testDTree.Grow(testData, oTest, false, 2, 1);
 
   arma::vec imps;
 
   testDTree.ComputeVariableImportance(imps);
 
   BOOST_REQUIRE_CLOSE((double) 0.0, imps[0], 1e-10);
-  BOOST_REQUIRE_CLOSE((double) (r_error - (rl_error + rr_error)), imps[1],
-      1e-10);
-  BOOST_REQUIRE_CLOSE((double) (root_error - (l_error + r_error)), imps[2],
-      1e-10);
+  BOOST_REQUIRE_CLOSE((double) (rError - (rlError + rrError)), imps[1], 1e-10);
+  BOOST_REQUIRE_CLOSE((double) (rootError - (lError + rError)), imps[2], 1e-10);
 }
 
 /**
@@ -312,26 +303,26 @@
  *
 BOOST_AUTO_TEST_CASE(TestTagTree)
 {
-  MatType test_data(3, 5);
+  MatType testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
             << 5 << 0 << 1 << 7 << 1 << arma::endr
             << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<>* testDTree = new DTree<>(&test_data);
+  DTree<>* testDTree = new DTree<>(&testData);
 
   delete testDTree;
 }
 
 BOOST_AUTO_TEST_CASE(TestFindBucket)
 {
-  MatType test_data(3, 5);
+  MatType testData(3, 5);
 
-  test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+  testData << 4 << 5 << 7 << 3 << 5 << arma::endr
             << 5 << 0 << 1 << 7 << 1 << arma::endr
             << 5 << 6 << 7 << 1 << 8 << arma::endr;
 
-  DTree<>* testDTree = new DTree<>(&test_data);
+  DTree<>* testDTree = new DTree<>(&testData);
 
   delete testDTree;
 }




More information about the mlpack-svn mailing list