[mlpack-git] master: Add test for RAModel. (does not yet work, still debugging) (b0f81c8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Dec 13 23:04:14 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/672786f499b8cc48ab1722de4544c1dc03af988a...7467dd42adfc64b8c7e4107f361c3245728e99e5

>---------------------------------------------------------------

commit b0f81c83db9ae5aadf4153273c8a7c93254fd5cb
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Dec 9 09:39:13 2015 -0500

    Add test for RAModel. (does not yet work, still debugging)


>---------------------------------------------------------------

b0f81c83db9ae5aadf4153273c8a7c93254fd5cb
 src/mlpack/tests/allkrann_search_test.cpp | 91 +++++++++++++++++++++++++++++++
 1 file changed, 91 insertions(+)

diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index e116c4c..ccd3bf1 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -13,6 +13,7 @@
 #include "old_boost_test_definitions.hpp"
 
 #include <mlpack/methods/rann/ra_search.hpp>
+#include <mlpack/methods/rann/ra_model.hpp>
 
 using namespace std;
 using namespace mlpack;
@@ -610,4 +611,94 @@ BOOST_AUTO_TEST_CASE(MoveTrainTest)
   BOOST_REQUIRE_EQUAL(distances.n_cols, 300);
 }
 
+/**
+ * Make sure the RAModel class works.
+ */
+BOOST_AUTO_TEST_CASE(RAModelTest)
+{
+  // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+  // results.
+  typedef RAModel<NearestNeighborSort> KNNModel;
+
+  arma::mat queryData, referenceData;
+  data::Load("rann_test_r_3_900.csv", referenceData, true);
+  data::Load("rann_test_q_3_100.csv", queryData, true);
+
+  // Build all the possible models.
+  std::cout << "build models\n";
+  KNNModel models[8];
+  models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+  models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+  models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+  models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+  models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+  models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+  models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+  models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+  arma::Mat<size_t> qrRanks;
+  data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+  for (size_t j = 2; j + 1 > 0; --j)
+  {
+    for (size_t i = 7; i + 1 > 0; --i)
+    {
+      // We only have std::move() constructors so make a copy of our data.
+      std::cout << "build model " << i << " " << j << ".\n";
+      arma::mat referenceCopy(referenceData);
+      if (j == 0)
+        models[i].BuildModel(std::move(referenceCopy), 20, false, false);
+      if (j == 1)
+        models[i].BuildModel(std::move(referenceCopy), 20, false, true);
+      if (j == 2)
+        models[i].BuildModel(std::move(referenceCopy), 20, true, false);
+      std::cout << "built\n";
+
+      // Set the search parameters.
+      models[i].Tau() = 1.0;
+      models[i].Alpha() = 0.95;
+      models[i].SampleAtLeaves() = false;
+      models[i].FirstLeafExact() = false;
+      models[i].SingleSampleLimit() = 5;
+
+      arma::Mat<size_t> neighbors;
+      arma::mat distances;
+
+      arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+      numSuccessRounds.fill(0);
+
+      // 1% of 900 is 9, so the rank is expected to be less than 10.
+      size_t expectedRankErrorUB = 10;
+
+      size_t numRounds = 100;
+      for (size_t round = 0; round < numRounds; round++)
+      {
+        arma::mat queryCopy(queryData);
+        models[i].Search(std::move(queryCopy), 1, neighbors, distances);
+        for (size_t k = 0; k < queryData.n_cols; k++)
+          if (qrRanks(k, neighbors(0, k)) < expectedRankErrorUB)
+            numSuccessRounds[k]++;
+
+        neighbors.reset();
+        distances.reset();
+      }
+
+      // Find the 95%-tile threshold so that 95% of the queries should pass this
+      // threshold.
+      size_t threshold = floor(numRounds *
+          (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+      size_t numQueriesFail = 0;
+      for (size_t k = 0; k < queryData.n_cols; k++)
+        if (numSuccessRounds[k] < threshold)
+          numQueriesFail++;
+
+      // assert that at most 5% of the queries fall out of this threshold
+      // 5% of 100 queries is 5.
+      size_t maxNumQueriesFail = 6;
+
+      BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list