[mlpack-svn] r13268 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 14:53:41 EDT 2012
Author: rcurtin
Date: 2012-07-20 14:53:40 -0400 (Fri, 20 Jul 2012)
New Revision: 13268
Modified:
mlpack/trunk/src/mlpack/tests/det_test.cpp
Log:
Refactor test to stop using eT and cT. Don't typedef MatType anymore, and
change variable name to line up with naming guidelines.
Modified: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp 2012-07-20 18:53:05 UTC (rev 13267)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp 2012-07-20 18:53:40 UTC (rev 13268)
@@ -23,23 +23,17 @@
BOOST_AUTO_TEST_SUITE(DETTest);
-// Testing functions of the DTree class
+// Tests for the private functions.
-typedef arma::mat MatType;
-typedef arma::vec VecType;
-
-
-// the private functions
-
BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<> tree(test_data);
+ DTree<> tree(testData);
BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -55,17 +49,17 @@
arma::vec minVals("3 0 1");
DTree<> testDTree(maxVals, minVals, 5);
- double true_node_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ double trueNodeError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
- BOOST_REQUIRE_CLOSE((double) testDTree.error_, true_node_error, 1e-10);
+ BOOST_REQUIRE_CLOSE((double) testDTree.error, trueNodeError, 1e-10);
- testDTree.start_ = 3;
- testDTree.end_ = 5;
+ testDTree.start = 3;
+ testDTree.end = 5;
- double node_error = -std::exp(testDTree.LogNegativeError(5));
- true_node_error = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
+ double nodeError = -std::exp(testDTree.LogNegativeError(5));
+ trueNodeError = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
log(7.0));
- BOOST_REQUIRE_CLOSE(node_error, true_node_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
}
BOOST_AUTO_TEST_CASE(TestWithinRange)
@@ -75,183 +69,182 @@
DTree<> testDTree(maxVals, minVals, 5);
- VecType test_query(3);
- test_query << 4.5 << 2.5 << 2;
+ arma::vec testQuery(3);
+ testQuery << 4.5 << 2.5 << 2;
- BOOST_REQUIRE_EQUAL(testDTree.WithinRange(test_query), true);
+ BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), true);
- test_query << 8.5 << 2.5 << 2;
+ testQuery << 8.5 << 2.5 << 2;
- BOOST_REQUIRE_EQUAL(testDTree.WithinRange(test_query), false);
+ BOOST_REQUIRE_EQUAL(testDTree.WithinRange(testQuery), false);
}
BOOST_AUTO_TEST_CASE(TestFindSplit)
{
- MatType test_data(3,5);
+ arma::mat testData(3,5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<> testDTree(test_data);
+ DTree<> testDTree(testData);
- size_t ob_dim, true_dim;
- double true_left_error, ob_left_error, true_right_error, ob_right_error,
- ob_split, true_split;
+ size_t obDim, trueDim;
+ double trueLeftError, obLeftError, trueRightError, obRightError,
+ obSplit, trueSplit;
- true_dim = 2;
- true_split = 5.5;
- true_left_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
+ trueDim = 2;
+ trueSplit = 5.5;
+ trueLeftError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
log(4.5)));
- true_right_error = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
+ trueRightError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
log(2.5)));
- BOOST_REQUIRE(testDTree.FindSplit(test_data, ob_dim, ob_split, ob_left_error,
- ob_right_error, 2, 1));
+ BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
+ obRightError, 2, 1));
- BOOST_REQUIRE(true_dim == ob_dim);
- BOOST_REQUIRE_CLOSE(true_split, ob_split, 1e-10);
+ BOOST_REQUIRE(trueDim == obDim);
+ BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
- BOOST_REQUIRE_CLOSE(true_left_error, ob_left_error, 1e-10);
- BOOST_REQUIRE_CLOSE(true_right_error, ob_right_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(trueLeftError, obLeftError, 1e-10);
+ BOOST_REQUIRE_CLOSE(trueRightError, obRightError, 1e-10);
}
BOOST_AUTO_TEST_CASE(TestSplitData)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<> testDTree(test_data);
+ DTree<> testDTree(testData);
- arma::Col<size_t> o_test(5);
- o_test << 1 << 2 << 3 << 4 << 5;
+ arma::Col<size_t> oTest(5);
+ oTest << 1 << 2 << 3 << 4 << 5;
- size_t split_dim = 2;
- double true_split_val = 5.5;
+ size_t splitDim = 2;
+ double trueSplitVal = 5.5;
- size_t splitInd = testDTree.SplitData(test_data, split_dim, true_split_val,
- o_test);
+ size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
+ oTest);
BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
- BOOST_REQUIRE_EQUAL(o_test[0], 1);
- BOOST_REQUIRE_EQUAL(o_test[1], 4);
- BOOST_REQUIRE_EQUAL(o_test[2], 3);
- BOOST_REQUIRE_EQUAL(o_test[3], 2);
- BOOST_REQUIRE_EQUAL(o_test[4], 5);
+ BOOST_REQUIRE_EQUAL(oTest[0], 1);
+ BOOST_REQUIRE_EQUAL(oTest[1], 4);
+ BOOST_REQUIRE_EQUAL(oTest[2], 3);
+ BOOST_REQUIRE_EQUAL(oTest[3], 2);
+ BOOST_REQUIRE_EQUAL(oTest[4], 5);
}
-// the public functions
+// Tests for the public functions.
BOOST_AUTO_TEST_CASE(TestGrow)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- arma::Col<size_t> o_test(5);
- o_test << 0 << 1 << 2 << 3 << 4;
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
- long double root_error, l_error, r_error, rl_error, rr_error;
+ double rootError, lError, rError, rlError, rrError;
- root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
- l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
- r_error = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+ lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+ rError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
- rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
- rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+ rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+ rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
- DTree<> testDTree(test_data);
- double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+ DTree<> testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
- BOOST_REQUIRE_EQUAL(o_test[0], 0);
- BOOST_REQUIRE_EQUAL(o_test[1], 3);
- BOOST_REQUIRE_EQUAL(o_test[2], 1);
- BOOST_REQUIRE_EQUAL(o_test[3], 2);
- BOOST_REQUIRE_EQUAL(o_test[4], 4);
+ BOOST_REQUIRE_EQUAL(oTest[0], 0);
+ BOOST_REQUIRE_EQUAL(oTest[1], 3);
+ BOOST_REQUIRE_EQUAL(oTest[2], 1);
+ BOOST_REQUIRE_EQUAL(oTest[3], 2);
+ BOOST_REQUIRE_EQUAL(oTest[4], 4);
- // test the structure of the tree
- BOOST_REQUIRE(testDTree.left()->left() == NULL);
- BOOST_REQUIRE(testDTree.left()->right() == NULL);
- BOOST_REQUIRE(testDTree.right()->left()->left() == NULL);
- BOOST_REQUIRE(testDTree.right()->left()->right() == NULL);
- BOOST_REQUIRE(testDTree.right()->right()->left() == NULL);
- BOOST_REQUIRE(testDTree.right()->right()->right() == NULL);
+ // Test the structure of the tree.
+ BOOST_REQUIRE(testDTree.Left()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Left()->Right() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Left()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Left()->Right() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Right()->Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right()->Right()->Right() == NULL);
- BOOST_REQUIRE(testDTree.subtree_leaves() == 3);
+ BOOST_REQUIRE(testDTree.SubtreeLeaves() == 3);
- BOOST_REQUIRE(testDTree.split_dim() == 2);
- BOOST_REQUIRE_CLOSE(testDTree.split_value(), 5.5, 1e-5);
- BOOST_REQUIRE(testDTree.right()->split_dim() == 1);
- BOOST_REQUIRE_CLOSE(testDTree.right()->split_value(), 0.5, 1e-5);
+ BOOST_REQUIRE(testDTree.SplitDim() == 2);
+ BOOST_REQUIRE_CLOSE(testDTree.SplitValue(), 5.5, 1e-5);
+ BOOST_REQUIRE(testDTree.Right()->SplitDim() == 1);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
- // test node errors for every node
- BOOST_REQUIRE_CLOSE(testDTree.error_, root_error, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.left()->error_, l_error, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.right()->error_, r_error, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.right()->left()->error_, rl_error, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.right()->right()->error_, rr_error, 1e-10);
+ // Test node errors for every node.
+ BOOST_REQUIRE_CLOSE(testDTree.error, rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Left()->error, lError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->error, rError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->error, rlError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->error, rrError, 1e-10);
+ // Test alpha.
+ double rootAlpha, rAlpha;
+ rootAlpha = (rootError - (lError + rlError + rrError)) / 2;
+ rAlpha = rError - (rlError + rrError);
- // test alpha
- double root_alpha, r_alpha;
- root_alpha = (root_error - (l_error + rl_error + rr_error)) / 2;
- r_alpha = r_error - (rl_error + rr_error);
-
- BOOST_REQUIRE_CLOSE(alpha, min(root_alpha, r_alpha), 1e-10);
+ BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
}
BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- arma::Col<size_t> o_test(5);
- o_test << 0 << 1 << 2 << 3 << 4;
- DTree<> testDTree(test_data);
- double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
+ DTree<> testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
alpha = testDTree.PruneAndUpdate(alpha, false);
BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
- BOOST_REQUIRE(testDTree.subtree_leaves() == 1);
+ BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
- long double root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ double rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
- BOOST_REQUIRE_CLOSE(testDTree.error(), root_error, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.subtree_leaves_error(), root_error, 1e-10);
- BOOST_REQUIRE(testDTree.left() == NULL);
- BOOST_REQUIRE(testDTree.right() == NULL);
+ BOOST_REQUIRE_CLOSE(testDTree.Error(), rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesError(), rootError, 1e-10);
+ BOOST_REQUIRE(testDTree.Left() == NULL);
+ BOOST_REQUIRE(testDTree.Right() == NULL);
}
BOOST_AUTO_TEST_CASE(TestComputeValue)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- VecType q1(3), q2(3), q3(3), q4(3);
+ arma::vec q1(3), q2(3), q3(3), q4(3);
q1 << 4 << 2 << 2;
q2 << 5 << 0.25 << 6;
q3 << 5 << 3 << 7;
q4 << 2 << 3 << 3;
- arma::Col<size_t> o_test(5);
- o_test << 0 << 1 << 2 << 3 << 4;
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
- DTree<> testDTree(test_data);
- double alpha = testDTree.Grow(test_data, o_test, false, 2, 1);
+ DTree<> testDTree(testData);
+ double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
double d2 = (1.0 / 5.0) / exp(log(4.0) + log(0.5) + log(2.5));
@@ -274,37 +267,35 @@
BOOST_AUTO_TEST_CASE(TestVariableImportance)
{
- MatType test_data(3, 5);
+ arma::mat testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
- long double root_error, l_error, r_error, rl_error, rr_error;
+ double rootError, lError, rError, rlError, rrError;
- root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
- l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
- r_error = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+ lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+ rError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
- rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
- rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+ rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+ rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
- arma::Col<size_t> o_test(5);
- o_test << 0 << 1 << 2 << 3 << 4;
+ arma::Col<size_t> oTest(5);
+ oTest << 0 << 1 << 2 << 3 << 4;
- DTree<> testDTree(test_data);
- testDTree.Grow(test_data, o_test, false, 2, 1);
+ DTree<> testDTree(testData);
+ testDTree.Grow(testData, oTest, false, 2, 1);
arma::vec imps;
testDTree.ComputeVariableImportance(imps);
BOOST_REQUIRE_CLOSE((double) 0.0, imps[0], 1e-10);
- BOOST_REQUIRE_CLOSE((double) (r_error - (rl_error + rr_error)), imps[1],
- 1e-10);
- BOOST_REQUIRE_CLOSE((double) (root_error - (l_error + r_error)), imps[2],
- 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (rError - (rlError + rrError)), imps[1], 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (rootError - (lError + rError)), imps[2], 1e-10);
}
/**
@@ -312,26 +303,26 @@
*
BOOST_AUTO_TEST_CASE(TestTagTree)
{
- MatType test_data(3, 5);
+ MatType testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<>* testDTree = new DTree<>(&test_data);
+ DTree<>* testDTree = new DTree<>(&testData);
delete testDTree;
}
BOOST_AUTO_TEST_CASE(TestFindBucket)
{
- MatType test_data(3, 5);
+ MatType testData(3, 5);
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ testData << 4 << 5 << 7 << 3 << 5 << arma::endr
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree<>* testDTree = new DTree<>(&test_data);
+ DTree<>* testDTree = new DTree<>(&testData);
delete testDTree;
}
More information about the mlpack-svn
mailing list