[mlpack-git] master: Test RSModel. (e09d6a4)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Nov 4 13:54:16 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1eb721c663e640d571d8374c67c40ad8a5ea6fb3...e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd
>---------------------------------------------------------------
commit e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd
Author: ryan <ryan at ratml.org>
Date: Wed Nov 4 13:54:00 2015 -0500
Test RSModel.
>---------------------------------------------------------------
e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd
src/mlpack/tests/range_search_test.cpp | 137 +++++++++++++++++++++++++++++++++
1 file changed, 137 insertions(+)
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 9c5ed25..7823ff3 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -7,6 +7,7 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/range_search/range_search.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/methods/range_search/rs_model.hpp>
#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"
@@ -1241,4 +1242,140 @@ BOOST_AUTO_TEST_CASE(MoveTrainTest)
}
}
+BOOST_AUTO_TEST_CASE(RSModelTest)
+{
+ math::RandomSeed(std::time(NULL));
+
+ // Ensure that we can build an RSModel and get correct results.
+ arma::mat queryData = arma::randu<arma::mat>(10, 50);
+ arma::mat referenceData = arma::randu<arma::mat>(10, 200);
+
+ // Build all the possible models.
+ RSModel models[10];
+ models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
+ models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
+ models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
+ models[3] = RSModel(RSModel::TreeTypes::COVER_TREE, false);
+ models[4] = RSModel(RSModel::TreeTypes::R_TREE, true);
+ models[5] = RSModel(RSModel::TreeTypes::R_TREE, false);
+ models[6] = RSModel(RSModel::TreeTypes::R_STAR_TREE, true);
+ models[7] = RSModel(RSModel::TreeTypes::R_STAR_TREE, false);
+ models[8] = RSModel(RSModel::TreeTypes::BALL_TREE, true);
+ models[9] = RSModel(RSModel::TreeTypes::BALL_TREE, false);
+
+ for (size_t j = 0; j < 2; ++j)
+ {
+ // Get a baseline.
+ RangeSearch<> rs(referenceData);
+ vector<vector<size_t>> baselineNeighbors;
+ vector<vector<double>> baselineDistances;
+ rs.Search(queryData, math::Range(0.25, 0.75), baselineNeighbors,
+ baselineDistances);
+
+ vector<vector<pair<double, size_t>>> baselineSorted;
+ SortResults(baselineNeighbors, baselineDistances, baselineSorted);
+
+ for (size_t i = 0; i < 10; ++i)
+ {
+ // We only have std::move() constructors, so make a copy of our data.
+ arma::mat referenceCopy(referenceData);
+ arma::mat queryCopy(queryData);
+ if (j == 0)
+ models[i].BuildModel(std::move(referenceCopy), 5, false, false);
+ else if (j == 1)
+ models[i].BuildModel(std::move(referenceCopy), 5, false, true);
+ else if (j == 2)
+ models[i].BuildModel(std::move(referenceCopy), 5, true, false);
+
+ vector<vector<size_t>> neighbors;
+ vector<vector<double>> distances;
+
+ models[i].Search(std::move(queryCopy), math::Range(0.25, 0.75), neighbors,
+ distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.size(), baselineNeighbors.size());
+ BOOST_REQUIRE_EQUAL(distances.size(), baselineDistances.size());
+
+ vector<vector<pair<double, size_t>>> sorted;
+ SortResults(neighbors, distances, sorted);
+
+ for (size_t k = 0; k < sorted.size(); ++k)
+ {
+ BOOST_REQUIRE_EQUAL(sorted[k].size(), baselineSorted[k].size());
+ for (size_t l = 0; l < sorted[k].size(); ++l)
+ {
+ BOOST_REQUIRE_EQUAL(sorted[k][l].second, baselineSorted[k][l].second);
+ BOOST_REQUIRE_CLOSE(sorted[k][l].first, baselineSorted[k][l].first,
+ 1e-5);
+ }
+ }
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
+{
+ // Ensure that we can build an RSModel and get correct results.
+ arma::mat referenceData = arma::randu<arma::mat>(10, 200);
+
+ // Build all the possible models.
+ RSModel models[10];
+ models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
+ models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
+ models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
+ models[3] = RSModel(RSModel::TreeTypes::COVER_TREE, false);
+ models[4] = RSModel(RSModel::TreeTypes::R_TREE, true);
+ models[5] = RSModel(RSModel::TreeTypes::R_TREE, false);
+ models[6] = RSModel(RSModel::TreeTypes::R_STAR_TREE, true);
+ models[7] = RSModel(RSModel::TreeTypes::R_STAR_TREE, false);
+ models[8] = RSModel(RSModel::TreeTypes::BALL_TREE, true);
+ models[9] = RSModel(RSModel::TreeTypes::BALL_TREE, false);
+
+ for (size_t j = 0; j < 2; ++j)
+ {
+ // Get a baseline.
+ RangeSearch<> rs(referenceData);
+ vector<vector<size_t>> baselineNeighbors;
+ vector<vector<double>> baselineDistances;
+ rs.Search(math::Range(0.25, 0.5), baselineNeighbors, baselineDistances);
+
+ vector<vector<pair<double, size_t>>> baselineSorted;
+ SortResults(baselineNeighbors, baselineDistances, baselineSorted);
+
+ for (size_t i = 0; i < 10; ++i)
+ {
+ // We only have std::move() cosntructors, so make a copy of our data.
+ arma::mat referenceCopy(referenceData);
+ if (j == 0)
+ models[i].BuildModel(std::move(referenceCopy), 5, false, false);
+ else if (j == 1)
+ models[i].BuildModel(std::move(referenceCopy), 5, false, true);
+ else if (j == 2)
+ models[i].BuildModel(std::move(referenceCopy), 5, true, false);
+
+ vector<vector<size_t>> neighbors;
+ vector<vector<double>> distances;
+
+ models[i].Search(math::Range(0.25, 0.5), neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.size(), baselineNeighbors.size());
+ BOOST_REQUIRE_EQUAL(distances.size(), baselineDistances.size());
+
+ vector<vector<pair<double, size_t>>> sorted;
+ SortResults(neighbors, distances, sorted);
+
+ for (size_t k = 0; k < sorted.size(); ++k)
+ {
+ BOOST_REQUIRE_EQUAL(sorted[k].size(), baselineSorted[k].size());
+ for (size_t l = 0; l < sorted[k].size(); ++l)
+ {
+ BOOST_REQUIRE_EQUAL(sorted[k][l].second, baselineSorted[k][l].second);
+ BOOST_REQUIRE_CLOSE(sorted[k][l].first, baselineSorted[k][l].first,
+ 1e-5);
+ }
+ }
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list