[mlpack-svn] r15043 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed May 8 22:22:54 EDT 2013


Author: rcurtin
Date: 2013-05-08 22:22:54 -0400 (Wed, 08 May 2013)
New Revision: 15043

Modified:
   mlpack/trunk/src/mlpack/tests/range_search_test.cpp
Log:
Comprehensive test for range search with cover trees.


Modified: mlpack/trunk/src/mlpack/tests/range_search_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/range_search_test.cpp	2013-05-09 02:22:41 UTC (rev 15042)
+++ mlpack/trunk/src/mlpack/tests/range_search_test.cpp	2013-05-09 02:22:54 UTC (rev 15043)
@@ -6,6 +6,7 @@
  */
 #include <mlpack/core.hpp>
 #include <mlpack/methods/range_search/range_search.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
@@ -66,72 +67,69 @@
     arma::mat dataMutable = data;
     switch (i)
     {
-      case 0: // Use the dual-tree method.
-        rs = new RangeSearch<>(dataMutable, false, false, 1);
+      case 0: // Use the naive method.
+        rs = new RangeSearch<>(dataMutable, true);
         break;
       case 1: // Use the single-tree method.
         rs = new RangeSearch<>(dataMutable, false, true, 1);
         break;
-      case 2: // Use the naive method.
-        rs = new RangeSearch<>(dataMutable, true);
+      case 2: // Use the dual-tree method.
+        rs = new RangeSearch<>(dataMutable, false, false, 1);
         break;
     }
 
     // Now perform the first calculation.  Points within 0.50.
     vector<vector<size_t> > neighbors;
     vector<vector<double> > distances;
-    rs->Search(Range(0.0, 0.50), neighbors, distances);
+    rs->Search(Range(0.0, sqrt(0.5)), neighbors, distances);
 
-    // Now the exhaustive check for correctness.  This will be long.  We must
-    // also remember that the distances returned are squared distances.  As a
-    // result, distance comparisons are written out as (distance * distance) for
-    // readability.
+    // Now the exhaustive check for correctness.  This will be long.
     vector<vector<pair<double, size_t> > > sortedOutput;
     SortResults(neighbors, distances, sortedOutput);
 
     // Neighbors of point 0.
     BOOST_REQUIRE(sortedOutput[0].size() == 4);
     BOOST_REQUIRE(sortedOutput[0][0].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][1].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.27 * 0.27), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 0.27, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][2].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (0.30 * 0.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, 0.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][3].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (0.40 * 0.40), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, 0.40, 1e-5);
 
     // Neighbors of point 1.
     BOOST_REQUIRE(sortedOutput[1].size() == 6);
     BOOST_REQUIRE(sortedOutput[1][0].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][1].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (0.20 * 0.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, 0.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][2].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (0.30 * 0.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, 0.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][3].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, (0.55 * 0.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][3].first, 0.55, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][4].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, (0.57 * 0.57), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][4].first, 0.57, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][5].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, (0.65 * 0.65), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][5].first, 0.65, 1e-5);
 
     // Neighbors of point 2.
     BOOST_REQUIRE(sortedOutput[2].size() == 4);
     BOOST_REQUIRE(sortedOutput[2][0].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][1].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.20 * 0.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 0.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][2].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (0.30 * 0.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, 0.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][3].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (0.37 * 0.37), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, 0.37, 1e-5);
 
     // Neighbors of point 3.
     BOOST_REQUIRE(sortedOutput[3].size() == 2);
     BOOST_REQUIRE(sortedOutput[3][0].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.25 * 0.25), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 0.25, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][1].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.35 * 0.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 0.35, 1e-5);
 
     // Neighbors of point 4.
     BOOST_REQUIRE(sortedOutput[4].size() == 0);
@@ -139,90 +137,90 @@
     // Neighbors of point 5.
     BOOST_REQUIRE(sortedOutput[5].size() == 4);
     BOOST_REQUIRE(sortedOutput[5][0].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (0.27 * 0.27), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, 0.27, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][1].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (0.37 * 0.37), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, 0.37, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][2].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (0.57 * 0.57), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, 0.57, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][3].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (0.67 * 0.67), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, 0.67, 1e-5);
 
     // Neighbors of point 6.
     BOOST_REQUIRE(sortedOutput[6].size() == 1);
     BOOST_REQUIRE(sortedOutput[6][0].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (0.70 * 0.70), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, 0.70, 1e-5);
 
     // Neighbors of point 7.
     BOOST_REQUIRE(sortedOutput[7].size() == 1);
     BOOST_REQUIRE(sortedOutput[7][0].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (0.70 * 0.70), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, 0.70, 1e-5);
 
     // Neighbors of point 8.
     BOOST_REQUIRE(sortedOutput[8].size() == 6);
     BOOST_REQUIRE(sortedOutput[8][0].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][1].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (0.30 * 0.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, 0.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][2].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (0.40 * 0.40), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, 0.40, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][3].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, (0.45 * 0.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][3].first, 0.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][4].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, (0.55 * 0.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][4].first, 0.55, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][5].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, (0.67 * 0.67), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][5].first, 0.67, 1e-5);
 
     // Neighbors of point 9.
     BOOST_REQUIRE(sortedOutput[9].size() == 4);
     BOOST_REQUIRE(sortedOutput[9][0].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][1].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.35 * 0.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 0.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][2].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (0.45 * 0.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, 0.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][3].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (0.55 * 0.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, 0.55, 1e-5);
 
     // Neighbors of point 10.
     BOOST_REQUIRE(sortedOutput[10].size() == 4);
     BOOST_REQUIRE(sortedOutput[10][0].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.10 * 0.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 0.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][1].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.25 * 0.25), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 0.25, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][2].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (0.55 * 0.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, 0.55, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][3].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (0.65 * 0.65), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, 0.65, 1e-5);
 
-    // Now do it again with a different range: [0.5 1.0].
-    rs->Search(Range(0.5, 1.0), neighbors, distances);
+    // Now do it again with a different range: [sqrt(0.5) 1.0].
+    rs->Search(Range(sqrt(0.5), 1.0), neighbors, distances);
     SortResults(neighbors, distances, sortedOutput);
 
     // Neighbors of point 0.
     BOOST_REQUIRE(sortedOutput[0].size() == 2);
     BOOST_REQUIRE(sortedOutput[0][0].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (0.85 * 0.85), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 0.85, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][1].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (0.95 * 0.95), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 0.95, 1e-5);
 
     // Neighbors of point 1.
     BOOST_REQUIRE(sortedOutput[1].size() == 1);
     BOOST_REQUIRE(sortedOutput[1][0].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (0.90 * 0.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 0.90, 1e-5);
 
     // Neighbors of point 2.
     BOOST_REQUIRE(sortedOutput[2].size() == 2);
     BOOST_REQUIRE(sortedOutput[2][0].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (0.75 * 0.75), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 0.75, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][1].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (0.85 * 0.85), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 0.85, 1e-5);
 
     // Neighbors of point 3.
     BOOST_REQUIRE(sortedOutput[3].size() == 2);
     BOOST_REQUIRE(sortedOutput[3][0].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (0.80 * 0.80), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 0.80, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][1].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (0.90 * 0.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 0.90, 1e-5);
 
     // Neighbors of point 4.
     BOOST_REQUIRE(sortedOutput[4].size() == 0);
@@ -239,21 +237,21 @@
     // Neighbors of point 8.
     BOOST_REQUIRE(sortedOutput[8].size() == 1);
     BOOST_REQUIRE(sortedOutput[8][0].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (0.80 * 0.80), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 0.80, 1e-5);
 
     // Neighbors of point 9.
     BOOST_REQUIRE(sortedOutput[9].size() == 2);
     BOOST_REQUIRE(sortedOutput[9][0].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (0.75 * 0.75), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 0.75, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][1].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (0.85 * 0.85), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 0.85, 1e-5);
 
     // Neighbors of point 10.
     BOOST_REQUIRE(sortedOutput[10].size() == 2);
     BOOST_REQUIRE(sortedOutput[10][0].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (0.85 * 0.85), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 0.85, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][1].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (0.95 * 0.95), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 0.95, 1e-5);
 
     // Now do it again with a different range: [1.0 inf].
     rs->Search(Range(1.0, numeric_limits<double>::infinity()), neighbors,
@@ -263,159 +261,159 @@
     // Neighbors of point 0.
     BOOST_REQUIRE(sortedOutput[0].size() == 4);
     BOOST_REQUIRE(sortedOutput[0][0].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, (1.20 * 1.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][0].first, 1.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][1].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, (1.35 * 1.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][1].first, 1.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][2].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, (2.05 * 2.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][2].first, 2.05, 1e-5);
     BOOST_REQUIRE(sortedOutput[0][3].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, (5.00 * 5.00), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[0][3].first, 5.00, 1e-5);
 
     // Neighbors of point 1.
     BOOST_REQUIRE(sortedOutput[1].size() == 3);
     BOOST_REQUIRE(sortedOutput[1][0].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, (1.65 * 1.65), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][0].first, 1.65, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][1].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, (2.35 * 2.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][1].first, 2.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[1][2].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, (4.70 * 4.70), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[1][2].first, 4.70, 1e-5);
 
     // Neighbors of point 2.
     BOOST_REQUIRE(sortedOutput[2].size() == 4);
     BOOST_REQUIRE(sortedOutput[2][0].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, (1.10 * 1.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][0].first, 1.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][1].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, (1.45 * 1.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][1].first, 1.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][2].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, (2.15 * 2.15), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][2].first, 2.15, 1e-5);
     BOOST_REQUIRE(sortedOutput[2][3].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, (4.90 * 4.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[2][3].first, 4.90, 1e-5);
 
     // Neighbors of point 3.
     BOOST_REQUIRE(sortedOutput[3].size() == 6);
     BOOST_REQUIRE(sortedOutput[3][0].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, (1.10 * 1.10), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][0].first, 1.10, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][1].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, (1.20 * 1.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][1].first, 1.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][2].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, (1.47 * 1.47), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][2].first, 1.47, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][3].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, (2.55 * 2.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][3].first, 2.55, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][4].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, (3.25 * 3.25), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][4].first, 3.25, 1e-5);
     BOOST_REQUIRE(sortedOutput[3][5].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, (3.80 * 3.80), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[3][5].first, 3.80, 1e-5);
 
     // Neighbors of point 4.
     BOOST_REQUIRE(sortedOutput[4].size() == 10);
     BOOST_REQUIRE(sortedOutput[4][0].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, (3.80 * 3.80), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][0].first, 3.80, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][1].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, (4.05 * 4.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][1].first, 4.05, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][2].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, (4.15 * 4.15), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][2].first, 4.15, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][3].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, (4.60 * 4.60), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][3].first, 4.60, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][4].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, (4.70 * 4.70), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][4].first, 4.70, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][5].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, (4.90 * 4.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][5].first, 4.90, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][6].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, (5.00 * 5.00), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][6].first, 5.00, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][7].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, (5.27 * 5.27), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][7].first, 5.27, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][8].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, (6.35 * 6.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][8].first, 6.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[4][9].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, (7.05 * 7.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[4][9].first, 7.05, 1e-5);
 
     // Neighbors of point 5.
     BOOST_REQUIRE(sortedOutput[5].size() == 6);
     BOOST_REQUIRE(sortedOutput[5][0].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, (1.08 * 1.08), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][0].first, 1.08, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][1].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, (1.12 * 1.12), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][1].first, 1.12, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][2].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, (1.22 * 1.22), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][2].first, 1.22, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][3].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, (1.47 * 1.47), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][3].first, 1.47, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][4].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, (1.78 * 1.78), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][4].first, 1.78, 1e-5);
     BOOST_REQUIRE(sortedOutput[5][5].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, (5.27 * 5.27), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[5][5].first, 5.27, 1e-5);
 
     // Neighbors of point 6.
     BOOST_REQUIRE(sortedOutput[6].size() == 9);
     BOOST_REQUIRE(sortedOutput[6][0].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, (1.78 * 1.78), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][0].first, 1.78, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][1].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, (2.05 * 2.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][1].first, 2.05, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][2].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, (2.15 * 2.15), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][2].first, 2.15, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][3].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, (2.35 * 2.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][3].first, 2.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][4].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, (2.45 * 2.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][4].first, 2.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][5].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, (2.90 * 2.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][5].first, 2.90, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][6].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, (3.00 * 3.00), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][6].first, 3.00, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][7].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, (3.25 * 3.25), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][7].first, 3.25, 1e-5);
     BOOST_REQUIRE(sortedOutput[6][8].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, (7.05 * 7.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[6][8].first, 7.05, 1e-5);
 
     // Neighbors of point 7.
     BOOST_REQUIRE(sortedOutput[7].size() == 9);
     BOOST_REQUIRE(sortedOutput[7][0].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, (1.08 * 1.08), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][0].first, 1.08, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][1].second == 0);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, (1.35 * 1.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][1].first, 1.35, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][2].second == 2);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, (1.45 * 1.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][2].first, 1.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][3].second == 1);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, (1.65 * 1.65), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][3].first, 1.65, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][4].second == 8);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, (1.75 * 1.75), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][4].first, 1.75, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][5].second == 9);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, (2.20 * 2.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][5].first, 2.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][6].second == 10);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, (2.30 * 2.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][6].first, 2.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][7].second == 3);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, (2.55 * 2.55), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][7].first, 2.55, 1e-5);
     BOOST_REQUIRE(sortedOutput[7][8].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, (6.35 * 6.35), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[7][8].first, 6.35, 1e-5);
 
     // Neighbors of point 8.
     BOOST_REQUIRE(sortedOutput[8].size() == 3);
     BOOST_REQUIRE(sortedOutput[8][0].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, (1.75 * 1.75), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][0].first, 1.75, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][1].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, (2.45 * 2.45), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][1].first, 2.45, 1e-5);
     BOOST_REQUIRE(sortedOutput[8][2].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, (4.60 * 4.60), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[8][2].first, 4.60, 1e-5);
 
     // Neighbors of point 9.
     BOOST_REQUIRE(sortedOutput[9].size() == 4);
     BOOST_REQUIRE(sortedOutput[9][0].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, (1.12 * 1.12), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][0].first, 1.12, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][1].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, (2.20 * 2.20), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][1].first, 2.20, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][2].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, (2.90 * 2.90), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][2].first, 2.90, 1e-5);
     BOOST_REQUIRE(sortedOutput[9][3].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, (4.15 * 4.15), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[9][3].first, 4.15, 1e-5);
 
     // Neighbors of point 10.
     BOOST_REQUIRE(sortedOutput[10].size() == 4);
     BOOST_REQUIRE(sortedOutput[10][0].second == 5);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, (1.22 * 1.22), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][0].first, 1.22, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][1].second == 7);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, (2.30 * 2.30), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][1].first, 2.30, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][2].second == 6);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, (3.00 * 3.00), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][2].first, 3.00, 1e-5);
     BOOST_REQUIRE(sortedOutput[10][3].second == 4);
-    BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, (4.05 * 4.05), 1e-5);
+    BOOST_REQUIRE_CLOSE(sortedOutput[10][3].first, 4.05, 1e-5);
 
     // Clean the memory.
     delete rs;
@@ -486,7 +484,7 @@
   if (!data::Load("test_data_3_1000.csv", dataForTree))
     BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
 
-  // Set up matrices to work with (may not be necessary with no ALIAS_MATRIX?).
+  // Set up matrices to work with.
   arma::mat dualQuery(dataForTree);
   arma::mat naiveQuery(dataForTree);
 
@@ -569,4 +567,222 @@
   }
 }
 
+/**
+ * Ensure that range search with cover trees works by comparing with the kd-tree
+ * implementation.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeTest)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  // Set up kd-tree range search.
+  RangeSearch<> kdsearch(data);
+  // Set up cover tree range search.
+  tree::CoverTree<> tree(data);
+  RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+      coversearch(&tree, data);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.5, 1.5);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(0.8, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for cover tree search.
+    vector<vector<size_t> > coverNeighbors;
+    vector<vector<double> > coverDistances;
+
+    // Run the searches.
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+    coversearch.Search(range, coverNeighbors, coverDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > coverSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(coverNeighbors, coverDistances, coverSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+            1e-5);
+      }
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+    }
+  }
+}
+
+/**
+ * Ensure that range search with cover trees works when using two datasets.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeTwoDatasetsTest)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+  arma::mat queries;
+  queries.randu(8, 350); // 350 points in 8 dimensions.
+
+  // Set up kd-tree range search.
+  RangeSearch<> kdsearch(data, queries);
+  // Set up cover tree range search.
+  tree::CoverTree<> tree(data);
+  tree::CoverTree<> queryTree(queries);
+  RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+      coversearch(&tree, &queryTree, data, queries);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.5, 1.5);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(0.8, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for cover tree search.
+    vector<vector<size_t> > coverNeighbors;
+    vector<vector<double> > coverDistances;
+
+    // Run the searches.
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+    coversearch.Search(range, coverNeighbors, coverDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > coverSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(coverNeighbors, coverDistances, coverSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+            1e-5);
+      }
+    }
+  }
+}
+
+/**
+ * Ensure that single-tree cover tree range search works.
+ */
+BOOST_AUTO_TEST_CASE(RangeSearchCoverTreeSingleTreeTest)
+{
+  arma::mat data;
+  data.randu(8, 1000); // 1000 points in 8 dimensions.
+
+  // Set up kd-tree range search.
+  RangeSearch<> kdsearch(data, true);
+  // Set up cover tree range search.
+  tree::CoverTree<> tree(data);
+  RangeSearch<metric::EuclideanDistance, tree::CoverTree<> >
+      coversearch(&tree, data, true);
+
+  // Four trials with different ranges.
+  for (size_t r = 0; r < 4; ++r)
+  {
+    Range range;
+    switch (r)
+    {
+      case 0:
+        // Includes zero distance.
+        range = Range(0.0, 0.75);
+        break;
+      case 1:
+        // A bounded range on both sides.
+        range = Range(0.5, 1.5);
+        break;
+      case 2:
+        // A range with no upper bound.
+        range = Range(0.8, DBL_MAX);
+        break;
+      case 3:
+        // A range which should have no results.
+        range = Range(15.6, 15.7);
+        break;
+    }
+
+    // Results for kd-tree search.
+    vector<vector<size_t> > kdNeighbors;
+    vector<vector<double> > kdDistances;
+
+    // Results for cover tree search.
+    vector<vector<size_t> > coverNeighbors;
+    vector<vector<double> > coverDistances;
+
+    // Run the searches.
+    kdsearch.Search(range, kdNeighbors, kdDistances);
+    coversearch.Search(range, coverNeighbors, coverDistances);
+
+    // Sort before comparison.
+    vector<vector<pair<double, size_t> > > kdSorted;
+    vector<vector<pair<double, size_t> > > coverSorted;
+    SortResults(kdNeighbors, kdDistances, kdSorted);
+    SortResults(coverNeighbors, coverDistances, coverSorted);
+
+    // Now compare the results.
+    for (size_t i = 0; i < kdSorted.size(); ++i)
+    {
+      for (size_t j = 0; j < kdSorted[i].size(); ++j)
+      {
+        BOOST_REQUIRE_EQUAL(kdSorted[i][j].second, coverSorted[i][j].second);
+        BOOST_REQUIRE_CLOSE(kdSorted[i][j].first, coverSorted[i][j].first,
+            1e-5);
+      }
+      BOOST_REQUIRE_EQUAL(kdSorted[i].size(), coverSorted[i].size());
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list