[mlpack-svn] r15920 - mlpack/trunk/src/mlpack/tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Oct 4 00:30:32 EDT 2013
Author: rcurtin
Date: Fri Oct 4 00:30:31 2013
New Revision: 15920
Log:
Test RANN with cover trees.
Modified:
mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp
Modified: mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp (original)
+++ mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp Fri Oct 4 00:30:31 2013
@@ -7,6 +7,7 @@
#include <time.h>
#include <mlpack/core.hpp>
#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -22,7 +23,7 @@
// Test AllkRANN in naive mode for exact results when the random seeds are set
// the same. This may not be the best test; if the implementation of RANN-RS
// gets random numbers in a different way, then this test might fail.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearchExact)
+BOOST_AUTO_TEST_CASE(NaiveSearchExact)
{
// First test on a small set.
arma::mat rdata(2, 10);
@@ -86,13 +87,13 @@
// 3. Check the neighbor returned.
for (size_t i = 0; i < qdata.n_cols; i++)
{
- BOOST_REQUIRE(neighbors(0, i) == rann[i]);
+ BOOST_REQUIRE_EQUAL(neighbors(0, i), rann[i]);
BOOST_REQUIRE_CLOSE(distances(0, i), rannDistances[i], 1e-5);
}
}
// Test the correctness and guarantees of AllkRANN when in naive mode.
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveGuaranteeTest)
+BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
{
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -143,12 +144,12 @@
// 5% of 100 queries is 5.
size_t maxNumQueriesFail = 6;
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
}
// Test single-tree rank-approximate search (harder to test because of
// the randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNSingleTreeSearch)
+BOOST_AUTO_TEST_CASE(SingleTreeSearch)
{
arma::mat refData;
arma::mat queryData;
@@ -202,12 +203,12 @@
// 5% of 100 queries is 5.
size_t maxNumQueriesFail = 6;
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
}
// Test dual-tree rank-approximate search (harder to test because of the
// randomness involved).
-BOOST_AUTO_TEST_CASE(AllkRANNDualTreeSearch)
+BOOST_AUTO_TEST_CASE(DualTreeSearch)
{
arma::mat refData;
arma::mat queryData;
@@ -262,12 +263,12 @@
// 5% of 100 queries is 5.
size_t maxNumQueriesFail = 6;
- BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
}
// Test rank-approximate search with just a single dataset. These tests just
// ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetNaiveSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetNaiveSearch)
{
arma::mat dataset(5, 2500);
dataset.randn();
@@ -287,7 +288,7 @@
// Test rank-approximate search with just a single dataset in single-tree mode.
// These tests just ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetSingleSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetSingleSearch)
{
arma::mat dataset(5, 2500);
dataset.randn();
@@ -307,7 +308,7 @@
// Test rank-approximate search with just a single dataset in dual-tree mode.
// These tests just ensure that the method runs okay.
-BOOST_AUTO_TEST_CASE(AllkRANNSingleDatasetSearch)
+BOOST_AUTO_TEST_CASE(SingleDatasetSearch)
{
arma::mat dataset(5, 2500);
dataset.randn();
@@ -324,4 +325,135 @@
BOOST_REQUIRE_EQUAL(distances.n_cols, 2500);
}
+// Test single-tree rank-approximate search with cover trees.
+BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+ RAQueryStat<NearestNeighborSort> > TreeType;
+ typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+ RACoverTreeSearch;
+
+ TreeType refTree(refData);
+ RACoverTreeSearch tssRann(&refTree, NULL, refData, queryData, true);
+
+ // The relative ranks for the given query reference pair.
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSS (cover tree): RANN guarantee fails on "
+ << numQueriesFail << " queries." << endl;
+
+ // Assert that at most 5% of the queries fall out of this threshold.
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
+// Test dual-tree rank-approximate search with cover trees.
+BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
+{
+ arma::mat refData;
+ arma::mat queryData;
+
+ data::Load("rann_test_r_3_900.csv", refData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+ // (rank error of 3).
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ typedef tree::CoverTree<metric::EuclideanDistance, tree::FirstPointIsRoot,
+ RAQueryStat<NearestNeighborSort> > TreeType;
+ typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
+ RACoverTreeSearch;
+
+ TreeType refTree(refData);
+ TreeType queryTree(queryData);
+
+ RACoverTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ size_t numRounds = 1000;
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ for (size_t rounds = 0; rounds < numRounds; rounds++)
+ {
+ tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ numSuccessRounds[i]++;
+
+ neighbors.reset();
+ distances.reset();
+
+ tsdRann.ResetQueryTree();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t i = 0; i < queryData.n_cols; i++)
+ if (numSuccessRounds[i] < threshold)
+ numQueriesFail++;
+
+ Log::Warn << "RANN-TSD (cover tree): RANN guarantee fails on "
+ << numQueriesFail << " queries." << endl;
+
+ // assert that at most 5% of the queries fall out of this threshold
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list