[mlpack-svn] r13308 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 1 16:20:06 EDT 2012


Author: rcurtin
Date: 2012-08-01 16:20:05 -0400 (Wed, 01 Aug 2012)
New Revision: 13308

Modified:
   mlpack/trunk/src/mlpack/tests/det_test.cpp
Log:
Adapt tests to working in the log-space and update APIs accordingly (as they
have changed).


Modified: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-08-01 20:19:43 UTC (rev 13307)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-08-01 20:20:05 UTC (rev 13308)
@@ -49,16 +49,15 @@
   arma::vec minVals("3 0 1");
 
   DTree<> testDTree(maxVals, minVals, 5);
-  double trueNodeError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
 
-  BOOST_REQUIRE_CLOSE((double) testDTree.error, trueNodeError, 1e-10);
+  BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
 
   testDTree.start = 3;
   testDTree.end = 5;
 
-  double nodeError = -std::exp(testDTree.LogNegativeError(5));
-  trueNodeError = -1.0 * exp(2 * log(2.0 / 5.0) - log(4.0) - log(7.0) -
-      log(7.0));
+  double nodeError = testDTree.LogNegativeError(5);
+  trueNodeError = 2 * log(2.0 / 5.0) - log(4.0) - log(7.0) - log(7.0);
   BOOST_REQUIRE_CLOSE(nodeError, trueNodeError, 1e-10);
 }
 
@@ -95,11 +94,10 @@
 
   trueDim = 2;
   trueSplit = 5.5;
-  trueLeftError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
-      log(4.5)));
-  trueRightError = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
-      log(2.5)));
+  trueLeftError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+  trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
 
+  testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
   BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
       obRightError, 2, 1));
 
@@ -153,13 +151,13 @@
 
   double rootError, lError, rError, rlError, rrError;
 
-  rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  rootError = -log(4.0) - log(7.0) - log(7.0);
 
-  lError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
-  rError =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
+  lError = 2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5));
+  rError =  2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
 
-  rlError = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
-  rrError = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
+  rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
+  rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
 
   DTree<> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
@@ -186,16 +184,18 @@
   BOOST_REQUIRE_CLOSE(testDTree.Right()->SplitValue(), 0.5, 1e-5);
 
   // Test node errors for every node.
-  BOOST_REQUIRE_CLOSE(testDTree.error, rootError, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.Left()->error, lError, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.Right()->error, rError, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->error, rlError, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->error, rrError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.logNegError, rootError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Left()->logNegError, lError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->logNegError, rError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->Left()->logNegError, rlError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.Right()->Right()->logNegError, rrError, 1e-10);
 
   // Test alpha.
   double rootAlpha, rAlpha;
-  rootAlpha = (rootError - (lError + rlError + rrError)) / 2;
-  rAlpha = rError - (rlError + rrError);
+  rootAlpha = std::log(-((std::exp(rootError) - (std::exp(lError) +
+      std::exp(rlError) + std::exp(rrError))) / 2));
+  rAlpha = std::log(-(std::exp(rError) - (std::exp(rlError) +
+      std::exp(rrError))));
 
   BOOST_REQUIRE_CLOSE(alpha, min(rootAlpha, rAlpha), 1e-10);
 }
@@ -212,15 +212,15 @@
   oTest << 0 << 1 << 2 << 3 << 4;
   DTree<> testDTree(testData);
   double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
-  alpha = testDTree.PruneAndUpdate(alpha, false);
+  alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
 
   BOOST_REQUIRE_CLOSE(alpha, numeric_limits<double>::max(), 1e-10);
   BOOST_REQUIRE(testDTree.SubtreeLeaves() == 1);
 
-  double rootError = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
+  double rootError = -log(4.0) - log(7.0) - log(7.0);
 
-  BOOST_REQUIRE_CLOSE(testDTree.Error(), rootError, 1e-10);
-  BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesError(), rootError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.LogNegError(), rootError, 1e-10);
+  BOOST_REQUIRE_CLOSE(testDTree.SubtreeLeavesLogNegError(), rootError, 1e-10);
   BOOST_REQUIRE(testDTree.Left() == NULL);
   BOOST_REQUIRE(testDTree.Right() == NULL);
 }
@@ -255,7 +255,7 @@
   BOOST_REQUIRE_CLOSE(d3, testDTree.ComputeValue(q3), 1e-10);
   BOOST_REQUIRE_CLOSE(0.0, testDTree.ComputeValue(q4), 1e-10);
 
-  alpha = testDTree.PruneAndUpdate(alpha, false);
+  alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
 
   double d = 1.0 / exp(log(4.0) + log(7.0) + log(7.0));
 




More information about the mlpack-svn mailing list