[mlpack-git] master, mlpack-1.0.x: Tests for ball trees and BallBound<> by Yash, to solve #250. (3563834)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Mar 5 21:54:17 EST 2015


Repository : https://github.com/mlpack/mlpack

On branches: master,mlpack-1.0.x
Link       : https://github.com/mlpack/mlpack/compare/904762495c039e345beba14c1142fd719b3bd50e...f94823c800ad6f7266995c700b1b630d5ffdcf40

>---------------------------------------------------------------

commit 356383497533b193a2619d59aff04eb0556905de
Author: Yash Vadalia <yashdv at gmail.com>
Date:   Fri Jul 25 15:53:13 2014 +0000

    Tests for ball trees and BallBound<> by Yash, to solve #250.


>---------------------------------------------------------------

356383497533b193a2619d59aff04eb0556905de
 src/mlpack/tests/allkfn_test.cpp          | 149 ++++++++++++++++++
 src/mlpack/tests/allknn_test.cpp          |  73 +++++++++
 src/mlpack/tests/allkrann_search_test.cpp | 138 ++++++++++++++++-
 src/mlpack/tests/range_search_test.cpp    | 248 +++++++++++++++++++++++++++++-
 src/mlpack/tests/tree_test.cpp            |  80 ++++++++--
 5 files changed, 668 insertions(+), 20 deletions(-)

diff --git a/src/mlpack/tests/allkfn_test.cpp b/src/mlpack/tests/allkfn_test.cpp
index 03967f8..921c13a 100644
--- a/src/mlpack/tests/allkfn_test.cpp
+++ b/src/mlpack/tests/allkfn_test.cpp
@@ -5,12 +5,14 @@
  */
 #include <mlpack/core.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
 using namespace mlpack;
 using namespace mlpack::neighbor;
 using namespace mlpack::tree;
+using namespace mlpack::metric;
 using namespace mlpack::bound;
 
 BOOST_AUTO_TEST_SUITE(AllkFNTest);
@@ -439,4 +441,151 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
   }
 }
 
+/**
+ * Test the cover tree single-tree furthest-neighbors method against the naive
+ * method.  This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+  arma::mat data;
+  data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+  arma::mat naiveQuery(data); // For naive AllkNN.
+
+  CoverTree<LMetric<2>, FirstPointIsRoot,
+      NeighborSearchStat<FurthestNeighborSort> > tree = CoverTree<LMetric<2>,
+      FirstPointIsRoot, NeighborSearchStat<FurthestNeighborSort> >(data);
+
+  NeighborSearch<FurthestNeighborSort, LMetric<2>, CoverTree<LMetric<2>,
+      FirstPointIsRoot, NeighborSearchStat<FurthestNeighborSort> > >
+      coverTreeSearch(&tree, data, true);
+
+  AllkFN naive(naiveQuery, true);
+
+  arma::Mat<size_t> coverTreeNeighbors;
+  arma::mat coverTreeDistances;
+  coverTreeSearch.Search(15, coverTreeNeighbors, coverTreeDistances);
+
+  arma::Mat<size_t> naiveNeighbors;
+  arma::mat naiveDistances;
+  naive.Search(15, naiveNeighbors, naiveDistances);
+
+  for (size_t i = 0; i < coverTreeNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(coverTreeNeighbors[i], naiveNeighbors[i]);
+    BOOST_REQUIRE_CLOSE(coverTreeDistances[i], naiveDistances[i], 1e-5);
+  }
+}
+
+/**
+ * Test the cover tree dual-tree furthest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+  arma::mat dataset;
+  data::Load("test_data_3_1000.csv", dataset);
+
+  arma::mat kdtreeData(dataset);
+
+  AllkFN tree(kdtreeData);
+
+  arma::Mat<size_t> kdNeighbors;
+  arma::mat kdDistances;
+  tree.Search(5, kdNeighbors, kdDistances);
+
+  typedef CoverTree<LMetric<2, true>, FirstPointIsRoot,
+      NeighborSearchStat<FurthestNeighborSort> > TreeType;
+
+  TreeType referenceTree = TreeType(dataset);
+
+  NeighborSearch<FurthestNeighborSort, LMetric<2, true>,
+      TreeType> coverTreeSearch(&referenceTree, dataset);
+
+  arma::Mat<size_t> coverNeighbors;
+  arma::mat coverDistances;
+  coverTreeSearch.Search(5, coverNeighbors, coverDistances);
+
+  for (size_t i = 0; i < coverNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(coverNeighbors(i), kdNeighbors(i));
+    BOOST_REQUIRE_CLOSE(coverDistances(i), kdDistances(i), 1e-5);
+  }
+}
+
+/**
+ * Test the ball tree single-tree furthest-neighbors method against the naive
+ * method.  This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+  arma::mat data;
+  data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+  typedef BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+      NeighborSearchStat<FurthestNeighborSort> > TreeType;
+  TreeType tree = TreeType(data);
+
+  // BinarySpaceTree modifies data. Use modified data to maintain the
+  // correspondance between points in the dataset for both methods. The order of
+  // query points in both methods should be same.
+  arma::mat naiveQuery(data); // For naive AllkNN.
+
+  NeighborSearch<FurthestNeighborSort, LMetric<2>, TreeType>
+      ballTreeSearch(&tree, data, true);
+
+  AllkFN naive(naiveQuery, true);
+
+  arma::Mat<size_t> ballTreeNeighbors;
+  arma::mat ballTreeDistances;
+  ballTreeSearch.Search(15, ballTreeNeighbors, ballTreeDistances);
+
+  arma::Mat<size_t> naiveNeighbors;
+  arma::mat naiveDistances;
+  naive.Search(15, naiveNeighbors, naiveDistances);
+
+  for (size_t i = 0; i < ballTreeNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(ballTreeNeighbors[i], naiveNeighbors[i]);
+    BOOST_REQUIRE_CLOSE(ballTreeDistances[i], naiveDistances[i], 1e-5);
+  }
+}
+
+/**
+ * Test the ball tree dual-tree furthest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+  arma::mat dataset;
+  data::Load("test_data_3_1000.csv", dataset);
+
+  arma::mat kdtreeData(dataset);
+
+  AllkFN tree(kdtreeData);
+
+  arma::Mat<size_t> kdNeighbors;
+  arma::mat kdDistances;
+  tree.Search(5, kdNeighbors, kdDistances);
+
+  NeighborSearch<FurthestNeighborSort, LMetric<2, true>,
+      BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+      NeighborSearchStat<FurthestNeighborSort> > >
+      ballTreeSearch(dataset);
+
+  arma::Mat<size_t> ballNeighbors;
+  arma::mat ballDistances;
+  ballTreeSearch.Search(5, ballNeighbors, ballDistances);
+
+  for (size_t i = 0; i < ballNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(ballNeighbors(i), kdNeighbors(i));
+    BOOST_REQUIRE_CLOSE(ballDistances(i), kdDistances(i), 1e-5);
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 4ea444f..e76df2d 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -663,6 +663,79 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
   }
 }
 
+/**
+ * Test the ball tree single-tree nearest-neighbors method against the naive
+ * method.  This uses only a random reference dataset.
+ *
+ * Errors are produced if the results are not identical.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+  arma::mat data;
+  data.randu(75, 1000); // 75 dimensional, 1000 points.
+
+  typedef BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+      NeighborSearchStat<NearestNeighborSort> > TreeType;
+  TreeType tree = TreeType(data);
+
+  // BinarySpaceTree modifies data. Use modified data to maintain the
+  // correspondance between points in the dataset for both methods. The order of
+  // query points in both methods should be same.
+  arma::mat naiveQuery(data); // For naive AllkNN.
+
+  NeighborSearch<NearestNeighborSort, LMetric<2>, TreeType>
+      ballTreeSearch(&tree, data, true);
+
+  AllkNN naive(naiveQuery, true);
+
+  arma::Mat<size_t> ballTreeNeighbors;
+  arma::mat ballTreeDistances;
+  ballTreeSearch.Search(1, ballTreeNeighbors, ballTreeDistances);
+
+  arma::Mat<size_t> naiveNeighbors;
+  arma::mat naiveDistances;
+  naive.Search(1, naiveNeighbors, naiveDistances);
+
+  for (size_t i = 0; i < ballTreeNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(ballTreeNeighbors[i], naiveNeighbors[i]);
+    BOOST_REQUIRE_CLOSE(ballTreeDistances[i], naiveDistances[i], 1e-5);
+  }
+}
+
+/**
+ * Test the ball tree dual-tree nearest neighbors method against the naive
+ * method.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+  arma::mat dataset;
+  data::Load("test_data_3_1000.csv", dataset);
+
+  arma::mat kdtreeData(dataset);
+
+  AllkNN tree(kdtreeData);
+
+  arma::Mat<size_t> kdNeighbors;
+  arma::mat kdDistances;
+  tree.Search(5, kdNeighbors, kdDistances);
+
+  NeighborSearch<NearestNeighborSort, LMetric<2, true>,
+      BinarySpaceTree<BallBound<arma::vec, LMetric<2, true> >,
+      NeighborSearchStat<NearestNeighborSort> > >
+      ballTreeSearch(dataset);
+
+  arma::Mat<size_t> ballNeighbors;
+  arma::mat ballDistances;
+  ballTreeSearch.Search(5, ballNeighbors, ballDistances);
+
+  for (size_t i = 0; i < ballNeighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(ballNeighbors(i), kdNeighbors(i));
+    BOOST_REQUIRE_CLOSE(ballDistances(i), kdDistances(i), 1e-5);
+  }
+}
+
 // Make sure sparse nearest neighbors works with kd trees.
 BOOST_AUTO_TEST_CASE(SparseAllkNNKDTreeTest)
 {
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 76da72b..ab86d3d 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -17,6 +17,9 @@
 using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
+using namespace mlpack::tree;
+using namespace mlpack::metric;
+using namespace mlpack::bound;
 
 BOOST_AUTO_TEST_SUITE(AllkRANNTest);
 
@@ -344,7 +347,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
   typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
       RACoverTreeSearch;
 
-  RACoverTreeSearch tssRann(refData, queryData, true);
+  RACoverTreeSearch tssRann(refData, queryData, false, true);
 
   // The relative ranks for the given query reference pair.
   arma::Mat<size_t> qrRanks;
@@ -455,4 +458,137 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
   BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
 }
 
+// Test single-tree rank-approximate search with ball trees.
+// This is known to not work right now.
+/*
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+  arma::mat refData;
+  arma::mat queryData;
+
+  data::Load("rann_test_r_3_900.csv", refData, true);
+  data::Load("rann_test_q_3_100.csv", queryData, true);
+
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  typedef BinarySpaceTree<BallBound<>, RAQueryStat<NearestNeighborSort> >
+      TreeType;
+  typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+      RABallTreeSearch;
+
+  RABallTreeSearch tssRann(refData, queryData, false, true);
+
+  // The relative ranks for the given query reference pair.
+  arma::Mat<size_t> qrRanks;
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+  size_t numRounds = 30;
+  arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+  numSuccessRounds.fill(0);
+
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
+  size_t expectedRankErrorUB = 10;
+
+  for (size_t rounds = 0; rounds < numRounds; rounds++)
+  {
+    tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+    for (size_t i = 0; i < queryData.n_cols; i++)
+      if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+        numSuccessRounds[i]++;
+
+    neighbors.reset();
+    distances.reset();
+  }
+
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+  size_t numQueriesFail = 0;
+  for (size_t i = 0; i < queryData.n_cols; i++)
+    if (numSuccessRounds[i] < threshold)
+      numQueriesFail++;
+
+  Log::Warn << "RANN-TSS (ball tree): RANN guarantee fails on "
+      << numQueriesFail << " queries." << endl;
+
+  // Assert that at most 5% of the queries fall out of this threshold.
+  // 5% of 100 queries is 5.
+  size_t maxNumQueriesFail = 6;
+
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
+// Test dual-tree rank-approximate search with Ball trees.
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+  arma::mat refData;
+  arma::mat queryData;
+
+  data::Load("rann_test_r_3_900.csv", refData, true);
+  data::Load("rann_test_q_3_100.csv", queryData, true);
+
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  typedef BinarySpaceTree<BallBound<>, RAQueryStat<NearestNeighborSort> >
+    TreeType;
+  typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+      RABallTreeSearch;
+
+  TreeType refTree(refData);
+  TreeType queryTree(queryData);
+
+  RABallTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+
+  arma::Mat<size_t> qrRanks;
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+  size_t numRounds = 1000;
+  arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+  numSuccessRounds.fill(0);
+
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
+  size_t expectedRankErrorUB = 10;
+
+  for (size_t rounds = 0; rounds < numRounds; rounds++)
+  {
+    tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+    for (size_t i = 0; i < queryData.n_cols; i++)
+      if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+        numSuccessRounds[i]++;
+
+    neighbors.reset();
+    distances.reset();
+
+    tsdRann.ResetQueryTree();
+  }
+
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+  size_t numQueriesFail = 0;
+  for (size_t i = 0; i < queryData.n_cols; i++)
+    if (numSuccessRounds[i] < threshold)
+      numQueriesFail++;
+
+  Log::Warn << "RANN-TSD (Ball tree): RANN guarantee fails on "
+      << numQueriesFail << " queries." << endl;
+
+  // assert that at most 5% of the queries fall out of this threshold
+  // 5% of 100 queries is 5.
+  size_t maxNumQueriesFail = 6;
+
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+*/
+
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 3af8380..4002d44 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -441,7 +441,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
 }
 
 /**
- * Test the dual-tree nearest-neighbors method with the naive method.  This
+ * Test the dual-tree range search method with the naive method.  This
  * uses both a query and reference dataset.
  *
  * Errors are produced if the results are not identical.
@@ -490,7 +490,7 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive1)
 }
 
 /**
- * Test the dual-tree nearest-neighbors method with the naive method.  This uses
+ * Test the dual-tree range search method with the naive method.  This uses
  * only a reference dataset.
  *
  * Errors are produced if the results are not identical.
@@ -539,7 +539,7 @@ BOOST_AUTO_TEST_CASE(DualTreeVsNaive2)
 }
 
 /**
- * Test the single-tree nearest-neighbors method with the naive method.  This
+ * Test the single-tree range search method with the naive method.  This
  * uses only a reference dataset.
  *
  * Errors are produced if the results are not identical.
@@ -588,8 +588,8 @@ BOOST_AUTO_TEST_CASE(SingleTreeVsNaive)
 }
 
 /**
- * Ensure that range search with cover trees works by comparing with the kd-tree
- * implementation.
+ * Ensure that dual tree range search with cover trees works by comparing 
+ * with the kd-tree implementation.
  */
 BOOST_AUTO_TEST_CASE(CoverTreeTest)
 {
@@ -666,7 +666,8 @@ BOOST_AUTO_TEST_CASE(CoverTreeTest)
 }
 
 /**
- * Ensure that range search with cover trees works when using two datasets.
+ * Ensure that dual tree range search with cover trees works when using
+ * two datasets.
  */
 BOOST_AUTO_TEST_CASE(CoverTreeTwoDatasetsTest)
 {
@@ -824,4 +825,239 @@ BOOST_AUTO_TEST_CASE(CoverTreeSingleTreeTest)
   }
 }
 
+/**
+ * Ensure that single-tree ball tree range search works.
+ */
+BOOST_AUTO_TEST_CASE(SingleBallTreeTest)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  // Set up ball tree range search.
+  typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+  TreeType tree(data);
+  RangeSearch<metric::EuclideanDistance, TreeType>
+      ballsearch(&tree, data, true);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    // Set up kd-tree range search.
+    RangeSearch<> kdsearch(data, true);
+
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.5, 1.5);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(0.8, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for ball tree search.
+    vector<vector<size_t> > ballNeighbors;
+    vector<vector<double> > ballDistances;
+
+    // Clean the tree statistics.
+    CleanTree(tree);
+
+    // Run the searches.
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+    ballsearch.Search(range, ballNeighbors, ballDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > ballSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(ballNeighbors, ballDistances, ballSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+            1e-5);
+      }
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+    }
+  }
+}
+
+/**
+ * Ensure that dual tree range search with ball trees works by comparing 
+ * with the kd-tree implementation.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  // Set up ball tree range search.
+  typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+  TreeType tree(data);
+  RangeSearch<metric::EuclideanDistance, TreeType> ballsearch(&tree, data);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    // Set up kd-tree range search.
+    RangeSearch<> kdsearch(data);
+
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.5, 1.5);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(0.8, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for ball tree search.
+    vector<vector<size_t> > ballNeighbors;
+    vector<vector<double> > ballDistances;
+
+    // Clean the tree statistics.
+    CleanTree(tree);
+
+    // Run the searches.
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+    ballsearch.Search(range, ballNeighbors, ballDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > ballSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(ballNeighbors, ballDistances, ballSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+            1e-5);
+      }
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+    }
+  }
+}
+
+/**
+ * Ensure that dual tree range search with ball trees works when using
+ * two datasets.
+ */
+BOOST_AUTO_TEST_CASE(DualBallTreeTest2)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  arma::mat queries;
+  queries.randu(8, 350); // 350 points in 8 dimensions.
+
+  // Set up ball tree range search.
+  typedef BinarySpaceTree<BallBound<>, RangeSearchStat> TreeType;
+  TreeType tree(data);
+  TreeType queryTree(queries);
+  RangeSearch<metric::EuclideanDistance, TreeType>
+      ballsearch(&tree, &queryTree, data, queries);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    // Set up kd-tree range search.  We don't have an easy way to rebuild the
+    // tree, so we'll just reinstantiate it here each loop time.
+    RangeSearch<> kdsearch(data, queries);
+
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.85, 1.05);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(1.35, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for ball tree search.
+    vector<vector<size_t> > ballNeighbors;
+    vector<vector<double> > ballDistances;
+
+    // Clean the trees.
+    CleanTree(tree);
+    CleanTree(queryTree);
+
+    // Run the searches.
+    ballsearch.Search(range, ballNeighbors, ballDistances);
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > ballSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(ballNeighbors, ballDistances, ballSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, ballSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, ballSorted[i][j].first,
+            1e-5);
+      }
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), ballSorted[i].size());
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 1971185..eaa6c79 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1206,7 +1206,7 @@ BOOST_AUTO_TEST_CASE(ParentDistanceTestWithMapping)
 
 // Forward declaration of methods we need for the next test.
 template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data);
+bool CheckPointBounds(TreeType& node, const MatType& data);
 
 template<typename TreeType>
 void GenerateVectorOfTree(TreeType* node,
@@ -1274,7 +1274,7 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
     }
 
     // Now check that each point is contained inside of all bounds above it.
-    CheckPointBounds(&root, dataset);
+    CheckPointBounds(root, dataset);
 
     // Now check that no peers overlap.
     std::vector<TreeType*> v;
@@ -1309,20 +1309,74 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
 
 // Recursively checks that each node contains all points that it claims to have.
 template<typename TreeType, typename MatType>
-bool CheckPointBounds(TreeType* node, const MatType& data)
+bool CheckPointBounds(TreeType& node, const MatType& data)
 {
-  if (node == NULL) // We have passed a leaf node.
-    return true;
-
-  TreeType* left = node->Left();
-  TreeType* right = node->Right();
-
   // Check that each point which this tree claims is actually inside the tree.
-  for (size_t index = 0; index < node->NumDescendants(); index++)
-    if (!node->Bound().Contains(data.col(node->Descendant(index))))
+  for (size_t index = 0; index < node.NumDescendants(); index++)
+    if (!node.Bound().Contains(data.col(node.Descendant(index))))
       return false;
 
-  return CheckPointBounds(left, data) && CheckPointBounds(right, data);
+  bool result = true;
+  for (size_t child = 0; child < node.NumChildren(); ++child)
+    result &= CheckPointBounds(node.Child(child), data);
+  return result;
+}
+
+/**
+ * Exhaustive ball tree test based on #125.
+ *
+ * - Generate a random dataset of a random size.
+ * - Build a tree on that dataset.
+ * - Ensure all the permutation indices map back to the correct points.
+ * - Verify that each point is contained inside all of the bounds of its parent
+ *     nodes.
+ *
+ * Then, we do that whole process a handful of times.
+ */
+BOOST_AUTO_TEST_CASE(BallTreeTest)
+{
+  typedef BinarySpaceTree<BallBound<> > TreeType;
+
+  size_t maxRuns = 10; // Ten total tests.
+  size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
+
+  // We use the default leaf size of 20.
+  for(size_t run = 0; run < maxRuns; run++)
+  {
+    size_t dimensions = run + 2;
+    size_t maxPoints = (run + 1) * pointIncrements;
+
+    size_t size = maxPoints;
+    arma::mat dataset = arma::mat(dimensions, size);
+    arma::mat datacopy; // Used to test mappings.
+
+    // Mappings for post-sort verification of data.
+    std::vector<size_t> newToOld;
+    std::vector<size_t> oldToNew;
+
+    // Generate data.
+    dataset.randu();
+    datacopy = dataset; // Save a copy.
+
+    // Build the tree itself.
+    TreeType root(dataset, newToOld, oldToNew);
+
+    // Ensure the size of the tree is correct.
+    BOOST_REQUIRE_EQUAL(root.NumDescendants(), size);
+
+    // Check the forward and backward mappings for correctness.
+    for(size_t i = 0; i < size; i++)
+    {
+      for(size_t j = 0; j < dimensions; j++)
+      {
+        BOOST_REQUIRE_EQUAL(dataset(j, i), datacopy(j, newToOld[i]));
+        BOOST_REQUIRE_EQUAL(dataset(j, oldToNew[i]), datacopy(j, i));
+      }
+    }
+
+    // Now check that each point is contained inside of all bounds above it.
+    CheckPointBounds(root, dataset);
+  }
 }
 
 template<int t_pow>
@@ -1429,7 +1483,7 @@ BOOST_AUTO_TEST_CASE(ExhaustiveSparseKDTreeTest)
     }
 
     // Now check that each point is contained inside of all bounds above it.
-    CheckPointBounds(&root, dataset);
+    CheckPointBounds(root, dataset);
 
     // Now check that no peers overlap.
     std::vector<TreeType*> v;



More information about the mlpack-git mailing list