[mlpack-svn] r15043 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 8 22:22:54 EDT 2013
Author: rcurtin
Date: 2013-05-08 22:22:54 -0400 (Wed, 08 May 2013)
New Revision: 15043
Modified:
mlpack/trunk/src/mlpack/tests/range_search_test.cpp
Log:
Comprehensive test for range search with cover trees.
Modified: mlpack/trunk/src/mlpack/tests/range_search_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/range_search_test.cpp 2013-05-09 02:22:41 UTC (rev 15042)
+++ mlpack/trunk/src/mlpack/tests/range_search_test.cpp 2013-05-09 02:22:54 UTC (rev 15043)
@@ -6,6 +6,7 @@
*/
#include <mlpack/core.hpp>
#include <mlpack/methods/range_search/range_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -66,72 +67,69 @@
arma::mat dataMutable = data;
switch (i)
{
- case 0: // Use the dual-tree method.
- rs = new RangeSearch<>(dataMutable, false, false, 1);
+ case 0: // Use the naive method.
+ rs = new RangeSearch<>(dataMutable, true);
break;
case 1: // Use the single-tree method.
rs = new RangeSearch<>(dataMutable, false, true, 1);
break;
- case 2: // Use the naive method.
- rs = new RangeSearch<>(dataMutable, true);
+ case 2: // Use the dual-tree method.
+ rs = new RangeSearch<>(dataMutable, false, false, 1);
break;
}
// Now perform the first calculation. Points within 0.50.
vector<vector<size_t> > neighbors;
vector<vector<double> > distances;
- rs->Search(Range(0.0, 0.50), neighbors, distances);
+ rs->Search(Range(0.0, sqrt(0.5)), neighbors, distances);
- // Now the exhaustive check for correctness. This will be long. We must
- // also remember that the distances returned are squared distances. As a
- // result, distance comparisons are written out as (distance * distance) for
- // readability.
+ // Now the exhaustive check for correctness. This will be long.
vector<vector<pair<double, size_t> > > sortedOutput;
SortResults(neighbors, distances, sortedOutput);
// Neighbors of point 0.
BOOST_REQUIRE(sortedOutput[0].size() == 4);
BOOST_REQUIRE(sortedOutput[0][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[0][1].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 0.27, 1e-5);
BOOST_REQUIRE(sortedOutput[0][2].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, 0.30, 1e-5);
BOOST_REQUIRE(sortedOutput[0][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (0.40 * 0.40), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, 0.40, 1e-5);
// Neighbors of point 1.
BOOST_REQUIRE(sortedOutput[1].size() == 6);
BOOST_REQUIRE(sortedOutput[1][0].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[1][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, 0.20, 1e-5);
BOOST_REQUIRE(sortedOutput[1][2].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, 0.30, 1e-5);
BOOST_REQUIRE(sortedOutput[1][3].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, 0.55, 1e-5);
BOOST_REQUIRE(sortedOutput[1][4].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, 0.57, 1e-5);
BOOST_REQUIRE(sortedOutput[1][5].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, (0.65 * 0.65), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, 0.65, 1e-5);
// Neighbors of point 2.
BOOST_REQUIRE(sortedOutput[2].size() == 4);
BOOST_REQUIRE(sortedOutput[2][0].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[2][1].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.20 * 0.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 0.20, 1e-5);
BOOST_REQUIRE(sortedOutput[2][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, 0.30, 1e-5);
BOOST_REQUIRE(sortedOutput[2][3].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (0.37 * 0.37), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, 0.37, 1e-5);
// Neighbors of point 3.
BOOST_REQUIRE(sortedOutput[3].size() == 2);
BOOST_REQUIRE(sortedOutput[3][0].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 0.25, 1e-5);
BOOST_REQUIRE(sortedOutput[3][1].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.35 * 0.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 0.35, 1e-5);
// Neighbors of point 4.
BOOST_REQUIRE(sortedOutput[4].size() == 0);
@@ -139,90 +137,90 @@
// Neighbors of point 5.
BOOST_REQUIRE(sortedOutput[5].size() == 4);
BOOST_REQUIRE(sortedOutput[5][0].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (0.27 * 0.27), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, 0.27, 1e-5);
BOOST_REQUIRE(sortedOutput[5][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (0.37 * 0.37), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, 0.37, 1e-5);
BOOST_REQUIRE(sortedOutput[5][2].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (0.57 * 0.57), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, 0.57, 1e-5);
BOOST_REQUIRE(sortedOutput[5][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (0.67 * 0.67), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, 0.67, 1e-5);
// Neighbors of point 6.
BOOST_REQUIRE(sortedOutput[6].size() == 1);
BOOST_REQUIRE(sortedOutput[6][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (0.70 * 0.70), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, 0.70, 1e-5);
// Neighbors of point 7.
BOOST_REQUIRE(sortedOutput[7].size() == 1);
BOOST_REQUIRE(sortedOutput[7][0].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (0.70 * 0.70), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, 0.70, 1e-5);
// Neighbors of point 8.
BOOST_REQUIRE(sortedOutput[8].size() == 6);
BOOST_REQUIRE(sortedOutput[8][0].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[8][1].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (0.30 * 0.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, 0.30, 1e-5);
BOOST_REQUIRE(sortedOutput[8][2].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (0.40 * 0.40), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, 0.40, 1e-5);
BOOST_REQUIRE(sortedOutput[8][3].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, 0.45, 1e-5);
BOOST_REQUIRE(sortedOutput[8][4].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, 0.55, 1e-5);
BOOST_REQUIRE(sortedOutput[8][5].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, (0.67 * 0.67), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, 0.67, 1e-5);
// Neighbors of point 9.
BOOST_REQUIRE(sortedOutput[9].size() == 4);
BOOST_REQUIRE(sortedOutput[9][0].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[9][1].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.35 * 0.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 0.35, 1e-5);
BOOST_REQUIRE(sortedOutput[9][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (0.45 * 0.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, 0.45, 1e-5);
BOOST_REQUIRE(sortedOutput[9][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, 0.55, 1e-5);
// Neighbors of point 10.
BOOST_REQUIRE(sortedOutput[10].size() == 4);
BOOST_REQUIRE(sortedOutput[10][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.10 * 0.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 0.10, 1e-5);
BOOST_REQUIRE(sortedOutput[10][1].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.25 * 0.25), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 0.25, 1e-5);
BOOST_REQUIRE(sortedOutput[10][2].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (0.55 * 0.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, 0.55, 1e-5);
BOOST_REQUIRE(sortedOutput[10][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (0.65 * 0.65), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, 0.65, 1e-5);
- // Now do it again with a different range: [0.5 1.0].
- rs->Search(Range(0.5, 1.0), neighbors, distances);
+ // Now do it again with a different range: [sqrt(0.5) 1.0].
+ rs->Search(Range(sqrt(0.5), 1.0), neighbors, distances);
SortResults(neighbors, distances, sortedOutput);
// Neighbors of point 0.
BOOST_REQUIRE(sortedOutput[0].size() == 2);
BOOST_REQUIRE(sortedOutput[0][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 0.85, 1e-5);
BOOST_REQUIRE(sortedOutput[0][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.95 * 0.95), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 0.95, 1e-5);
// Neighbors of point 1.
BOOST_REQUIRE(sortedOutput[1].size() == 1);
BOOST_REQUIRE(sortedOutput[1][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.90 * 0.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 0.90, 1e-5);
// Neighbors of point 2.
BOOST_REQUIRE(sortedOutput[2].size() == 2);
BOOST_REQUIRE(sortedOutput[2][0].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 0.75, 1e-5);
BOOST_REQUIRE(sortedOutput[2][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 0.85, 1e-5);
// Neighbors of point 3.
BOOST_REQUIRE(sortedOutput[3].size() == 2);
BOOST_REQUIRE(sortedOutput[3][0].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.80 * 0.80), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 0.80, 1e-5);
BOOST_REQUIRE(sortedOutput[3][1].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.90 * 0.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 0.90, 1e-5);
// Neighbors of point 4.
BOOST_REQUIRE(sortedOutput[4].size() == 0);
@@ -239,21 +237,21 @@
// Neighbors of point 8.
BOOST_REQUIRE(sortedOutput[8].size() == 1);
BOOST_REQUIRE(sortedOutput[8][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.80 * 0.80), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 0.80, 1e-5);
// Neighbors of point 9.
BOOST_REQUIRE(sortedOutput[9].size() == 2);
BOOST_REQUIRE(sortedOutput[9][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.75 * 0.75), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 0.75, 1e-5);
BOOST_REQUIRE(sortedOutput[9][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 0.85, 1e-5);
// Neighbors of point 10.
BOOST_REQUIRE(sortedOutput[10].size() == 2);
BOOST_REQUIRE(sortedOutput[10][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.85 * 0.85), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 0.85, 1e-5);
BOOST_REQUIRE(sortedOutput[10][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.95 * 0.95), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 0.95, 1e-5);
// Now do it again with a different range: [1.0 inf].
rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
@@ -263,159 +261,159 @@
// Neighbors of point 0.
BOOST_REQUIRE(sortedOutput[0].size() == 4);
BOOST_REQUIRE(sortedOutput[0][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 1.20, 1e-5);
BOOST_REQUIRE(sortedOutput[0][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 1.35, 1e-5);
BOOST_REQUIRE(sortedOutput[0][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, 2.05, 1e-5);
BOOST_REQUIRE(sortedOutput[0][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (5.00 * 5.00), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, 5.00, 1e-5);
// Neighbors of point 1.
BOOST_REQUIRE(sortedOutput[1].size() == 3);
BOOST_REQUIRE(sortedOutput[1][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 1.65, 1e-5);
BOOST_REQUIRE(sortedOutput[1][1].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, 2.35, 1e-5);
BOOST_REQUIRE(sortedOutput[1][2].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (4.70 * 4.70), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, 4.70, 1e-5);
// Neighbors of point 2.
BOOST_REQUIRE(sortedOutput[2].size() == 4);
BOOST_REQUIRE(sortedOutput[2][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 1.10, 1e-5);
BOOST_REQUIRE(sortedOutput[2][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 1.45, 1e-5);
BOOST_REQUIRE(sortedOutput[2][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, 2.15, 1e-5);
BOOST_REQUIRE(sortedOutput[2][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (4.90 * 4.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, 4.90, 1e-5);
// Neighbors of point 3.
BOOST_REQUIRE(sortedOutput[3].size() == 6);
BOOST_REQUIRE(sortedOutput[3][0].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (1.10 * 1.10), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 1.10, 1e-5);
BOOST_REQUIRE(sortedOutput[3][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (1.20 * 1.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 1.20, 1e-5);
BOOST_REQUIRE(sortedOutput[3][2].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, 1.47, 1e-5);
BOOST_REQUIRE(sortedOutput[3][3].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, 2.55, 1e-5);
BOOST_REQUIRE(sortedOutput[3][4].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, 3.25, 1e-5);
BOOST_REQUIRE(sortedOutput[3][5].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, (3.80 * 3.80), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, 3.80, 1e-5);
// Neighbors of point 4.
BOOST_REQUIRE(sortedOutput[4].size() == 10);
BOOST_REQUIRE(sortedOutput[4][0].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, (3.80 * 3.80), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, 3.80, 1e-5);
BOOST_REQUIRE(sortedOutput[4][1].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, (4.05 * 4.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, 4.05, 1e-5);
BOOST_REQUIRE(sortedOutput[4][2].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, (4.15 * 4.15), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, 4.15, 1e-5);
BOOST_REQUIRE(sortedOutput[4][3].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, (4.60 * 4.60), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, 4.60, 1e-5);
BOOST_REQUIRE(sortedOutput[4][4].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, (4.70 * 4.70), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, 4.70, 1e-5);
BOOST_REQUIRE(sortedOutput[4][5].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, (4.90 * 4.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, 4.90, 1e-5);
BOOST_REQUIRE(sortedOutput[4][6].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, (5.00 * 5.00), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, 5.00, 1e-5);
BOOST_REQUIRE(sortedOutput[4][7].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, (5.27 * 5.27), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, 5.27, 1e-5);
BOOST_REQUIRE(sortedOutput[4][8].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, (6.35 * 6.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, 6.35, 1e-5);
BOOST_REQUIRE(sortedOutput[4][9].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, (7.05 * 7.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, 7.05, 1e-5);
// Neighbors of point 5.
BOOST_REQUIRE(sortedOutput[5].size() == 6);
BOOST_REQUIRE(sortedOutput[5][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, 1.08, 1e-5);
BOOST_REQUIRE(sortedOutput[5][1].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, 1.12, 1e-5);
BOOST_REQUIRE(sortedOutput[5][2].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, 1.22, 1e-5);
BOOST_REQUIRE(sortedOutput[5][3].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (1.47 * 1.47), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, 1.47, 1e-5);
BOOST_REQUIRE(sortedOutput[5][4].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, 1.78, 1e-5);
BOOST_REQUIRE(sortedOutput[5][5].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, (5.27 * 5.27), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, 5.27, 1e-5);
// Neighbors of point 6.
BOOST_REQUIRE(sortedOutput[6].size() == 9);
BOOST_REQUIRE(sortedOutput[6][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (1.78 * 1.78), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, 1.78, 1e-5);
BOOST_REQUIRE(sortedOutput[6][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, (2.05 * 2.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, 2.05, 1e-5);
BOOST_REQUIRE(sortedOutput[6][2].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, (2.15 * 2.15), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, 2.15, 1e-5);
BOOST_REQUIRE(sortedOutput[6][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, (2.35 * 2.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, 2.35, 1e-5);
BOOST_REQUIRE(sortedOutput[6][4].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, 2.45, 1e-5);
BOOST_REQUIRE(sortedOutput[6][5].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, 2.90, 1e-5);
BOOST_REQUIRE(sortedOutput[6][6].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, 3.00, 1e-5);
BOOST_REQUIRE(sortedOutput[6][7].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, (3.25 * 3.25), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, 3.25, 1e-5);
BOOST_REQUIRE(sortedOutput[6][8].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, (7.05 * 7.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, 7.05, 1e-5);
// Neighbors of point 7.
BOOST_REQUIRE(sortedOutput[7].size() == 9);
BOOST_REQUIRE(sortedOutput[7][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (1.08 * 1.08), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, 1.08, 1e-5);
BOOST_REQUIRE(sortedOutput[7][1].second == 0);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, (1.35 * 1.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, 1.35, 1e-5);
BOOST_REQUIRE(sortedOutput[7][2].second == 2);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, (1.45 * 1.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, 1.45, 1e-5);
BOOST_REQUIRE(sortedOutput[7][3].second == 1);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, (1.65 * 1.65), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, 1.65, 1e-5);
BOOST_REQUIRE(sortedOutput[7][4].second == 8);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, 1.75, 1e-5);
BOOST_REQUIRE(sortedOutput[7][5].second == 9);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, 2.20, 1e-5);
BOOST_REQUIRE(sortedOutput[7][6].second == 10);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, 2.30, 1e-5);
BOOST_REQUIRE(sortedOutput[7][7].second == 3);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, (2.55 * 2.55), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, 2.55, 1e-5);
BOOST_REQUIRE(sortedOutput[7][8].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, (6.35 * 6.35), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, 6.35, 1e-5);
// Neighbors of point 8.
BOOST_REQUIRE(sortedOutput[8].size() == 3);
BOOST_REQUIRE(sortedOutput[8][0].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (1.75 * 1.75), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 1.75, 1e-5);
BOOST_REQUIRE(sortedOutput[8][1].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (2.45 * 2.45), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, 2.45, 1e-5);
BOOST_REQUIRE(sortedOutput[8][2].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (4.60 * 4.60), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, 4.60, 1e-5);
// Neighbors of point 9.
BOOST_REQUIRE(sortedOutput[9].size() == 4);
BOOST_REQUIRE(sortedOutput[9][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (1.12 * 1.12), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 1.12, 1e-5);
BOOST_REQUIRE(sortedOutput[9][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (2.20 * 2.20), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 2.20, 1e-5);
BOOST_REQUIRE(sortedOutput[9][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (2.90 * 2.90), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, 2.90, 1e-5);
BOOST_REQUIRE(sortedOutput[9][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (4.15 * 4.15), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, 4.15, 1e-5);
// Neighbors of point 10.
BOOST_REQUIRE(sortedOutput[10].size() == 4);
BOOST_REQUIRE(sortedOutput[10][0].second == 5);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (1.22 * 1.22), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 1.22, 1e-5);
BOOST_REQUIRE(sortedOutput[10][1].second == 7);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (2.30 * 2.30), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 2.30, 1e-5);
BOOST_REQUIRE(sortedOutput[10][2].second == 6);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (3.00 * 3.00), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, 3.00, 1e-5);
BOOST_REQUIRE(sortedOutput[10][3].second == 4);
- BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (4.05 * 4.05), 1e-5);
+ BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, 4.05, 1e-5);
// Clean the memory.
delete rs;
@@ -486,7 +484,7 @@
if (!data::Load("test_data_3_1000.csv", dataForTree))
BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
- // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+ // Set up matrices to work with.
arma::mat dualQuery(dataForTree);
arma::mat naiveQuery(dataForTree);
@@ -569,4 +567,222 @@
}
}
+/**
+ * Ensure that range search with cover trees works by comparing with the kd-tree
+ * implementation.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeTest)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ // Set up kd-tree range search.
+ RangeSearch<> kdsearch(data);
+ // Set up cover tree range search.
+ tree::CoverTree<> tree(data);
+ RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+ coversearch(&tree, data);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.5, 1.5);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(0.8, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for cover tree search.
+ vector<vector<size_t> > coverNeighbors;
+ vector<vector<double> > coverDistances;
+
+ // Run the searches.
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+ coversearch.Search(range, coverNeighbors, coverDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > coverSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(coverNeighbors, coverDistances, coverSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+ 1e-5);
+ }
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+ }
+ }
+}
+
+/**
+ * Ensure that range search with cover trees works when using two datasets.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeTwoDatasetsTest)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+ arma::mat queries;
+ queries.randu(8, 350); // 350 points in 8 dimensions.
+
+ // Set up kd-tree range search.
+ RangeSearch<> kdsearch(data, queries);
+ // Set up cover tree range search.
+ tree::CoverTree<> tree(data);
+ tree::CoverTree<> queryTree(queries);
+ RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+ coversearch(&tree, &queryTree, data, queries);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.5, 1.5);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(0.8, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for cover tree search.
+ vector<vector<size_t> > coverNeighbors;
+ vector<vector<double> > coverDistances;
+
+ // Run the searches.
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+ coversearch.Search(range, coverNeighbors, coverDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > coverSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(coverNeighbors, coverDistances, coverSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+ 1e-5);
+ }
+ }
+ }
+}
+
+/**
+ * Ensure that single-tree cover tree range search works.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeSingleTreeTest)
+{
+ arma::mat data;
+ data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+ // Set up kd-tree range search.
+ RangeSearch<> kdsearch(data, true);
+ // Set up cover tree range search.
+ tree::CoverTree<> tree(data);
+ RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+ coversearch(&tree, data, true);
+
+ // Four trials with different ranges.
+ for (size_t r = 0; r < 4; ++r)
+ {
+ Range range;
+ switch (r)
+ {
+ case 0:
+ // Includes zero distance.
+ range = Range(0.0, 0.75);
+ break;
+ case 1:
+ // A bounded range on both sides.
+ range = Range(0.5, 1.5);
+ break;
+ case 2:
+ // A range with no upper bound.
+ range = Range(0.8, DBL_MAX);
+ break;
+ case 3:
+ // A range which should have no results.
+ range = Range(15.6, 15.7);
+ break;
+ }
+
+ // Results for kd-tree search.
+ vector<vector<size_t> > kdNeighbors;
+ vector<vector<double> > kdDistances;
+
+ // Results for cover tree search.
+ vector<vector<size_t> > coverNeighbors;
+ vector<vector<double> > coverDistances;
+
+ // Run the searches.
+ kdsearch.Search(range, kdNeighbors, kdDistances);
+ coversearch.Search(range, coverNeighbors, coverDistances);
+
+ // Sort before comparison.
+ vector<vector<pair<double, size_t> > > kdSorted;
+ vector<vector<pair<double, size_t> > > coverSorted;
+ SortResults(kdNeighbors, kdDistances, kdSorted);
+ SortResults(coverNeighbors, coverDistances, coverSorted);
+
+ // Now compare the results.
+ for (size_t i = 0; i < kdSorted.size(); ++i)
+ {
+ for (size_t j = 0; j < kdSorted[i].size(); ++j)
+ {
+ BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+ BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+ 1e-5);
+ }
+ BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list