[mlpack-git] master: Add test for RAModel. (does not yet work, still debugging) (c7b4744)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Dec 9 14:37:22 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/cec4ac427536cbd9738a33e0c6facabeeadd31b0...4a39d474593067343b4972d4a5217bcfae84ca5d
>---------------------------------------------------------------
commit c7b4744e0a8d8c0d7cd544c04eacff983d63462e
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Dec 9 09:39:13 2015 -0500
Add test for RAModel. (does not yet work, still debugging)
>---------------------------------------------------------------
c7b4744e0a8d8c0d7cd544c04eacff983d63462e
src/mlpack/tests/allkrann_search_test.cpp | 91 +++++++++++++++++++++++++++++++
1 file changed, 91 insertions(+)
diff --git a/src/mlpack/tests/allkrann_search_test.cpp b/src/mlpack/tests/allkrann_search_test.cpp
index e116c4c..ccd3bf1 100644
--- a/src/mlpack/tests/allkrann_search_test.cpp
+++ b/src/mlpack/tests/allkrann_search_test.cpp
@@ -13,6 +13,7 @@
#include "old_boost_test_definitions.hpp"
#include <mlpack/methods/rann/ra_search.hpp>
+#include <mlpack/methods/rann/ra_model.hpp>
using namespace std;
using namespace mlpack;
@@ -610,4 +611,94 @@ BOOST_AUTO_TEST_CASE(MoveTrainTest)
BOOST_REQUIRE_EQUAL(distances.n_cols, 300);
}
+/**
+ * Make sure the RAModel class works.
+ */
+BOOST_AUTO_TEST_CASE(RAModelTest)
+{
+ // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+ // results.
+ typedef RAModel<NearestNeighborSort> KNNModel;
+
+ arma::mat queryData, referenceData;
+ data::Load("rann_test_r_3_900.csv", referenceData, true);
+ data::Load("rann_test_q_3_100.csv", queryData, true);
+
+ // Build all the possible models.
+ std::cout << "build models\n";
+ KNNModel models[8];
+ models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+ models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+ models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+ models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+ models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+ models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+ models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+ models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+ arma::Mat<size_t> qrRanks;
+ data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
+
+ for (size_t j = 2; j + 1 > 0; --j)
+ {
+ for (size_t i = 7; i + 1 > 0; --i)
+ {
+ // We only have std::move() constructors so make a copy of our data.
+ std::cout << "build model " << i << " " << j << ".\n";
+ arma::mat referenceCopy(referenceData);
+ if (j == 0)
+ models[i].BuildModel(std::move(referenceCopy), 20, false, false);
+ if (j == 1)
+ models[i].BuildModel(std::move(referenceCopy), 20, false, true);
+ if (j == 2)
+ models[i].BuildModel(std::move(referenceCopy), 20, true, false);
+ std::cout << "built\n";
+
+ // Set the search parameters.
+ models[i].Tau() = 1.0;
+ models[i].Alpha() = 0.95;
+ models[i].SampleAtLeaves() = false;
+ models[i].FirstLeafExact() = false;
+ models[i].SingleSampleLimit() = 5;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ arma::Col<size_t> numSuccessRounds(queryData.n_cols);
+ numSuccessRounds.fill(0);
+
+ // 1% of 900 is 9, so the rank is expected to be less than 10.
+ size_t expectedRankErrorUB = 10;
+
+ size_t numRounds = 100;
+ for (size_t round = 0; round < numRounds; round++)
+ {
+ arma::mat queryCopy(queryData);
+ models[i].Search(std::move(queryCopy), 1, neighbors, distances);
+ for (size_t k = 0; k < queryData.n_cols; k++)
+ if (qrRanks(k, neighbors(0, k)) < expectedRankErrorUB)
+ numSuccessRounds[k]++;
+
+ neighbors.reset();
+ distances.reset();
+ }
+
+ // Find the 95%-tile threshold so that 95% of the queries should pass this
+ // threshold.
+ size_t threshold = floor(numRounds *
+ (0.95 - (1.96 * sqrt(0.95 * 0.05 / numRounds))));
+ size_t numQueriesFail = 0;
+ for (size_t k = 0; k < queryData.n_cols; k++)
+ if (numSuccessRounds[k] < threshold)
+ numQueriesFail++;
+
+ // assert that at most 5% of the queries fall out of this threshold
+ // 5% of 100 queries is 5.
+ size_t maxNumQueriesFail = 6;
+
+ BOOST_REQUIRE_LT(numQueriesFail, maxNumQueriesFail);
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list