[mlpack-git] master: Refactor test to handle internally-copying trees correctly. (b3d3eb8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:42:25 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit b3d3eb8f0c325afb19e97487b20bf128e70b3275
Author: ryan <ryan at ratml.org>
Date: Mon Jul 27 00:09:49 2015 -0400
Refactor test to handle internally-copying trees correctly.
>---------------------------------------------------------------
b3d3eb8f0c325afb19e97487b20bf128e70b3275
src/mlpack/tests/range_search_test.cpp | 62 +++++++++++++---------------------
1 file changed, 23 insertions(+), 39 deletions(-)
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 63f486e..3110a9f 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -76,10 +76,9 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
// We will loop through three times, one for each method of performing the
// calculation.
- arma::mat dataMutable = data;
std::vector<size_t> oldFromNew;
std::vector<size_t> newFromOld;
- TreeType* tree = new TreeType(dataMutable, oldFromNew, newFromOld, 1);
+ TreeType* tree = new TreeType(data, oldFromNew, newFromOld, 1);
for (int i = 0; i < 3; i++)
{
RangeSearch<>* rs;
@@ -87,7 +86,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
switch (i)
{
case 0: // Use the naive method.
- rs = new RangeSearch<>(dataMutable, true);
+ rs = new RangeSearch<>(tree->Dataset(), true);
break;
case 1: // Use the single-tree method.
rs = new RangeSearch<>(tree, true);
@@ -211,6 +210,8 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
BOOST_REQUIRE_CLOSE(sortedOutput[newFromOld[10]][3].first, 0.65, 1e-5);
// Now do it again with a different range: [sqrt(0.5) 1.0].
+ if (rs->ReferenceTree())
+ CleanTree(*rs->ReferenceTree());
rs->Search(Range(sqrt(0.5), 1.0), neighbors, distances);
SortResults(neighbors, distances, sortedOutput);
@@ -272,6 +273,8 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
BOOST_REQUIRE_CLOSE(sortedOutput[newFromOld[10]][1].first, 0.95, 1e-5);
// Now do it again with a different range: [1.0 inf].
+ if (rs->ReferenceTree())
+ CleanTree(*rs->ReferenceTree());
rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
distances);
SortResults(neighbors, distances, sortedOutput);
@@ -597,9 +600,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
data.randu(8, 1000); // 1000 points in 8 dimensions.
// Set up cover tree range search.
- StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat> tree(data);
RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
- coversearch(&tree);
+ coversearch(data);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -637,7 +639,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
vector<vector<double>> coverDistances;
// Clean the tree statistics.
- CleanTree(tree);
+ CleanTree(*coversearch.ReferenceTree());
// Run the searches.
kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -675,12 +677,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
queries.randu(8, 350); // 350 points in 8 dimensions.
// Set up cover tree range search.
- typedef StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat>
- CoverTreeType;
- CoverTreeType tree(data);
- CoverTreeType* queryTree = new CoverTreeType(queries);
RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
- coversearch(&tree);
+ coversearch(data);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -719,12 +717,10 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
vector<vector<double>> coverDistances;
// Clean the trees.
- CleanTree(tree);
- delete queryTree;
- queryTree = new CoverTreeType(queries);
+ CleanTree(*coversearch.ReferenceTree());
// Run the searches.
- coversearch.Search(queryTree, range, coverNeighbors, coverDistances);
+ coversearch.Search(queries, range, coverNeighbors, coverDistances);
kdsearch.Search(queries, range, kdNeighbors, kdDistances);
// Sort before comparison.
@@ -745,8 +741,6 @@ BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
}
}
-
- delete queryTree;
}
/**
@@ -758,17 +752,14 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
data.randu(8, 1000); // 1000 points in 8 dimensions.
// Set up cover tree range search.
- typedef StandardCoverTree<EuclideanDistance, RangeSearchStat, arma::mat>
- CoverTreeType;
- CoverTreeType tree(data);
RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>
- coversearch(&tree, true);
+ coversearch(data, false, true);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
{
// Set up kd-tree range search.
- RangeSearch<> kdsearch(data, true);
+ RangeSearch<> kdsearch(data);
Range range;
switch (r)
@@ -800,7 +791,7 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
vector<vector<double>> coverDistances;
// Clean the tree statistics.
- CleanTree(tree);
+ CleanTree(*coversearch.ReferenceTree());
// Run the searches.
kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -835,15 +826,14 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
data.randu(8, 1000); // 1000 points in 8 dimensions.
// Set up ball tree range search.
- typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
- TreeType tree(data);
- RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree, true);
+ RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data, false,
+ true);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
{
// Set up kd-tree range search.
- RangeSearch<> kdsearch(data, true);
+ RangeSearch<> kdsearch(data);
Range range;
switch (r)
@@ -875,7 +865,7 @@ BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
vector<vector<double>> ballDistances;
// Clean the tree statistics.
- CleanTree(tree);
+ CleanTree(*ballsearch.ReferenceTree());
// Run the searches.
kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -911,9 +901,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
data.randu(8, 1000); // 1000 points in 8 dimensions.
// Set up ball tree range search.
- typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
- TreeType tree(data);
- RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree);
+ RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -951,7 +939,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest)
vector<vector<double>> ballDistances;
// Clean the tree statistics.
- CleanTree(tree);
+ CleanTree(*ballsearch.ReferenceTree());
// Run the searches.
kdsearch.Search(range, kdNeighbors, kdDistances);
@@ -990,10 +978,7 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
queries.randu(8, 350); // 350 points in 8 dimensions.
// Set up ball tree range search.
- typedef BallTree<EuclideanDistance, RangeSearchStat, arma::mat> TreeType;
- TreeType tree(data);
- TreeType queryTree(queries);
- RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(&tree);
+ RangeSearch<EuclideanDistance, arma::mat, BallTree> ballsearch(data);
// Four trials with different ranges.
for (size_t r = 0; r < 4; ++r)
@@ -1032,11 +1017,10 @@ BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
vector<vector<double>> ballDistances;
// Clean the trees.
- CleanTree(tree);
- CleanTree(queryTree);
+ CleanTree(*ballsearch.ReferenceTree());
// Run the searches.
- ballsearch.Search(&queryTree, range, ballNeighbors, ballDistances);
+ ballsearch.Search(queries, range, ballNeighbors, ballDistances);
kdsearch.Search(queries, range, kdNeighbors, kdDistances);
// Sort before comparison.
More information about the mlpack-git
mailing list