[mlpack-git] master, mlpack-1.0.x: Tests for ball trees and BallBound<> by Yash, to solve #250. (3563834)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:54:17 EST 2015
Repository : https://github.com/mlpack/mlpack
On branches: master,mlpack-1.0.x
Link : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40
>---------------------------------------------------------------
commit 356383497533b193a2619d59aff04eb0556905de
Author: Yash Vadalia <yashdv at gmail.com>
Date: Fri Jul 25 15:53:13 2014 +0000
Tests for ball trees and BallBound<> by Yash, to solve #250.
>---------------------------------------------------------------
356383497533b193a2619d59aff04eb0556905de
src/mlpack/tests/allkfn_test.cpp | 149 ++++++++++++++++++
src/mlpack/tests/allknn_test.cpp | 73 +++++++++
src/mlpack/tests/allkrann_search_test.cpp | 138 ++++++++++++++++-
src/mlpack/tests/range_search_test.cpp | 248 +++++++++++++++++++++++++++++-
src/mlpack/tests/tree_test.cpp | 80 ++++++++--
5 files changed, 668 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/tests/allkfn_test.cpp b/src/mlpack/tests/allkfn_test.cpp
index 03967f8..921c13a 100644
--- a/src/mlpack/tests/allkfn_test.cpp
+++ b/src/mlpack/tests/allkfn_test.cpp
@@ -5,12 +5,14 @@
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
using namespace mlpack;
using namespace mlpack::neighbor;
using namespace mlpack::tree;
+using namespace mlpack::metric;
using namespace mlpack::bound;
BOOST_AUTO_TEST_SUITE(AllkFNTest);
@@ -439,4 +441,151 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
}
}
+/**
+ * Test the cover tree single-tree furthest-neighbors method against the naive
+ * method. This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+ arma::mat data;
+ data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+ arma::mat naiveQuery(data); // For naive AllkNN.
+
+ CoverTree<LMetric<2>, FirstPointIsRoot,
+ NeighborSearchStat<FurthestNeighborSort> > tree = CoverTree<LMetric<2>,
+ FirstPointIsRoot, NeighborSearchStat<FurthestNeighborSort> >(data);
+
+ NeighborSearch<FurthestNeighborSort, LMetric<2>, CoverTree<LMetric<2>,
+ FirstPointIsRoot, NeighborSearchStat<FurthestNeighborSort> > >
+ coverTreeSearch(&tree, data, true);
+
+ AllkFN naive(naiveQuery, true);
+
+ arma::Mat<size_t> coverTreeNeighbors;
+ arma::mat coverTreeDistances;
+ coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
+
+ arma::Mat<size_t> naiveNeighbors;
+ arma::mat naiveDistances;
+ naive.Search(15, naiveNeighbors, naiveDistances);
+
+ for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
+ }
+}
+
+/**
+ * Test the cover tree dual-tree furthest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+ arma::mat dataset;
+ data::Load("test_data_3_1000.csv", dataset);
+
+ arma::mat kdtreeData(dataset);
+
+ AllkFN tree(kdtreeData);
+
+ arma::Mat<size_t> kdNeighbors;
+ arma::mat kdDistances;
+ tree.Search(5, kdNeighbors, kdDistances);
+
+ typedef CoverTree<LMetric<2, true>, FirstPointIsRoot,
+ NeighborSearchStat<FurthestNeighborSort> > TreeType;
+
+ TreeType referenceTree = TreeType(dataset);
+
+ NeighborSearch<FurthestNeighborSort, LMetric<2, true>,
+ TreeType> coverTreeSearch(&referenceTree, dataset);
+
+ arma::Mat<size_t> coverNeighbors;
+ arma::mat coverDistances;
+ coverTreeSearch.Search(5, coverNeighbors, coverDistances);
+
+ for (size_t i = 0; i < coverNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(coverNeighbors(i), kdNeighbors(i));
+ BOOST_REQUIRE_CLOSE(coverDistances(i), kdDistances(i), 1e-5);
+ }
+}
+
+/**
+ * Test the ball tree single-tree furthest-neighbors method against the naive
+ * method. This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+ arma::mat data;
+ data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+ typedef BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+ NeighborSearchStat<FurthestNeighborSort> > TreeType;
+ TreeType tree = TreeType(data);
+
+ // BinarySpaceTree modifies data. Use modified data to maintain the
+ // correspondance between points in the dataset for both methods. The order of
+ // query points in both methods should be same.
+ arma::mat naiveQuery(data); // For naive AllkNN.
+
+ NeighborSearch<FurthestNeighborSort, LMetric<2>, TreeType>
+ ballTreeSearch(&tree, data, true);
+
+ AllkFN naive(naiveQuery, true);
+
+ arma::Mat<size_t> ballTreeNeighbors;
+ arma::mat ballTreeDistances;
+ ballTreeSearch.Search(15, ballTreeNeighbors, ballTreeDistances);
+
+ arma::Mat<size_t> naiveNeighbors;
+ arma::mat naiveDistances;
+ naive.Search(15, naiveNeighbors, naiveDistances);
+
+ for (size_t i = 0; i < ballTreeNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(ballTreeNeighbors[i], naiveNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(ballTreeDistances[i], naiveDistances[i], 1e-5);
+ }
+}
+
+/**
+ * Test the ball tree dual-tree furthest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+ arma::mat dataset;
+ data::Load("test_data_3_1000.csv", dataset);
+
+ arma::mat kdtreeData(dataset);
+
+ AllkFN tree(kdtreeData);
+
+ arma::Mat<size_t> kdNeighbors;
+ arma::mat kdDistances;
+ tree.Search(5, kdNeighbors, kdDistances);
+
+ NeighborSearch<FurthestNeighborSort, LMetric<2, true>,
+ BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+ NeighborSearchStat<FurthestNeighborSort> > >
+ ballTreeSearch(dataset);
+
+ arma::Mat<size_t> ballNeighbors;
+ arma::mat ballDistances;
+ ballTreeSearch.Search(5, ballNeighbors, ballDistances);
+
+ for (size_t i = 0; i < ballNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(ballNeighbors(i), kdNeighbors(i));
+ BOOST_REQUIRE_CLOSE(ballDistances(i), kdDistances(i), 1e-5);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 4ea444f..e76df2d 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -663,6 +663,79 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
}
}
+/**
+ * Test the ball tree single-tree nearest-neighbors method against the naive
+ * method. This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+ arma::mat data;
+ data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+ typedef BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+ NeighborSearchStat<NearestNeighborSort> > TreeType;
+ TreeType tree = TreeType(data);
+
+ // BinarySpaceTree modifies data. Use modified data to maintain the
+ // correspondance between points in the dataset for both methods. The order of
+ // query points in both methods should be same.
+ arma::mat naiveQuery(data); // For naive AllkNN.
+
+ NeighborSearch<NearestNeighborSort, LMetric<2>, TreeType>
+ ballTreeSearch(&tree, data, true);
+
+ AllkNN naive(naiveQuery, true);
+
+ arma::Mat<size_t> ballTreeNeighbors;
+ arma::mat ballTreeDistances;
+ ballTreeSearch.Search(1, ballTreeNeighbors, ballTreeDistances);
+
+ arma::Mat<size_t> naiveNeighbors;
+ arma::mat naiveDistances;
+ naive.Search(1, naiveNeighbors, naiveDistances);
+
+ for (size_t i = 0; i < ballTreeNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(ballTreeNeighbors[i], naiveNeighbors[i]);
+ BOOST_REQUIRE_CLOSE(ballTreeDistances[i], naiveDistances[i], 1e-5);
+ }
+}
+
+/**
+ * Test the ball tree dual-tree nearest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+ arma::mat dataset;
+ data::Load("test_data_3_1000.csv", dataset);
+
+ arma::mat kdtreeData(dataset);
+
+ AllkNN tree(kdtreeData);
+
+ arma::Mat<size_t> kdNeighbors;
+ arma::mat kdDistances;
+ tree.Search(5, kdNeighbors, kdDistances);
+
+ NeighborSearch<NearestNeighborSort, LMetric<2, true>,
+ BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+ NeighborSearchStat<NearestNeighborSort> > >
+ ballTreeSearch(dataset);
+
+ arma::Mat<size_t> ballNeighbors;
+ arma::mat ballDistances;
+ ballTreeSearch.Search(5, ballNeighbors, ballDistances);
+
+ for (size_t i = 0; i < ballNeighbors.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(ballNeighbors(i), kdNeighbors(i));
+ BOOST_REQUIRE_CLOSE(ballDistances(i), kdDistances(i), 1e-5);
+ }
+}
+
// Make sure sparse nearest neighbors works with kd trees.
BOOST_AUTO_TEST_CASE(SparseAllkNNKDTreeTest)
{
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 76da72b..ab86d3d 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -17,6 +17,9 @@
using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
BOOST_AUTO_TEST_SUITE(AllkRANNTest);
@@ -344,7 +347,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
RACoverTreeSearch;
- RACoverTreeSearch tssRann(refData, queryData, true);
+ RACoverTreeSearch tssRann(refData, queryData, false, true);
// The relative ranks for the given query reference pair.
arma::Mat<size_t> qrRanks;
@@ -455,4 +458,137 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
}
+// Test single-tree rank-approximate search with ball trees.
+// This is known to not work right now.
+/*
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ typedef BinarySpaceTree<BallBound<>, RAQueryStat<NearestNeighborSort> >
+ TreeType;
+ typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+ RABallTreeSearch;
+
+ RABallTreeSearch tssRann(refData, queryData, false, true);
+
+ // The relative ranks for the given query reference pair.
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 30;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSS (ball tree): RANN guarantee fails on "
+ << numQueriesFail << " queries." << endl;
+
+ // Assert that at most 5% of the queries fall out of this threshold.
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
+// Test dual-tree rank-approximate search with Ball trees.
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ typedef BinarySpaceTree<BallBound<>, RAQueryStat<NearestNeighborSort> >
+ TreeType;
+ typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+ RABallTreeSearch;
+
+ TreeType refTree(refData);
+ TreeType queryTree(queryData);
+
+ RABallTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+
+ tsdRann.ResetQueryTree();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSD (Ball tree): RANN guarantee fails on "
+ << numQueriesFail << " queries." << endl;
+
+ // assert that at most 5% of the queries fall out of this threshold
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+*/
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 3af8380..4002d44 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -441,7 +441,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
}
/**
- * Test the dual-tree nearest-neighbors method with the naive method. This
+ * Test the dual-tree range search method with the naive method. This
* uses both a query and reference dataset.
*
* Errors are produced if the results are not identical.
@@ -490,7 +490,7 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
}
/**
- * Test the dual-tree nearest-neighbors method with the naive method. This uses
+ * Test the dual-tree range search method with the naive method. This uses
* only a reference dataset.
*
* Errors are produced if the results are not identical.
@@ -539,7 +539,7 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
}
/**
- * Test the single-tree nearest-neighbors method with the naive method. This
+ * Test the single-tree range search method with the naive method. This
* uses only a reference dataset.
*
* Errors are produced if the results are not identical.
@@ -588,8 +588,8 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
}
/**
- * Ensure that range search with cover trees works by comparing with the kd-tree
- * implementation.
+ * Ensure that dual tree range search with cover trees works by comparing
+ * with the kd-tree implementation.
*/
BOOST_AUTO_TEST_CASE(CoverTreeTest)
{
@@ -666,7 +666,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
}
/**
- * Ensure that range search with cover trees works when using two datasets.
+ * Ensure that dual tree range search with cover trees works when using
+ * two datasets.
*/
BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
{
@@ -824,4 +825,239 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
}
}
+/**
+ * Ensure that single-tree ball tree range search works.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ // Set up ball tree range search.
+ typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+ TreeType tree(data);
+ RangeSearch<metric::EuclideanDistance, TreeType>
+ ballsearch(&tree, data, true);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ // Set up kd-tree range search.
+ RangeSearch<> kdsearch(data, true);
+
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.5, 1.5);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(0.8, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for ball tree search.
+ vector<vector<size_t> > ballNeighbors;
+ vector<vector<double> > ballDistances;
+
+ // Clean the tree statistics.
+ CleanTree(tree);
+
+ // Run the searches.
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+ ballsearch.Search(range, ballNeighbors, ballDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > ballSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(ballNeighbors, ballDistances, ballSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+ 1e-5);
+ }
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+ }
+ }
+}
+
+/**
+ * Ensure that dual tree range search with ball trees works by comparing
+ * with the kd-tree implementation.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ // Set up ball tree range search.
+ typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+ TreeType tree(data);
+ RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree, data);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ // Set up kd-tree range search.
+ RangeSearch<> kdsearch(data);
+
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.5, 1.5);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(0.8, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for ball tree search.
+ vector<vector<size_t> > ballNeighbors;
+ vector<vector<double> > ballDistances;
+
+ // Clean the tree statistics.
+ CleanTree(tree);
+
+ // Run the searches.
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+ ballsearch.Search(range, ballNeighbors, ballDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > ballSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(ballNeighbors, ballDistances, ballSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+ 1e-5);
+ }
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+ }
+ }
+}
+
+/**
+ * Ensure that dual tree range search with ball trees works when using
+ * two datasets.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ arma::mat queries;
+ queries.randu(8, 350); // 350 points in 8 dimensions.
+
+ // Set up ball tree range search.
+ typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+ TreeType tree(data);
+ TreeType queryTree(queries);
+ RangeSearch<metric::EuclideanDistance, TreeType>
+ ballsearch(&tree, &queryTree, data, queries);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ // Set up kd-tree range search. We don't have an easy way to rebuild the
+ // tree, so we'll just reinstantiate it here each loop time.
+ RangeSearch<> kdsearch(data, queries);
+
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.85, 1.05);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(1.35, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for ball tree search.
+ vector<vector<size_t> > ballNeighbors;
+ vector<vector<double> > ballDistances;
+
+ // Clean the trees.
+ CleanTree(tree);
+ CleanTree(queryTree);
+
+ // Run the searches.
+ ballsearch.Search(range, ballNeighbors, ballDistances);
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > ballSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(ballNeighbors, ballDistances, ballSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+ 1e-5);
+ }
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 1971185..eaa6c79 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1206,7 +1206,7 @@ BOOST_AUTO_TEST_CASE(ParentDistanceTestWithMapping)
// Forward declaration of methods we need for the next test.
template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data);
+bool CheckPointBounds(TreeType& node, const MatType& data);
template<typename TreeType>
void GenerateVectorOfTree(TreeType* node,
@@ -1274,7 +1274,7 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
}
// Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
+ CheckPointBounds(root, dataset);
// Now check that no peers overlap.
std::vector<TreeType*> v;
@@ -1309,20 +1309,74 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
// Recursively checks that each node contains all points that it claims to have.
template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data)
+bool CheckPointBounds(TreeType& node, const MatType& data)
{
- if (node == NULL) // We have passed a leaf node.
- return true;
-
- TreeType* left = node->Left();
- TreeType* right = node->Right();
-
// Check that each point which this tree claims is actually inside the tree.
- for (size_t index = 0; index < node->NumDescendants(); index++)
- if (!node->Bound().Contains(data.col(node->Descendant(index))))
+ for (size_t index = 0; index < node.NumDescendants(); index++)
+ if (!node.Bound().Contains(data.col(node.Descendant(index))))
return false;
- return CheckPointBounds(left, data) && CheckPointBounds(right, data);
+ bool result = true;
+ for (size_t child = 0; child < node.NumChildren(); ++child)
+ result &= CheckPointBounds(node.Child(child), data);
+ return result;
+}
+
+/**
+ * Exhaustive ball tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ * nodes.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(BallTreeTest)
+{
+ typedef BinarySpaceTree<BallBound<> > TreeType;
+
+ size_t maxRuns = 10; // Ten total tests.
+ size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
+
+ // We use the default leaf size of 20.
+ for(size_t run = 0; run < maxRuns; run++)
+ {
+ size_t dimensions = run + 2;
+ size_t maxPoints = (run + 1) * pointIncrements;
+
+ size_t size = maxPoints;
+ arma::mat dataset = arma::mat(dimensions, size);
+ arma::mat datacopy; // Used to test mappings.
+
+ // Mappings for post-sort verification of data.
+ std::vector<size_t> newToOld;
+ std::vector<size_t> oldToNew;
+
+ // Generate data.
+ dataset.randu();
+ datacopy = dataset; // Save a copy.
+
+ // Build the tree itself.
+ TreeType root(dataset, newToOld, oldToNew);
+
+ // Ensure the size of the tree is correct.
+ BOOST_REQUIRE_EQUAL(root.NumDescendants(), size);
+
+ // Check the forward and backward mappings for correctness.
+ for(size_t i = 0; i < size; i++)
+ {
+ for(size_t j = 0; j < dimensions; j++)
+ {
+ BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+ BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+ }
+ }
+
+ // Now check that each point is contained inside of all bounds above it.
+ CheckPointBounds(root, dataset);
+ }
}
template<int t_pow>
@@ -1429,7 +1483,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
}
// Now check that each point is contained inside of all bounds above it.
- CheckPointBounds(&root, dataset);
+ CheckPointBounds(root, dataset);
// Now check that no peers overlap.
std::vector<TreeType*> v;
More information about the mlpack-git
mailing list