[mlpack-git] master: Refactor test to handle internally-copying trees correctly. (b3d3eb8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:25 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b

>---------------------------------------------------------------

commit b3d3eb8f0c325afb19e97487b20bf128e70b3275
Author: ryan <ryan at ratml.org>
Date:   Mon Jul 27 00:09:49 2015 -0400

    Refactor test to handle internally-copying trees correctly.


>---------------------------------------------------------------

b3d3eb8f0c325afb19e97487b20bf128e70b3275
 src/mlpack/tests/range_search_test.cpp | 62 +++++++++++++---------------------
 1 file changed, 23 insertions(+), 39 deletions(-)

diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 63f486e..3110a9f 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -76,10 +76,9 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
 
   // We will loop through three times, one for each method of performing the
   // calculation.
-  arma::mat dataMutable = data;
   std::vector<size_t> oldFromNew;
   std::vector<size_t> newFromOld;
-  TreeType* tree = new TreeType(dataMutable, oldFromNew, newFromOld, 1);
+  TreeType* tree = new TreeType(data, oldFromNew, newFromOld, 1);
   for (int i = 0; i < 3; i++)
   {
     RangeSearch<>* rs;
@@ -87,7 +86,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
     switch (i)
     {
       case 0: // Use the naive method.
-        rs = new RangeSearch<>(dataMutable, true);
+        rs = new RangeSearch<>(tree->Dataset(), true);
         break;
       case 1: // Use the single-tree method.
         rs = new RangeSearch<>(tree, true);
@@ -211,6 +210,8 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
     BOOST_REQUIRE_CLOSE(sortedOutput[newFromOld[10]][3].first, 0.65, 1e-5);
 
     // Now do it again with a different range: [sqrt(0.5) 1.0].
+    if (rs->ReferenceTree())
+      CleanTree(*rs->ReferenceTree());
     rs->Search(Range(sqrt(0.5), 1.0), neighbors, distances);
     SortResults(neighbors, distances, sortedOutput);
 
@@ -272,6 +273,8 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
     BOOST_REQUIRE_CLOSE(sortedOutput[newFromOld[10]][1].first, 0.95, 1e-5);
 
     // Now do it again with a different range: [1.0 inf].
+    if (rs->ReferenceTree())
+      CleanTree(*rs->ReferenceTree());
     rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
         distances);
     SortResults(neighbors, distances, sortedOutput);
@@ -597,9 +600,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
   data.randu(8, 1000); // 1000 points in 8 dimensions.
 
   // Set up cover tree range search.
-  StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat> tree(data);
   RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
-      coversearch(&tree);
+      coversearch(data);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -637,7 +639,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
     vector<vector<double>> coverDistances;
 
     // Clean the tree statistics.
-    CleanTree(tree);
+    CleanTree(*coversearch.ReferenceTree());
 
     // Run the searches.
     kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -675,12 +677,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
   queries.randu(8, 350); // 350 points in 8 dimensions.
 
   // Set up cover tree range search.
-  typedef StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat>
-      CoverTreeType;
-  CoverTreeType tree(data);
-  CoverTreeType* queryTree = new CoverTreeType(queries);
   RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
-      coversearch(&tree);
+      coversearch(data);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -719,12 +717,10 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
     vector<vector<double>> coverDistances;
 
     // Clean the trees.
-    CleanTree(tree);
-    delete queryTree;
-    queryTree = new CoverTreeType(queries);
+    CleanTree(*coversearch.ReferenceTree());
 
     // Run the searches.
-    coversearch.Search(queryTree, range, coverNeighbors, coverDistances);
+    coversearch.Search(queries, range, coverNeighbors, coverDistances);
     kdsearch.Search(queries, range, kdNeighbors, kdDistances);
 
     // Sort before comparison.
@@ -745,8 +741,6 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
       BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
     }
   }
-
-  delete queryTree;
 }
 
 /**
@@ -758,17 +752,14 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
   data.randu(8, 1000); // 1000 points in 8 dimensions.
 
   // Set up cover tree range search.
-  typedef StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat>
-      CoverTreeType;
-  CoverTreeType tree(data);
   RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
-      coversearch(&tree, true);
+      coversearch(data, false, true);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
   {
     // Set up kd-tree range search.
-    RangeSearch<> kdsearch(data, true);
+    RangeSearch<> kdsearch(data);
 
     Range range;
     switch (r)
@@ -800,7 +791,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
     vector<vector<double>> coverDistances;
 
     // Clean the tree statistics.
-    CleanTree(tree);
+    CleanTree(*coversearch.ReferenceTree());
 
     // Run the searches.
     kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -835,15 +826,14 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
   data.randu(8, 1000); // 1000 points in 8 dimensions.
 
   // Set up ball tree range search.
-  typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
-  TreeType tree(data);
-  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree, true);
+  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data, false,
+      true);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
   {
     // Set up kd-tree range search.
-    RangeSearch<> kdsearch(data, true);
+    RangeSearch<> kdsearch(data);
 
     Range range;
     switch (r)
@@ -875,7 +865,7 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
     vector<vector<double>> ballDistances;
 
     // Clean the tree statistics.
-    CleanTree(tree);
+    CleanTree(*ballsearch.ReferenceTree());
 
     // Run the searches.
     kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -911,9 +901,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
   data.randu(8, 1000); // 1000 points in 8 dimensions.
 
   // Set up ball tree range search.
-  typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
-  TreeType tree(data);
-  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree);
+  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -951,7 +939,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
     vector<vector<double>> ballDistances;
 
     // Clean the tree statistics.
-    CleanTree(tree);
+    CleanTree(*ballsearch.ReferenceTree());
 
     // Run the searches.
     kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -990,10 +978,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
   queries.randu(8, 350); // 350 points in 8 dimensions.
 
   // Set up ball tree range search.
-  typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
-  TreeType tree(data);
-  TreeType queryTree(queries);
-  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree);
+  RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data);
 
   // Four trials with different ranges.
   for (size_t r = 0; r < 4; ++r)
@@ -1032,11 +1017,10 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
     vector<vector<double>> ballDistances;
 
     // Clean the trees.
-    CleanTree(tree);
-    CleanTree(queryTree);
+    CleanTree(*ballsearch.ReferenceTree());
 
     // Run the searches.
-    ballsearch.Search(&queryTree, range, ballNeighbors, ballDistances);
+    ballsearch.Search(queries, range, ballNeighbors, ballDistances);
     kdsearch.Search(queries, range, kdNeighbors, kdDistances);
 
     // Sort before comparison.



More information about the mlpack-git mailing list