[mlpack-svn] r15920 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 4 00:30:32 EDT 2013


Author: rcurtin
Date: Fri Oct  4 00:30:31 2013
New Revision: 15920

Log:
Test RANN with cover trees.


Modified:
   mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp

Modified: mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp	(original)
+++ mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp	Fri Oct  4 00:30:31 2013
@@ -7,6 +7,7 @@
 #include <time.h>
 #include <mlpack/core.hpp>
 #include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
 
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
@@ -22,7 +23,7 @@
 // Test AllkRANN in naive mode for exact results when the random seeds are set
 // the same.  This may not be the best test; if the implementation of RANN-RS
 // gets random numbers in a different way, then this test might fail.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearchExact)
+BOOST_AUTO_TEST_CASE(NaiveSearchExact)
 {
   // First test on a small set.
   arma::mat rdata(2, 10);
@@ -86,13 +87,13 @@
   // 3. Check the neighbor returned.
   for (size_t i = 0; i < qdata.n_cols; i++)
   {
-    BOOST_REQUIRE(neighbors(0, i) == rann[i]);
+    BOOST_REQUIRE_EQUAL(neighbors(0, i), rann[i]);
     BOOST_REQUIRE_CLOSE(distances(0, i), rannDistances[i], 1e-5);
   }
 }
 
 // Test the correctness and guarantees of AllkRANN when in naive mode.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveGuaranteeTest)
+BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
 {
   arma::Mat<size_t> neighbors;
   arma::mat distances;
@@ -143,12 +144,12 @@
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
-  BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
 }
 
 // Test single-tree rank-approximate search (harder to test because of
 // the randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNSingleTreeSearch)
+BOOST_AUTO_TEST_CASE(SingleTreeSearch)
 {
   arma::mat refData;
   arma::mat queryData;
@@ -202,12 +203,12 @@
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
-  BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
 }
 
 // Test dual-tree rank-approximate search (harder to test because of the
 // randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNDualTreeSearch)
+BOOST_AUTO_TEST_CASE(DualTreeSearch)
 {
   arma::mat refData;
   arma::mat queryData;
@@ -262,12 +263,12 @@
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
-  BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
 }
 
 // Test rank-approximate search with just a single dataset.  These tests just
 // ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetNaiveSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetNaiveSearch)
 {
   arma::mat dataset(5, 2500);
   dataset.randn();
@@ -287,7 +288,7 @@
 
 // Test rank-approximate search with just a single dataset in single-tree mode.
 // These tests just ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetSingleSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetSingleSearch)
 {
   arma::mat dataset(5, 2500);
   dataset.randn();
@@ -307,7 +308,7 @@
 
 // Test rank-approximate search with just a single dataset in dual-tree mode.
 // These tests just ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetSearch)
 {
   arma::mat dataset(5, 2500);
   dataset.randn();
@@ -324,4 +325,135 @@
   BOOST_REQUIRE_EQUAL(distances.n_cols, 2500);
 }
 
+// Test single-tree rank-approximate search with cover trees.
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+  arma::mat refData;
+  arma::mat queryData;
+
+  data::Load("rann_test_r_3_900.csv", refData, true);
+  data::Load("rann_test_q_3_100.csv", queryData, true);
+
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+      RAQueryStat<NearestNeighborSort> > TreeType;
+  typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+      RACoverTreeSearch;
+
+  TreeType refTree(refData);
+  RACoverTreeSearch tssRann(&refTree, NULL, refData, queryData, true);
+
+  // The relative ranks for the given query reference pair.
+  arma::Mat<size_t> qrRanks;
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+  size_t numRounds = 1000;
+  arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+  numSuccessRounds.fill(0);
+
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
+  size_t expectedRankErrorUB = 10;
+
+  for (size_t rounds = 0; rounds < numRounds; rounds++)
+  {
+    tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+    for (size_t i = 0; i < queryData.n_cols; i++)
+      if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+        numSuccessRounds[i]++;
+
+    neighbors.reset();
+    distances.reset();
+  }
+
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+  size_t numQueriesFail = 0;
+  for (size_t i = 0; i < queryData.n_cols; i++)
+    if (numSuccessRounds[i] < threshold)
+      numQueriesFail++;
+
+  Log::Warn << "RANN-TSS (cover tree): RANN guarantee fails on "
+      << numQueriesFail << " queries." << endl;
+
+  // Assert that at most 5% of the queries fall out of this threshold.
+  // 5% of 100 queries is 5.
+  size_t maxNumQueriesFail = 6;
+
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
+// Test dual-tree rank-approximate search with cover trees.
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+  arma::mat refData;
+  arma::mat queryData;
+
+  data::Load("rann_test_r_3_900.csv", refData, true);
+  data::Load("rann_test_q_3_100.csv", queryData, true);
+
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+      RAQueryStat<NearestNeighborSort> > TreeType;
+  typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+      RACoverTreeSearch;
+
+  TreeType refTree(refData);
+  TreeType queryTree(queryData);
+
+  RACoverTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+
+  arma::Mat<size_t> qrRanks;
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+  size_t numRounds = 1000;
+  arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+  numSuccessRounds.fill(0);
+
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
+  size_t expectedRankErrorUB = 10;
+
+  for (size_t rounds = 0; rounds < numRounds; rounds++)
+  {
+    tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+    for (size_t i = 0; i < queryData.n_cols; i++)
+      if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+        numSuccessRounds[i]++;
+
+    neighbors.reset();
+    distances.reset();
+
+    tsdRann.ResetQueryTree();
+  }
+
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+  size_t numQueriesFail = 0;
+  for (size_t i = 0; i < queryData.n_cols; i++)
+    if (numSuccessRounds[i] < threshold)
+      numQueriesFail++;
+
+  Log::Warn << "RANN-TSD (cover tree): RANN guarantee fails on "
+      << numQueriesFail << " queries." << endl;
+
+  // assert that at most 5% of the queries fall out of this threshold
+  // 5% of 100 queries is 5.
+  size_t maxNumQueriesFail = 6;
+
+  BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-svn mailing list