[mlpack-git] master: Refactor test. (3977306)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Apr 23 14:42:17 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3ad38770911c7b9840901f0934bd1a81c2249046...9763c578d44db92496c044bbda812cf1af49b9a8
>---------------------------------------------------------------
commit 3977306e0f6f34e1c175d48f265f308b376e47fb
Author: ryan <ryan at ratml.org>
Date: Thu Apr 23 14:41:54 2015 -0400
Refactor test.
>---------------------------------------------------------------
3977306e0f6f34e1c175d48f265f308b376e47fb
src/mlpack/tests/allkrann_search_test.cpp | 35 +++++++++++++++++++------------
1 file changed, 22 insertions(+), 13 deletions(-)
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 9a76bce..bd6e043 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -35,7 +35,7 @@ BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
data::Load("rann_test_r_3_900.csv", refData, true);
data::Load("rann_test_q_3_100.csv", queryData, true);
- RASearch<> rsRann(refData, queryData, true);
+ RASearch<> rsRann(refData, true, false, 1.0);
arma::mat qrRanks;
data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -49,7 +49,7 @@ BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
for (size_t rounds = 0; rounds < numRounds; rounds++)
{
- rsRann.Search(1, neighbors, distances, 1.0);
+ rsRann.Search(queryData, 1, neighbors, distances);
for (size_t i = 0; i < queryData.n_cols; i++)
if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeSearch)
arma::Mat<size_t> neighbors;
arma::mat distances;
- RASearch<> tssRann(refData, queryData, false, true);
+ RASearch<> tssRann(refData, false, true, 1.0, 0.95, false, false);
// The relative ranks for the given query reference pair
arma::Mat<size_t> qrRanks;
@@ -108,7 +108,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeSearch)
for (size_t rounds = 0; rounds < numRounds; rounds++)
{
- tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false);
+ tssRann.Search(queryData, 1, neighbors, distances);
for (size_t i = 0; i < queryData.n_cols; i++)
if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -152,7 +152,7 @@ BOOST_AUTO_TEST_CASE(DualTreeSearch)
arma::Mat<size_t> neighbors;
arma::mat distances;
- RASearch<> tsdRann(refData, queryData, false, false);
+ RASearch<> tsdRann(refData, false, false, 1.0, 0.95, false, false, 5);
arma::Mat<size_t> qrRanks;
data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -164,18 +164,27 @@ BOOST_AUTO_TEST_CASE(DualTreeSearch)
// 1% of 900 is 9, so the rank is expected to be less than 10.
size_t expectedRankErrorUB = 10;
+ // Build query tree by hand.
+ typedef tree::BinarySpaceTree<bound::HRectBound<2, false>,
+ RAQueryStat<NearestNeighborSort>> TreeType;
+ std::vector<size_t> oldFromNewQueries;
+ TreeType queryTree(queryData, oldFromNewQueries);
+
for (size_t rounds = 0; rounds < numRounds; rounds++)
{
- tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+ tsdRann.Search(&queryTree, 1, neighbors, distances);
for (size_t i = 0; i < queryData.n_cols; i++)
- if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+ {
+ const size_t oldIndex = oldFromNewQueries[i];
+ if (qrRanks(oldIndex, neighbors(0, i)) < expectedRankErrorUB)
numSuccessRounds[i]++;
+ }
neighbors.reset();
distances.reset();
- tsdRann.ResetQueryTree();
+ tsdRann.ResetQueryTree(&queryTree);
}
// Find the 95%-tile threshold so that 95% of the queries should pass this
@@ -275,7 +284,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
RACoverTreeSearch;
- RACoverTreeSearch tssRann(refData, queryData, false, true);
+ RACoverTreeSearch tssRann(refData, false, true, 1.0, 0.95, false, false, 5);
// The relative ranks for the given query reference pair.
arma::Mat<size_t> qrRanks;
@@ -290,7 +299,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
for (size_t rounds = 0; rounds < numRounds; rounds++)
{
- tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+ tssRann.Search(queryData, 1, neighbors, distances);
for (size_t i = 0; i < queryData.n_cols; i++)
if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -341,7 +350,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
TreeType refTree(refData);
TreeType queryTree(queryData);
- RACoverTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+ RACoverTreeSearch tsdRann(&refTree, false, 1.0, 0.95, false, false, 5);
arma::Mat<size_t> qrRanks;
data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -355,7 +364,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
for (size_t rounds = 0; rounds < numRounds; rounds++)
{
- tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+ tsdRann.Search(&queryTree, 1, neighbors, distances);
for (size_t i = 0; i < queryData.n_cols; i++)
if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -364,7 +373,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
neighbors.reset();
distances.reset();
- tsdRann.ResetQueryTree();
+ tsdRann.ResetQueryTree(&queryTree);
}
// Find the 95%-tile threshold so that 95% of the queries should pass this
More information about the mlpack-git
mailing list