[mlpack-git] master: Test RSModel. (e09d6a4)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Wed Nov 4 13:54:16 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/1eb721c663e640d571d8374c67c40ad8a5ea6fb3...e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd

>---------------------------------------------------------------

commit e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd
Author: ryan <ryan at ratml.org>
Date:   Wed Nov 4 13:54:00 2015 -0500

    Test RSModel.


>---------------------------------------------------------------

e09d6a4869ed02fbf3d9d9b22ab1bf46a116cfdd
 src/mlpack/tests/range_search_test.cpp | 137 +++++++++++++++++++++++++++++++++
 1 file changed, 137 insertions(+)

diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 9c5ed25..7823ff3 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -7,6 +7,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/range_search/range_search.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/methods/range_search/rs_model.hpp>
 #include <boost/test/unit_test.hpp>
 #include "old_boost_test_definitions.hpp"
 
@@ -1241,4 +1242,140 @@ BOOST_AUTO_TEST_CASE(MoveTrainTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(RSModelTest)
+{
+  math::RandomSeed(std::time(NULL));
+
+  // Ensure that we can build an RSModel and get correct results.
+  arma::mat queryData = arma::randu<arma::mat>(10, 50);
+  arma::mat referenceData = arma::randu<arma::mat>(10, 200);
+
+  // Build all the possible models.
+  RSModel models[10];
+  models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
+  models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
+  models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
+  models[3] = RSModel(RSModel::TreeTypes::COVER_TREE, false);
+  models[4] = RSModel(RSModel::TreeTypes::R_TREE, true);
+  models[5] = RSModel(RSModel::TreeTypes::R_TREE, false);
+  models[6] = RSModel(RSModel::TreeTypes::R_STAR_TREE, true);
+  models[7] = RSModel(RSModel::TreeTypes::R_STAR_TREE, false);
+  models[8] = RSModel(RSModel::TreeTypes::BALL_TREE, true);
+  models[9] = RSModel(RSModel::TreeTypes::BALL_TREE, false);
+
+  for (size_t j = 0; j < 2; ++j)
+  {
+    // Get a baseline.
+    RangeSearch<> rs(referenceData);
+    vector<vector<size_t>> baselineNeighbors;
+    vector<vector<double>> baselineDistances;
+    rs.Search(queryData, math::Range(0.25, 0.75), baselineNeighbors,
+        baselineDistances);
+
+    vector<vector<pair<double, size_t>>> baselineSorted;
+    SortResults(baselineNeighbors, baselineDistances, baselineSorted);
+
+    for (size_t i = 0; i < 10; ++i)
+    {
+      // We only have std::move() constructors, so make a copy of our data.
+      arma::mat referenceCopy(referenceData);
+      arma::mat queryCopy(queryData);
+      if (j == 0)
+        models[i].BuildModel(std::move(referenceCopy), 5, false, false);
+      else if (j == 1)
+        models[i].BuildModel(std::move(referenceCopy), 5, false, true);
+      else if (j == 2)
+        models[i].BuildModel(std::move(referenceCopy), 5, true, false);
+
+      vector<vector<size_t>> neighbors;
+      vector<vector<double>> distances;
+
+      models[i].Search(std::move(queryCopy), math::Range(0.25, 0.75), neighbors,
+          distances);
+
+      BOOST_REQUIRE_EQUAL(neighbors.size(), baselineNeighbors.size());
+      BOOST_REQUIRE_EQUAL(distances.size(), baselineDistances.size());
+
+      vector<vector<pair<double, size_t>>> sorted;
+      SortResults(neighbors, distances, sorted);
+
+      for (size_t k = 0; k < sorted.size(); ++k)
+      {
+        BOOST_REQUIRE_EQUAL(sorted[k].size(), baselineSorted[k].size());
+        for (size_t l = 0; l < sorted[k].size(); ++l)
+        {
+          BOOST_REQUIRE_EQUAL(sorted[k][l].second, baselineSorted[k][l].second);
+          BOOST_REQUIRE_CLOSE(sorted[k][l].first, baselineSorted[k][l].first,
+              1e-5);
+        }
+      }
+    }
+  }
+}
+
+BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
+{
+  // Ensure that we can build an RSModel and get correct results.
+  arma::mat referenceData = arma::randu<arma::mat>(10, 200);
+
+  // Build all the possible models.
+  RSModel models[10];
+  models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
+  models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
+  models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
+  models[3] = RSModel(RSModel::TreeTypes::COVER_TREE, false);
+  models[4] = RSModel(RSModel::TreeTypes::R_TREE, true);
+  models[5] = RSModel(RSModel::TreeTypes::R_TREE, false);
+  models[6] = RSModel(RSModel::TreeTypes::R_STAR_TREE, true);
+  models[7] = RSModel(RSModel::TreeTypes::R_STAR_TREE, false);
+  models[8] = RSModel(RSModel::TreeTypes::BALL_TREE, true);
+  models[9] = RSModel(RSModel::TreeTypes::BALL_TREE, false);
+
+  for (size_t j = 0; j < 2; ++j)
+  {
+    // Get a baseline.
+    RangeSearch<> rs(referenceData);
+    vector<vector<size_t>> baselineNeighbors;
+    vector<vector<double>> baselineDistances;
+    rs.Search(math::Range(0.25, 0.5), baselineNeighbors, baselineDistances);
+
+    vector<vector<pair<double, size_t>>> baselineSorted;
+    SortResults(baselineNeighbors, baselineDistances, baselineSorted);
+
+    for (size_t i = 0; i < 10; ++i)
+    {
+      // We only have std::move() cosntructors, so make a copy of our data.
+      arma::mat referenceCopy(referenceData);
+      if (j == 0)
+        models[i].BuildModel(std::move(referenceCopy), 5, false, false);
+      else if (j == 1)
+        models[i].BuildModel(std::move(referenceCopy), 5, false, true);
+      else if (j == 2)
+        models[i].BuildModel(std::move(referenceCopy), 5, true, false);
+
+      vector<vector<size_t>> neighbors;
+      vector<vector<double>> distances;
+
+      models[i].Search(math::Range(0.25, 0.5), neighbors, distances);
+
+      BOOST_REQUIRE_EQUAL(neighbors.size(), baselineNeighbors.size());
+      BOOST_REQUIRE_EQUAL(distances.size(), baselineDistances.size());
+
+      vector<vector<pair<double, size_t>>> sorted;
+      SortResults(neighbors, distances, sorted);
+
+      for (size_t k = 0; k < sorted.size(); ++k)
+      {
+        BOOST_REQUIRE_EQUAL(sorted[k].size(), baselineSorted[k].size());
+        for (size_t l = 0; l < sorted[k].size(); ++l)
+        {
+          BOOST_REQUIRE_EQUAL(sorted[k][l].second, baselineSorted[k][l].second);
+          BOOST_REQUIRE_CLOSE(sorted[k][l].first, baselineSorted[k][l].first,
+              1e-5);
+        }
+      }
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list