[mlpack-svn] r13308 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 1 16:20:06 EDT 2012
Author: rcurtin
Date: 2012-08-01 16:20:05 -0400 (Wed, 01 Aug 2012)
New Revision: 13308
Modified:
mlpack/trunk/src/mlpack/tests/det_test.cpp
Log:
Adapt tests to working in the log-space and update APIs accordingly (as they
have changed).
Modified: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp 2012-08-01 20:19:43 UTC (rev 13307)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp 2012-08-01 20:20:05 UTC (rev 13308)
@@ -49,16 +49,15 @@
arma::vec minVals("3 0 1");
DTree<> testDTree(maxVals, minVals, 5);
- double trueNodeError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
- BOOST_REQUIRE_CLOSE((double) testDTree.error, trueNodeError, 1e-10);
+ BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
testDTree.start = 3;
testDTree.end = 5;
- double nodeError = -std::exp(testDTree.LogNegativeError(5));
- trueNodeError = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
- log(7.0));
+ double nodeError = testDTree.LogNegativeError(5);
+ trueNodeError = 2 * log(2.0 / 5.0) - log(4.0) - log(7.0) - log(7.0);
BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
}
@@ -95,11 +94,10 @@
trueDim = 2;
trueSplit = 5.5;
- trueLeftError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
- log(4.5)));
- trueRightError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
- log(2.5)));
+ trueLeftError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+ trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
+ testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
obRightError, 2, 1));
@@ -153,13 +151,13 @@
double rootError, lError, rError, rlError, rrError;
- rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ rootError = -log(4.0) - log(7.0) - log(7.0);
- lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
- rError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+ lError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+ rError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
- rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
- rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+ rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
+ rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
DTree<> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
@@ -186,16 +184,18 @@
BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
// Test node errors for every node.
- BOOST_REQUIRE_CLOSE(testDTree.error, rootError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Left()->error, lError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->error, rError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->error, rlError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->error, rrError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.logNegError, rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Left()->logNegError, lError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->logNegError, rError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->logNegError, rlError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->logNegError, rrError, 1e-10);
// Test alpha.
double rootAlpha, rAlpha;
- rootAlpha = (rootError - (lError + rlError + rrError)) / 2;
- rAlpha = rError - (rlError + rrError);
+ rootAlpha = std::log(-((std::exp(rootError) - (std::exp(lError) +
+ std::exp(rlError) + std::exp(rrError))) / 2));
+ rAlpha = std::log(-(std::exp(rError) - (std::exp(rlError) +
+ std::exp(rrError))));
BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
}
@@ -212,15 +212,15 @@
oTest << 0 << 1 << 2 << 3 << 4;
DTree<> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
- alpha = testDTree.PruneAndUpdate(alpha, false);
+ alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
- double rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+ double rootError = -log(4.0) - log(7.0) - log(7.0);
- BOOST_REQUIRE_CLOSE(testDTree.Error(), rootError, 1e-10);
- BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesError(), rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.LogNegError(), rootError, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesLogNegError(), rootError, 1e-10);
BOOST_REQUIRE(testDTree.Left() == NULL);
BOOST_REQUIRE(testDTree.Right() == NULL);
}
@@ -255,7 +255,7 @@
BOOST_REQUIRE_CLOSE(d3, testDTree.ComputeValue(q3), 1e-10);
BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
- alpha = testDTree.PruneAndUpdate(alpha, false);
+ alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
double d = 1.0 / exp(log(4.0) + log(7.0) + log(7.0));
More information about the mlpack-svn
mailing list