[mlpack-svn] r14347 - mlpack/trunk/src/mlpack/tests

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Feb 20 15:16:59 EST 2013


Author: rcurtin
Date: 2013-02-20 15:16:58 -0500 (Wed, 20 Feb 2013)
New Revision: 14347

Modified:
   mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp
Log:
Clean up tests; split into four (not three) tests, remove random seed setting
where appropriate.


Modified: mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp	2013-02-20 17:55:46 UTC (rev 14346)
+++ mlpack/trunk/src/mlpack/tests/allkrann_search_test.cpp	2013-02-20 20:16:58 UTC (rev 14347)
@@ -1,7 +1,7 @@
 /**
  * @file allkrann_search_test.cpp
  *
- * Unit tests for the 'RASearch' class and consequently the 
+ * Unit tests for the 'RASearch' class and consequently the
  * 'RASearchRules' class
  */
 #include <time.h>
@@ -23,35 +23,35 @@
 
 BOOST_AUTO_TEST_SUITE(AllkRANNTest);
 
-BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearch)
+// Test AllkRANN in naive mode for exact results when the random seeds are set
+// the same.  This may not be the best test; if the implementation of RANN-RS
+// gets random numbers in a different way, then this test might fail.
+BOOST_AUTO_TEST_CASE(AllkRANNNaiveSearchExact)
 {
-  // first testing on a small set.
-
+  // First test on a small set.
   arma::mat rdata(2, 10);
-  rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr << 
-    0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
+  rdata << 3 << 2 << 4 << 3 << 5 << 6 << 0 << 8 << 3 << 1 << arma::endr <<
+           0 << 3 << 4 << 7 << 8 << 4 << 1 << 0 << 4 << 3 << arma::endr;
 
   arma::mat qdata(2, 3);
-  qdata << 3 << 2 << 0 << arma::endr << 5 << 3 << 4 << arma::endr;
+  qdata << 3 << 2 << 0 << arma::endr
+        << 5 << 3 << 4 << arma::endr;
 
-
   metric::SquaredEuclideanDistance dMetric;
   double rankApproximation = 30;
   double successProb = 0.95;
 
-  // Search for 1 rank-approximate nearest-neighbors in the top 30% 
-  // of the point (rank error of 3)
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-
-  // Test naive rank-approximate search
-
-  // Predict what the actual RANN-RS result would be
+  // Test naive rank-approximate search.
+  // Predict what the actual RANN-RS result would be.
   math::RandomSeed(0);
 
-  size_t numSamples = (size_t) ceil( log (1.0 / (1.0 - successProb)) 
-   / log (1.0 / (1.0 - (rankApproximation / 100.0) ) ) );
+  size_t numSamples = (size_t) ceil(log(1.0 / (1.0 - successProb)) /
+      log(1.0 / (1.0 - (rankApproximation / 100.0))));
 
   arma::Mat<size_t> samples(qdata.n_cols, numSamples);
   for (size_t j = 0; j < qdata.n_cols; j++)
@@ -59,68 +59,58 @@
       samples(j, i) = (size_t) math::RandInt(10);
 
   arma::Col<size_t> rann(qdata.n_cols);
-  arma::vec rann_distance(qdata.n_cols);
-  rann_distance.fill(DBL_MAX);
+  arma::vec rannDistances(qdata.n_cols);
+  rannDistances.fill(DBL_MAX);
 
   for (size_t j = 0; j < qdata.n_cols; j++)
   {
     for (size_t i = 0; i < numSamples; i++)
     {
-      double dist = dMetric.Evaluate(qdata.unsafe_col(j), 
+      double dist = dMetric.Evaluate(qdata.unsafe_col(j),
                                      rdata.unsafe_col(samples(j, i)));
-      if (dist < rann_distance[j])
+      if (dist < rannDistances[j])
       {
         rann[j] = samples(j, i);
-        rann_distance[j] = dist;
+        rannDistances[j] = dist;
       }
     }
   }
 
-  // use RANN-RS implementation
+  // Use RANN-RS implementation.
   math::RandomSeed(0);
 
-  arma::mat naive_rdata = rdata;
-  arma::mat naive_qdata = qdata;
+  RASearch<> naive(rdata, qdata, true);
+  naive.Search(1, neighbors, distances, rankApproximation);
 
-  RASearch<> *naive = new RASearch<>(naive_rdata, naive_qdata, true);
-  naive->Search(1, neighbors, distances, rankApproximation);
-
-  delete naive;
-  naive_rdata.reset();
-  naive_qdata.reset();
-
   // Things to check:
-  // 
-  // 1. (implicitly) The minimum number of required samples for 
-  //    guaranteed approximation
+  //
+  // 1. (implicitly) The minimum number of required samples for guaranteed
+  //    approximation.
   // 2. (implicitly) Check the samples obtained.
   // 3. Check the neighbor returned.
-
   for (size_t i = 0; i < qdata.n_cols; i++)
   {
     BOOST_REQUIRE(neighbors(0, i) == rann[i]);
-    BOOST_REQUIRE_CLOSE(distances(0, i), rann_distance[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(distances(0, i), rannDistances[i], 1e-5);
   }
+}
 
-  Log::Warn << "RANN-RS (no tree) works as expected on small set." << endl;
+// Test the correctness and guarantees of AllkRANN when in naive mode.
+BOOST_AUTO_TEST_CASE(AllkRANNNaiveGuaranteeTest)
+{
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
 
-  neighbors.reset();
-  distances.reset();
-
-  // now test the correctness & guarantees of this algorithm
-  math::RandomSeed(time(NULL));
-    
   arma::mat refData;
   arma::mat queryData;
 
   data::Load("rann_test_r_3_900.csv", refData, true);
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
-  RASearch<> *rann_rs = new RASearch<>(refData, queryData, true);
+  RASearch<> rsRann(refData, queryData, true);
 
   arma::mat qrRanks;
-  data::Load("rann_test_qr_ranks.csv", qrRanks, true);
-  qrRanks = qrRanks.t();
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
 
   size_t numRounds = 1000;
   arma::Col<size_t> numSuccessRounds(queryData.n_cols);
@@ -131,7 +121,7 @@
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    rann_rs->Search(1, neighbors, distances, 1.0);
+    rsRann.Search(1, neighbors, distances, 1.0);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -141,67 +131,56 @@
     distances.reset();
   }
 
-  delete rann_rs;
-
-  // Finding the 95%-tile threshold so that 95% of the queries should 
-  // pass this threshold
-  size_t threshold = floor(numRounds * (0.95 - (1.96 * sqrt(0.95 * 0.05 
-                                                            / numRounds))));
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
   size_t numQueriesFail = 0;
   for (size_t i = 0; i < queryData.n_cols; i++)
     if (numSuccessRounds[i] < threshold)
       numQueriesFail++;
-  
-  Log::Warn << "RANN-RS: RANN guarantee fails on " << numQueriesFail << 
-    " queries." << endl;
 
+  Log::Warn << "RANN-RS: RANN guarantee fails on " << numQueriesFail
+      << " queries." << endl;
+
   // assert that at most 5% of the queries fall out of this threshold
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
   BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-  Log::Warn << "RANN-RS (no tree) guarantees desired rank-approximation." <<
-    endl;
 }
 
-
+// Test single-tree rank-approximate search (harder to test because of
+// the randomness involved).
 BOOST_AUTO_TEST_CASE(AllkRANNSingleTreeSearch)
 {
-  // Test single-tree rank-approximate search (harder to test because of 
-  // the randomness involved)
-
-  // Checking the correctness & guarantees of the algorithm
-  math::RandomSeed(time(NULL));
-
   arma::mat refData;
   arma::mat queryData;
 
   data::Load("rann_test_r_3_900.csv", refData, true);
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
-  // Search for 1 rank-approximate nearest-neighbors in the top 30% 
-  // of the point (rank error of 3)
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
+  RASearch<> tssRann(refData, queryData, false, true, 5);
 
-  RASearch<> *rann_tss = new RASearch<>(refData, queryData, false, true, 5);
-
   // The relative ranks for the given query reference pair
   arma::Mat<size_t> qrRanks;
-  data::Load("rann_test_qr_ranks.csv", qrRanks, true);
-  qrRanks = qrRanks.t();
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
 
   size_t numRounds = 1000;
   arma::Col<size_t> numSuccessRounds(queryData.n_cols);
   numSuccessRounds.fill(0);
 
-  // 1% of 900 is 9, so the rank is expected to be less than 10
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
   size_t expectedRankErrorUB = 10;
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    rann_tss->Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+    tssRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -211,65 +190,55 @@
     distances.reset();
   }
 
-  delete rann_tss;
-
-  // Finding the 95%-tile threshold so that 95% of the queries should 
-  // pass this threshold
-  size_t threshold = floor(numRounds * (0.95 - (1.96 * sqrt(0.95 * 0.05 
-                                                            / numRounds))));
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
   size_t numQueriesFail = 0;
   for (size_t i = 0; i < queryData.n_cols; i++)
     if (numSuccessRounds[i] < threshold)
       numQueriesFail++;
 
-  Log::Warn << "RANN-TSS: RANN guarantee fails on " << numQueriesFail << 
-    " queries." << endl;
+  Log::Warn << "RANN-TSS: RANN guarantee fails on " << numQueriesFail
+      << " queries." << endl;
 
-  // assert that at most 5% of the queries fall out of this threshold
+  // Assert that at most 5% of the queries fall out of this threshold.
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
   BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-  Log::Warn << "RANN-TSS (single tree) guarantees desired " << 
-    "rank-approximation." << endl;
 }
 
+// Test dual-tree rank-approximate search (harder to test because of the
+// randomness involved).
 BOOST_AUTO_TEST_CASE(AllkRANNDualTreeSearch)
 {
-  // Test dual-tree rank-approximate search (harder to test because of 
-  // the randomness involved)
-  // Test dual-tree rank-approximate search
-
-  // Checking the correctness & guarantees of the algorithm
-  math::RandomSeed(time(NULL));
-
   arma::mat refData;
   arma::mat queryData;
 
   data::Load("rann_test_r_3_900.csv", refData, true);
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
-  // Search for 1 rank-approximate nearest-neighbors in the top 30% 
-  // of the point (rank error of 3)
+  // Search for 1 rank-approximate nearest-neighbors in the top 30% of the point
+  // (rank error of 3).
   arma::Mat<size_t> neighbors;
   arma::mat distances;
 
-  RASearch<> *rann_tsd = new RASearch<>(refData, queryData, false, false, 5);
+  RASearch<> tsdRann(refData, queryData, false, false, 5);
 
   arma::Mat<size_t> qrRanks;
-  data::Load("rann_test_qr_ranks.csv", qrRanks, true);
-  qrRanks = qrRanks.t();
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
 
   size_t numRounds = 1000;
   arma::Col<size_t> numSuccessRounds(queryData.n_cols);
   numSuccessRounds.fill(0);
 
-  // 1% of 900 is 9, so the rank is expected to be less than 10
+  // 1% of 900 is 9, so the rank is expected to be less than 10.
   size_t expectedRankErrorUB = 10;
 
   for (size_t rounds = 0; rounds < numRounds; rounds++)
   {
-    rann_tsd->Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
+    tsdRann.Search(1, neighbors, distances, 1.0, 0.95, false, false, 5);
 
     for (size_t i = 0; i < queryData.n_cols; i++)
       if (qrRanks(i, neighbors(0, i)) < expectedRankErrorUB)
@@ -278,30 +247,26 @@
     neighbors.reset();
     distances.reset();
 
-    rann_tsd->ResetQueryTree();
+    tsdRann.ResetQueryTree();
   }
 
-  delete rann_tsd;
-
-  // Finding the 95%-tile threshold so that 95% of the queries should 
-  // pass this threshold
-  size_t threshold = floor(numRounds * (0.95 - (1.96 * sqrt(0.95 * 0.05 
-                                                              / numRounds))));
+  // Find the 95%-tile threshold so that 95% of the queries should pass this
+  // threshold.
+  size_t threshold = floor(numRounds *
+      (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
   size_t numQueriesFail = 0;
   for (size_t i = 0; i < queryData.n_cols; i++)
     if (numSuccessRounds[i] < threshold)
       numQueriesFail++;
 
-  Log::Warn << "RANN-TSD: RANN guarantee fails on " << numQueriesFail << 
-    " queries." << endl;
+  Log::Warn << "RANN-TSD: RANN guarantee fails on " << numQueriesFail
+      << " queries." << endl;
 
   // assert that at most 5% of the queries fall out of this threshold
   // 5% of 100 queries is 5.
   size_t maxNumQueriesFail = 6;
 
   BOOST_REQUIRE(numQueriesFail < maxNumQueriesFail);
-  Log::Warn << "RANN-TSD (dual tree) guarantees desired " << 
-    "rank-approximation." << endl;
 }
 
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-svn mailing list