[mlpack-svn] r13212 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 11 18:34:10 EDT 2012


Author: rcurtin
Date: 2012-07-11 18:34:09 -0400 (Wed, 11 Jul 2012)
New Revision: 13212

Modified:
   mlpack/trunk/src/mlpack/tests/det_test.cpp
Log:
Fix lots of formatting issues in DETTest tests.  Stop casting to/from float/long
double to preserve accuracy of calculations (which are now done on doubles).


Modified: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-07-11 22:33:31 UTC (rev 13211)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp	2012-07-11 22:34:09 UTC (rev 13212)
@@ -25,8 +25,8 @@
 
 // Testing functions of the DTree class
 
-typedef arma::Mat<float> MatType;
-typedef arma::Col<float> VecType;
+typedef arma::mat MatType;
+typedef arma::vec VecType;
 
 
 // the private functions
@@ -65,20 +65,20 @@
   *min_vals << 3 << 0 << 1;
 
   DTree<>* testDTree = new DTree<>(max_vals, min_vals, 5);
-  long double true_node_error = -1.0 * exp(-(long double) log((float) 4.0)
-					   - (long double) log((float) 7.0)
-					   - (long double) log((float) 7.0));
+  double true_node_error = -1.0 * exp(-(double) log((double) 4.0)
+					   - (double) log((double) 7.0)
+					   - (double) log((double) 7.0));
 
   BOOST_REQUIRE_CLOSE(testDTree->error_, true_node_error, 1e-10);
 
   testDTree->start_ = 3;
   testDTree->end_ = 5;
 
-  long double node_error = -std::exp(testDTree->LogNegativeError(5));
-  true_node_error = -1.0 * exp(2 * log((long double) 2 / (long double) 5)
-			       -(long double) log((float) 4.0)
-			       - (long double) log((float) 7.0)
-			       - (long double) log((float) 7.0));
+  double node_error = -std::exp(testDTree->LogNegativeError(5));
+  true_node_error = -1.0 * exp(2 * log((double) 2 / (double) 5)
+			       -(double) log((double) 4.0)
+			       - (double) log((double) 7.0)
+			       - (double) log((double) 7.0));
   BOOST_REQUIRE_CLOSE(node_error, true_node_error, 1e-10);
 
   delete testDTree;
@@ -117,24 +117,17 @@
   DTree<>* testDTree = new DTree<>(&test_data);
 
   size_t ob_dim, true_dim, ob_ind, true_ind;
-  long double true_left_error, ob_left_error,
-    true_right_error, ob_right_error;
+  long double true_left_error, ob_left_error, true_right_error, ob_right_error;
 
   true_dim = 2;
   true_ind = 1;
-  true_left_error = -1.0 * exp(2 * log((long double) 2
-				       / (long double) 5)
-			       - ((long double) log((float) 7)
-				  + (long double) log((float) 4)
-				  + (long double) log((float) 4.5)));
-  true_right_error =  -1.0 * exp(2 * log((long double) 3
-					 / (long double) 5)
-				 - ((long double) log((float) 7)
-				    + (long double) log((float) 4)
-				    + (long double) log((float) 2.5)));
+  true_left_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) +
+      log(4.5)));
+  true_right_error = -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) +
+      log(2.5)));
 
   BOOST_REQUIRE(testDTree->FindSplit_
-		(&test_data, &ob_dim, &ob_ind, &ob_left_error,
+		(test_data, &ob_dim, &ob_ind, &ob_left_error,
 		 &ob_right_error, 2, 1));
 
   BOOST_REQUIRE(true_dim == ob_dim);
@@ -160,7 +153,7 @@
   o_test << 1 << 2 << 3 << 4 << 5;
 
   size_t split_dim = 2, split_ind = 1;
-  float true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
+  double true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
     true_rsplit_val, ob_rsplit_val;
 
   true_lsplit_val = 5;
@@ -195,38 +188,18 @@
   arma::Col<size_t> o_test(5);
   o_test << 0 << 1 << 2 << 3 << 4;
 
-  long double root_error, l_error, r_error, rl_error, rr_error;
+  double root_error, l_error, r_error, rl_error, rr_error;
 
-  root_error = -1.0 * exp(-(long double) log((float) 4.0)
-			  - (long double) log((float) 7.0)
-			  - (long double) log((float) 7.0));
+  root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  l_error = -1.0 * exp(2 * log((long double) 2
-			       / (long double) 5)
-		       - ((long double) log((float) 7)
-			  + (long double) log((float) 4)
-			  + (long double) log((float) 4.5)));
-  r_error =  -1.0 * exp(2 * log((long double) 3
-				/ (long double) 5)
-			- ((long double) log((float) 7)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
+  l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+  r_error =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
 
-  rl_error = -1.0 * exp(2 * log((long double) 1
-				/ (long double) 5)
-			- ((long double) log((float) 0.5)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
+  rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+  rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
 
-  rr_error = -1.0 * exp(2 * log((long double) 2
-				/ (long double) 5)
-			- ((long double) log((float) 6.5)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
-
   DTree<>* testDTree = new DTree<>(&test_data);
-  long double alpha = testDTree->Grow(&test_data, &o_test,
-				      false, 2, 1);
+  long double alpha = testDTree->Grow(&test_data, &o_test, false, 2, 1);
 
   BOOST_REQUIRE(o_test[0] == 0 && o_test[1] == 3
 		&& o_test[2] == 1 && o_test[3] == 2
@@ -284,9 +257,7 @@
   BOOST_REQUIRE_CLOSE(alpha, numeric_limits<long double>::max(), 1e-10);
   BOOST_REQUIRE(testDTree->subtree_leaves() == 1);
 
-  long double root_error = -1.0 * exp(-(long double) log((float) 4.0)
-				      - (long double) log((float) 7.0)
-				      - (long double) log((float) 7.0));
+  long double root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
   BOOST_REQUIRE_CLOSE(testDTree->error(), root_error, 1e-10);
   BOOST_REQUIRE_CLOSE(testDTree->subtree_leaves_error(), root_error, 1e-10);
@@ -318,19 +289,11 @@
   long double alpha = testDTree->Grow(&test_data, &o_test,
 				      false, 2, 1);
 
-  long double d1, d2, d3;
-  d1 = ((long double) 2 / (long double) 5)
-    / exp((long double) log((float) 4) + (long double) log((float) 7)
-	  + (long double) log((float) 4.5));
+  double d1, d2, d3;
+  d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
+  d2 = (1.0 / 5.0) / exp(log(4.0) + log(0.5) + log(2.5));
+  d3 = (2.0 / 5.0) / exp(log(4.0) + log(6.5) + log(2.5));
 
-  d2 = ((long double) 1 / (long double) 5)
-    / exp((long double) log((float) 4) + (long double) log((float) 0.5)
-	  + (long double) log((float) 2.5));
-
-  d3 = ((long double) 2 / (long double) 5)
-    / exp((long double) log((float) 4) + (long double) log((float) 6.5)
-	  + (long double) log((float) 2.5));
-
   BOOST_REQUIRE_CLOSE(d1, testDTree->ComputeValue(&q1), 1e-10);
   BOOST_REQUIRE_CLOSE(d2, testDTree->ComputeValue(&q2), 1e-10);
   BOOST_REQUIRE_CLOSE(d3, testDTree->ComputeValue(&q3), 1e-10);
@@ -338,9 +301,7 @@
 
   alpha = testDTree->PruneAndUpdate(alpha, false);
 
-  long double d = 1.0
-    / exp((long double) log((float) 4) + (long double) log((float) 7)
-	  + (long double) log((float) 7));
+  long double d = 1.0 / exp(log(4.0) + log(7.0) + log(7.0));
 
   BOOST_REQUIRE_CLOSE(d, testDTree->ComputeValue(&q1), 1e-10);
   BOOST_REQUIRE_CLOSE(d, testDTree->ComputeValue(&q2), 1e-10);
@@ -360,33 +321,14 @@
 
   long double root_error, l_error, r_error, rl_error, rr_error;
 
-  root_error = -1.0 * exp(-(long double) log((float) 4.0)
-			  - (long double) log((float) 7.0)
-			  - (long double) log((float) 7.0));
+  root_error = -1.0 * exp(-log(4.0) - log(7.0) - log(7.0));
 
-  l_error = -1.0 * exp(2 * log((long double) 2
-			       / (long double) 5)
-		       - ((long double) log((float) 7)
-			  + (long double) log((float) 4)
-			  + (long double) log((float) 4.5)));
-  r_error =  -1.0 * exp(2 * log((long double) 3
-				/ (long double) 5)
-			- ((long double) log((float) 7)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
+  l_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(7.0) + log(4.0) + log(4.5)));
+  r_error =  -1.0 * exp(2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5)));
 
-  rl_error = -1.0 * exp(2 * log((long double) 1
-				/ (long double) 5)
-			- ((long double) log((float) 0.5)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
+  rl_error = -1.0 * exp(2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5)));
+  rr_error = -1.0 * exp(2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5)));
 
-  rr_error = -1.0 * exp(2 * log((long double) 2
-				/ (long double) 5)
-			- ((long double) log((float) 6.5)
-			   + (long double) log((float) 4)
-			   + (long double) log((float) 2.5)));
-
   arma::Col<size_t> o_test(5);
   o_test << 0 << 1 << 2 << 3 << 4;
 




More information about the mlpack-svn mailing list