[mlpack-git] master: Refactor test. (3977306)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Thu Apr 23 14:42:17 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/3ad38770911c7b9840901f0934bd1a81c2249046...9763c578d44db92496c044bbda812cf1af49b9a8

>---------------------------------------------------------------

commit 3977306e0f6f34e1c175d48f265f308b376e47fb
Author: ryan <ryan at ratml.org>
Date:   Thu Apr 23 14:41:54 2015 -0400

    Refactor test.


>---------------------------------------------------------------

3977306e0f6f34e1c175d48f265f308b376e47fb
 src/mlpack/tests/allkrann_search_test.cpp | 35 +++++++++++++++++++------------
 1 file changed, 22 insertions(+), 13 deletions(-)

diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index 9a76bce..bd6e043 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -35,7 +35,7 @@ BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
   data::Load("rann_test_r_3_900.csv", refData, true);
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
-  RASearch<> rsRann(refData, queryData, true);
+  RASearch<> rsRann(refData, true, false, 1.0);
 
   arma::mat qrRanks;
   data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -49,7 +49,7 @@ BOOST_AUTO_TEST_CASE(NaiveGuaranteeTest)
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    rsRann.Search(1, neighbors, distances, 1.0);
+    rsRann.Search(queryData, 1, neighbors, distances);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeSearch)
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  RASearch<> tssRann(refData, queryData, false, true);
+  RASearch<> tssRann(refData, false, true, 1.0, 0.95, false, false);
 
   // The relative ranks for the given query reference pair
   arma::Mat<size_t> qrRanks;
@@ -108,7 +108,7 @@ BOOST_AUTO_TEST_CASE(SingleTreeSearch)
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false);
+    tssRann.Search(queryData, 1, neighbors, distances);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -152,7 +152,7 @@ BOOST_AUTO_TEST_CASE(DualTreeSearch)
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  RASearch<> tsdRann(refData, queryData, false, false);
+  RASearch<> tsdRann(refData, false, false, 1.0, 0.95, false, false, 5);
 
   arma::Mat<size_t> qrRanks;
   data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -164,18 +164,27 @@ BOOST_AUTO_TEST_CASE(DualTreeSearch)
   // 1% of 900 is 9, so the rank is expected to be less than 10.
   size_t expectedRankErrorUB = 10;
 
+  // Build query tree by hand.
+  typedef tree::BinarySpaceTree<bound::HRectBound<2, false>,
+      RAQueryStat<NearestNeighborSort>> TreeType;
+  std::vector<size_t> oldFromNewQueries;
+  TreeType queryTree(queryData, oldFromNewQueries);
+
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+    tsdRann.Search(&queryTree, 1, neighbors, distances);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
-      if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
+    {
+      const size_t oldIndex = oldFromNewQueries[i];
+      if (qrRanks(oldIndex, neighbors(0, i)) < expectedRankErrorUB)
         numSuccessRounds[i]++;
+    }
 
     neighbors.reset();
     distances.reset();
 
-    tsdRann.ResetQueryTree();
+    tsdRann.ResetQueryTree(&queryTree);
   }
 
   // Find the 95%-tile threshold so that 95% of the queries should pass this
@@ -275,7 +284,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
   typedef RASearch<NearestNeighborSort, metric::EuclideanDistance, TreeType>
       RACoverTreeSearch;
 
-  RACoverTreeSearch tssRann(refData, queryData, false, true);
+  RACoverTreeSearch tssRann(refData, false, true, 1.0, 0.95, false, false, 5);
 
   // The relative ranks for the given query reference pair.
   arma::Mat<size_t> qrRanks;
@@ -290,7 +299,7 @@ BOOST_AUTO_TEST_CASE(SingleCoverTreeTest)
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+    tssRann.Search(queryData, 1, neighbors, distances);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -341,7 +350,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
   TreeType refTree(refData);
   TreeType queryTree(queryData);
 
-  RACoverTreeSearch tsdRann(&refTree, &queryTree, refData, queryData, false);
+  RACoverTreeSearch tsdRann(&refTree, false, 1.0, 0.95, false, false, 5);
 
   arma::Mat<size_t> qrRanks;
   data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
@@ -355,7 +364,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+    tsdRann.Search(&queryTree, 1, neighbors, distances);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -364,7 +373,7 @@ BOOST_AUTO_TEST_CASE(DualCoverTreeTest)
     neighbors.reset();
     distances.reset();
 
-    tsdRann.ResetQueryTree();
+    tsdRann.ResetQueryTree(&queryTree);
   }
 
   // Find the 95%-tile threshold so that 95% of the queries should pass this



More information about the mlpack-git mailing list