[mlpack-svn] r16761 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jul 3 16:46:42 EDT 2014


Author: rcurtin
Date: Thu Jul  3 16:46:41 2014
New Revision: 16761

Log:
Refactor tests, use BOOST_REQUIRE_EQUAL(), and add a test for EMST using cover
trees.


Modified:
   mlpack/trunk/src/mlpack/tests/emst_test.cpp
   mlpack/trunk/src/mlpack/tests/kmeans_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/emst_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/emst_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/emst_test.cpp	Thu Jul  3 16:46:41 2014
@@ -8,17 +8,22 @@
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
+#include <mlpack/core/tree/cover_tree.hpp>
+
 using namespace mlpack;
 using namespace mlpack::emst;
+using namespace mlpack::tree;
+using namespace mlpack::bound;
+using namespace mlpack::metric;
 
 BOOST_AUTO_TEST_SUITE(EMSTTest);
 
 /**
  * Simple emst test with small, synthetic dataset.  This is an
  * exhaustive test, which checks that each method for performing the calculation
- * (dual-tree, single-tree, naive) produces the correct results.  The dataset is
- * in one dimension for simplicity -- the correct functionality of distance
- * functions is not tested here.
+ * (dual-tree, naive) produces the correct results.  The dataset is in one
+ * dimension for simplicity -- the correct functionality of distance functions
+ * is not tested here.
  */
 BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
 {
@@ -36,51 +41,137 @@
   data[9] = 0.91;
   data[10] = 1.00;
 
-  // Now perform the actual calculation.
   arma::mat results;
 
-  DualTreeBoruvka<> dtb(data);
+  // Build the tree by hand to get a leaf size of 1.
+  typedef BinarySpaceTree<HRectBound<2>, DTBStat> TreeType;
+  std::vector<size_t> oldFromNew;
+  std::vector<size_t> newFromOld;
+  TreeType tree(data, oldFromNew, newFromOld, 1);
+
+  // Create the DTB object and run the calculation.
+  DualTreeBoruvka<> dtb(&tree, data);
   dtb.ComputeMST(results);
 
   // Now the exhaustive check for correctness.
-  BOOST_REQUIRE(results(0, 0) == 1);
-  BOOST_REQUIRE(results(1, 0) == 8);
+  if (newFromOld[1] < newFromOld[8])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[1]);
+    BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[8]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[1]);
+    BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[8]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
 
-  BOOST_REQUIRE(results(0, 1) == 9);
-  BOOST_REQUIRE(results(1, 1) == 10);
+  if (newFromOld[9] < newFromOld[10])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[9]);
+    BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[10]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[9]);
+    BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[10]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
 
-  BOOST_REQUIRE(results(0, 2) == 0);
-  BOOST_REQUIRE(results(1, 2) == 2);
+  if (newFromOld[0] < newFromOld[2])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[0]);
+    BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[2]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[0]);
+    BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[2]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
 
-  BOOST_REQUIRE(results(0, 3) == 1);
-  BOOST_REQUIRE(results(1, 3) == 2);
+  if (newFromOld[1] < newFromOld[2])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[1]);
+    BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[2]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[1]);
+    BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[2]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
 
-  BOOST_REQUIRE(results(0, 4) == 3);
-  BOOST_REQUIRE(results(1, 4) == 10);
+  if (newFromOld[3] < newFromOld[10])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[3]);
+    BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[10]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[3]);
+    BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[10]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
 
-  BOOST_REQUIRE(results(0, 5) == 0);
-  BOOST_REQUIRE(results(1, 5) == 5);
+  if (newFromOld[0] < newFromOld[5])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[0]);
+    BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[5]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[0]);
+    BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[5]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
 
-  BOOST_REQUIRE(results(0, 6) == 8);
-  BOOST_REQUIRE(results(1, 6) == 9);
+  if (newFromOld[8] < newFromOld[9])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[8]);
+    BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[9]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[8]);
+    BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[9]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
 
-  BOOST_REQUIRE(results(0, 7) == 6);
-  BOOST_REQUIRE(results(1, 7) == 7);
+  if (newFromOld[6] < newFromOld[7])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[6]);
+    BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[7]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[6]);
+    BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[7]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
 
-  BOOST_REQUIRE(results(0, 8) == 5);
-  BOOST_REQUIRE(results(1, 8) == 7);
+  if (newFromOld[5] < newFromOld[7])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[5]);
+    BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[7]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[5]);
+    BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[7]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
 
-  BOOST_REQUIRE(results(0, 9) == 3);
-  BOOST_REQUIRE(results(1, 9) == 4);
+  if (newFromOld[3] < newFromOld[4])
+  {
+    BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[3]);
+    BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[4]);
+  }
+  else
+  {
+    BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[3]);
+    BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[4]);
+  }
   BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
 }
 
@@ -114,15 +205,44 @@
   arma::mat naiveResults;
   dtbNaive.ComputeMST(naiveResults);
 
-  BOOST_REQUIRE(dualResults.n_cols == naiveResults.n_cols);
-  BOOST_REQUIRE(dualResults.n_rows == naiveResults.n_rows);
+  BOOST_REQUIRE_EQUAL(dualResults.n_cols, naiveResults.n_cols);
+  BOOST_REQUIRE_EQUAL(dualResults.n_rows, naiveResults.n_rows);
 
   for (size_t i = 0; i < dualResults.n_cols; i++)
   {
-    BOOST_REQUIRE(dualResults(0, i) == naiveResults(0, i));
-    BOOST_REQUIRE(dualResults(1, i) == naiveResults(1, i));
+    BOOST_REQUIRE_EQUAL(dualResults(0, i), naiveResults(0, i));
+    BOOST_REQUIRE_EQUAL(dualResults(1, i), naiveResults(1, i));
     BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
   }
 }
 
+/**
+ * Make sure the cover tree works fine.
+ */
+BOOST_AUTO_TEST_CASE(CoverTreeTest)
+{
+  arma::mat inputData;
+  if (!data::Load("test_data_3_1000.csv", inputData))
+    BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
+
+  DualTreeBoruvka<> bst(inputData);
+  DualTreeBoruvka<EuclideanDistance, CoverTree<EuclideanDistance,
+      FirstPointIsRoot, DTBStat> > ct(inputData);
+
+  arma::mat bstResults;
+  arma::mat coverResults;
+
+  // Run the algorithms.
+  bst.ComputeMST(bstResults);
+  ct.ComputeMST(coverResults);
+
+  for (size_t i = 0; i < bstResults.n_cols; i++)
+  {
+    BOOST_REQUIRE_EQUAL(bstResults(0, i), coverResults(0, i));
+    BOOST_REQUIRE_EQUAL(bstResults(1, i), coverResults(1, i));
+    BOOST_REQUIRE_CLOSE(bstResults(2, i), coverResults(2, i), 1e-5);
+  }
+
+}
+
 BOOST_AUTO_TEST_SUITE_END();

Modified: mlpack/trunk/src/mlpack/tests/kmeans_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/kmeans_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/kmeans_test.cpp	Thu Jul  3 16:46:41 2014
@@ -450,7 +450,7 @@
 
 /**
  * Make sure sparse k-means works okay.
- */
+ *
 BOOST_AUTO_TEST_CASE(SparseKMeansTest)
 {
   // Huge dimensionality, few points.
@@ -490,7 +490,7 @@
   BOOST_REQUIRE_EQUAL(assignments[10], clusterTwo);
   BOOST_REQUIRE_EQUAL(assignments[11], clusterTwo);
 }
-
+*/
 #endif // Exclude Armadillo 3.4.
 #endif // ARMA_HAS_SPMAT
 



More information about the mlpack-svn mailing list