[mlpack-git] master: Add a helper class for knn model saving. (9ed5c23)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:25 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262
>---------------------------------------------------------------
commit 9ed5c236edcc626b44b2dcf3490a637effb10a9f
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 18 08:09:59 2015 -0400
Add a helper class for knn model saving.
>---------------------------------------------------------------
9ed5c236edcc626b44b2dcf3490a637effb10a9f
src/mlpack/methods/neighbor_search/CMakeLists.txt | 2 +
src/mlpack/methods/neighbor_search/ns_model.hpp | 138 +++++++
.../methods/neighbor_search/ns_model_impl.hpp | 401 +++++++++++++++++++++
src/mlpack/tests/allknn_test.cpp | 120 ++++++
4 files changed, 661 insertions(+)
diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index ddf3a8c..2f8d275 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -6,6 +6,8 @@ set(SOURCES
neighbor_search_rules.hpp
neighbor_search_rules_impl.hpp
neighbor_search_stat.hpp
+ ns_model.hpp
+ ns_model_impl.hpp
ns_traversal_info.hpp
sort_policies/nearest_neighbor_sort.hpp
sort_policies/nearest_neighbor_sort.cpp
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
new file mode 100644
index 0000000..e9e8042
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -0,0 +1,138 @@
+/**
+ * @file ns_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for nearest or furthest neighbor search. It is useful in
+ * that it provides an easy way to serialize a model, abstracts away the
+ * different types of trees, and also reflects the NeighborSearch API and
+ * automatically directs to the right tree type.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "neighbor_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy>
+struct NSModelName
+{
+ const static constexpr char value[22] = "neighbor_search_model";
+};
+
+template<>
+struct NSModelName<NearestNeighborSort>
+{
+ const static constexpr char value[30] = "nearest_neighbor_search_model";
+};
+
+template<>
+struct NSModelName<FurthestNeighborSort>
+{
+ const static constexpr char value[31] = "furthest_neighbor_search_model";
+};
+
+template<typename SortPolicy>
+class NSModel
+{
+ public:
+ enum TreeTypes
+ {
+ KD_TREE,
+ COVER_TREE,
+ R_TREE,
+ R_STAR_TREE
+ };
+
+ private:
+ int treeType;
+ size_t leafSize;
+
+ // For random projections.
+ bool randomBasis;
+ arma::mat q;
+
+ // Mappings, in case they are necessary.
+ std::vector<size_t> oldFromNewReferences;
+
+ template<template<typename TreeMetricType,
+ typename TreeStatType,
+ typename TreeMatType> class TreeType>
+ using NSType = NeighborSearch<SortPolicy,
+ metric::EuclideanDistance,
+ arma::mat,
+ TreeType>;
+
+ // Only one of these pointers will be non-NULL.
+ NSType<tree::KDTree>* kdTreeNS;
+ NSType<tree::StandardCoverTree>* coverTreeNS;
+ NSType<tree::RTree>* rTreeNS;
+ NSType<tree::RStarTree>* rStarTreeNS;
+
+ // This pointers is only non-null if we are using kd-trees and we built the
+ // tree ourselves (which only happens if BuildModel() is called).
+ typename NSType<tree::KDTree>::Tree* kdTree;
+
+ public:
+ /**
+ * Initialize the NSModel with the given type and whether or not a random
+ * basis should be used.
+ */
+ NSModel(int treeType = TreeTypes::KD_TREE, bool randomBasis = false);
+
+ //! Clean memory, if necessary.
+ ~NSModel();
+
+ //! Serialize the kNN model.
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
+ //! Expose the dataset.
+ const arma::mat& Dataset() const;
+
+ //! Expose singleMode.
+ bool SingleMode() const;
+ bool& SingleMode();
+
+ bool Naive() const;
+ bool& Naive();
+
+ size_t LeafSize() const { return leafSize; }
+ size_t& LeafSize() { return leafSize; }
+
+ int TreeType() const { return treeType; }
+ int& TreeType() { return treeType; }
+
+ bool RandomBasis() const { return randomBasis; }
+ bool& RandomBasis() { return randomBasis; }
+
+ //! Build the reference tree.
+ void BuildModel(arma::mat& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode);
+
+ //! Perform neighbor search. The query set will be reordered.
+ void Search(arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+
+ //! Perform neighbor search.
+ void Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances);
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+// Include implementation.
+#include "ns_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
new file mode 100644
index 0000000..be0df8e
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -0,0 +1,401 @@
+/**
+ * @file ns_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for nearest or furthest neighbor search. It is useful in
+ * that it provides an easy way to serialize a model, abstracts away the
+ * different types of trees, and also reflects the NeighborSearch API and
+ * automatically directs to the right tree type.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ns_model.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * Initialize the NSModel with the given type and whether or not a random
+ * basis should be used.
+ */
+template<typename SortPolicy>
+NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
+ treeType(treeType),
+ randomBasis(randomBasis),
+ kdTreeNS(NULL),
+ coverTreeNS(NULL),
+ rTreeNS(NULL),
+ rStarTreeNS(NULL),
+ kdTree(NULL)
+{
+ // Nothing to do.
+}
+
+//! Clean memory, if necessary.
+template<typename SortPolicy>
+NSModel<SortPolicy>::~NSModel()
+{
+ if (kdTree)
+ delete kdTree;
+
+ if (kdTreeNS)
+ delete kdTreeNS;
+ if (coverTreeNS)
+ delete coverTreeNS;
+ if (rTreeNS)
+ delete rTreeNS;
+ if (rStarTreeNS)
+ delete rStarTreeNS;
+}
+
+//! Serialize the kNN model.
+template<typename SortPolicy>
+ template<typename Archive>
+void NSModel<SortPolicy>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(treeType, "treeType");
+ ar & data::CreateNVP(randomBasis, "randomBasis");
+ ar & data::CreateNVP(q, "q");
+ ar & data::CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+ // This should never happen, but just in case, be clean with memory.
+ if (Archive::is_loading::value)
+ {
+ if (kdTree)
+ delete kdTree;
+
+ if (kdTreeNS)
+ delete kdTreeNS;
+ if (coverTreeNS)
+ delete coverTreeNS;
+ if (rTreeNS)
+ delete rTreeNS;
+ if (rStarTreeNS)
+ delete rStarTreeNS;
+
+ // Set all the pointers to NULL.
+ kdTree = NULL;
+
+ kdTreeNS = NULL;
+ coverTreeNS = NULL;
+ rTreeNS = NULL;
+ rStarTreeNS = NULL;
+ }
+
+ // We'll only need to serialize one of the kNN objects, based on the type.
+ const std::string& name = NSModelName<SortPolicy>::value;
+ switch (treeType)
+ {
+ case KD_TREE:
+ ar & data::CreateNVP(kdTreeNS, name);
+ break;
+ case COVER_TREE:
+ ar & data::CreateNVP(coverTreeNS, name);
+ break;
+ case R_TREE:
+ ar & data::CreateNVP(rTreeNS, name);
+ break;
+ case R_STAR_TREE:
+ ar & data::CreateNVP(rStarTreeNS, name);
+ break;
+ }
+}
+
+//! Expose singleMode.
+template<typename SortPolicy>
+bool NSModel<SortPolicy>::SingleMode() const
+{
+ if (kdTreeNS)
+ return kdTreeNS->SingleMode();
+ else if (coverTreeNS)
+ return coverTreeNS->SingleMode();
+ else if (rTreeNS)
+ return rTreeNS->SingleMode();
+ else if (rStarTreeNS)
+ return rStarTreeNS->SingleMode();
+
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool& NSModel<SortPolicy>::SingleMode()
+{
+ if (kdTreeNS)
+ return kdTreeNS->SingleMode();
+ else if (coverTreeNS)
+ return coverTreeNS->SingleMode();
+ else if (rTreeNS)
+ return rTreeNS->SingleMode();
+ else if (rStarTreeNS)
+ return rStarTreeNS->SingleMode();
+
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool NSModel<SortPolicy>::Naive() const
+{
+ if (kdTreeNS)
+ return kdTreeNS->Naive();
+ else if (coverTreeNS)
+ return coverTreeNS->Naive();
+ else if (rTreeNS)
+ return rTreeNS->Naive();
+ else if (rStarTreeNS)
+ return rStarTreeNS->Naive();
+
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool& NSModel<SortPolicy>::Naive()
+{
+ if (kdTreeNS)
+ return kdTreeNS->Naive();
+ else if (coverTreeNS)
+ return coverTreeNS->Naive();
+ else if (rTreeNS)
+ return rTreeNS->Naive();
+ else if (rStarTreeNS)
+ return rStarTreeNS->Naive();
+
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Build the reference tree.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
+ const size_t leafSize,
+ const bool naive,
+ const bool singleMode)
+{
+ // Initialize random basis if necessary.
+ if (randomBasis)
+ {
+ Log::Info << "Creating random basis..." << std::endl;
+ while (true)
+ {
+ // [Q, R] = qr(randn(d, d));
+ // Q = Q * diag(sign(diag(R)));
+ arma::mat r;
+ if (arma::qr(q, r, arma::randn<arma::mat>(referenceSet.n_rows,
+ referenceSet.n_rows)))
+ {
+ arma::vec rDiag(r.n_rows);
+ for (size_t i = 0; i < rDiag.n_elem; ++i)
+ {
+ if (r(i, i) < 0)
+ rDiag(i) = -1;
+ else if (r(i, i) > 0)
+ rDiag(i) = 1;
+ else
+ rDiag(i) = 0;
+ }
+
+ q *= arma::diagmat(rDiag);
+
+ // Check if the determinant is positive.
+ if (arma::det(q) >= 0)
+ break;
+ }
+ }
+ }
+
+ // Clean memory, if necessary.
+ if (kdTree)
+ delete kdTree;
+
+ if (kdTreeNS)
+ delete kdTreeNS;
+ if (coverTreeNS)
+ delete coverTreeNS;
+ if (rTreeNS)
+ delete rTreeNS;
+ if (rStarTreeNS)
+ delete rStarTreeNS;
+
+ // Do we need to modify the reference set?
+ if (randomBasis)
+ referenceSet = q * referenceSet;
+
+ if (!naive)
+ {
+ Timer::Start("tree_building");
+ Log::Info << "Building reference tree..." << std::endl;
+ }
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ // If necessary, build the kd-tree.
+ if (naive)
+ {
+ kdTreeNS = new NSType<tree::KDTree>(referenceSet, naive, singleMode);
+ }
+ else
+ {
+ kdTree = new typename NSType<tree::KDTree>::Tree(referenceSet,
+ oldFromNewReferences, leafSize);
+ kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
+ }
+
+ break;
+ case COVER_TREE:
+ // If necessary, build the cover tree.
+ coverTreeNS = new NSType<tree::StandardCoverTree>(referenceSet,
+ singleMode);
+ break;
+ case R_TREE:
+ // If necessary, build the R tree.
+ rTreeNS = new NSType<tree::RTree>(referenceSet, singleMode);
+ break;
+ case R_STAR_TREE:
+ // If necessary, build the R* tree.
+ rStarTreeNS = new NSType<tree::RStarTree>(referenceSet, singleMode);
+ break;
+ }
+
+ if (!naive)
+ {
+ Timer::Stop("tree_building");
+ Log::Info << "Tree built." << std::endl;
+ }
+}
+
+//! Perform neighbor search. The query set will be reordered.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::Search(arma::mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ // We may need to map the query set randomly.
+ if (randomBasis)
+ querySet = q * querySet;
+
+ Log::Info << "Searching for " << k << " nearest neighbors with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree search..." << std::endl;
+ else if (!Naive())
+ Log::Info << "single-tree search..." << std::endl;
+ else
+ Log::Info << "brute-force (naive) search..." << std::endl;
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ if (!kdTreeNS->Naive() && !kdTreeNS->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << std::endl;
+ std::vector<size_t> oldFromNewQueries;
+ typename NSType<tree::KDTree>::Tree queryTree(querySet,
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << std::endl;
+ Timer::Stop("tree_building");
+
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
+
+ // Unmap the results.
+ Unmap(neighborsOut, distancesOut, oldFromNewReferences,
+ oldFromNewQueries, neighbors, distances);
+ }
+ else if (kdTreeNS->SingleMode() && !kdTreeNS->Naive())
+ {
+ // Search without building a second tree.
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
+
+ Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
+ distances);
+ }
+ else
+ {
+ // Naive mode search. No unmapping will be necessary... unless a tree
+ // has been built.
+ if (oldFromNewReferences.size() == 0)
+ {
+ kdTreeNS->Search(querySet, k, neighbors, distances);
+ }
+ else
+ {
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
+
+ Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
+ distances);
+ }
+ }
+ break;
+ case COVER_TREE:
+ // No mapping necessary.
+ coverTreeNS->Search(querySet, k, neighbors, distances);
+ break;
+ case R_TREE:
+ // No mapping necessary.
+ rTreeNS->Search(querySet, k, neighbors, distances);
+ break;
+ case R_STAR_TREE:
+ // No mapping necessary.
+ rStarTreeNS->Search(querySet, k, neighbors, distances);
+ break;
+ }
+}
+
+//! Perform neighbor search.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::Search(const size_t k,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances)
+{
+ Log::Info << "Searching for " << k << " nearest neighbors with ";
+ if (!Naive() && !SingleMode())
+ Log::Info << "dual-tree search..." << std::endl;
+ else if (!Naive())
+ Log::Info << "single-tree search..." << std::endl;
+ else
+ Log::Info << "brute-force (naive) search..." << std::endl;
+
+ switch (treeType)
+ {
+ case KD_TREE:
+ // If in dual-tree or single-tree mode, we'll have to do unmapping. We
+ // also must do unmapping in naive mode, if a tree has been built on the
+ // data.
+ if (oldFromNewReferences.size() > 0) // Mapping has occured.
+ {
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ kdTreeNS->Search(k, neighborsOut, distancesOut);
+ Unmap(neighborsOut, distancesOut, oldFromNewReferences,
+ oldFromNewReferences, neighbors, distances);
+ }
+ else
+ {
+ kdTreeNS->Search(k, neighbors, distances);
+ }
+ break;
+ case COVER_TREE:
+ coverTreeNS->Search(k, neighbors, distances);
+ break;
+ case R_TREE:
+ rTreeNS->Search(k, neighbors, distances);
+ break;
+ case R_STAR_TREE:
+ rStarTreeNS->Search(k, neighbors, distances);
+ break;
+ }
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 8818920..cf22d5e 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -6,6 +6,7 @@
#include <mlpack/core.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
#include <mlpack/methods/neighbor_search/unmap.hpp>
+#include <mlpack/methods/neighbor_search/ns_model.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/example_tree.hpp>
#include <boost/test/unit_test.hpp>
@@ -900,4 +901,123 @@ BOOST_AUTO_TEST_CASE(SparseAllkNNCoverTreeTest)
}
*/
+BOOST_AUTO_TEST_CASE(KNNModelTest)
+{
+ // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+ // results.
+ typedef NSModel<NearestNeighborSort> KNNModel;
+
+ arma::mat queryData = arma::randu<arma::mat>(10, 100);
+ arma::mat referenceData = arma::randu<arma::mat>(10, 500);
+
+ // Build all the possible models.
+ KNNModel models[8];
+ models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+ models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+ models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+ models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+ models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+ models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+ models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+ models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+ for (size_t j = 0; j < 2; ++j)
+ {
+ // Get a baseline.
+ AllkNN knn(referenceData);
+ arma::Mat<size_t> baselineNeighbors;
+ arma::mat baselineDistances;
+ knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
+
+ for (size_t i = 0; i < 8; ++i)
+ {
+ if (j == 0)
+ models[i].BuildModel(referenceData, 20, false, false);
+ if (j == 1)
+ models[i].BuildModel(referenceData, 20, false, true);
+ if (j == 2)
+ models[i].BuildModel(referenceData, 20, true, false);
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ models[i].Search(queryData, 3, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, baselineNeighbors.n_rows);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, baselineNeighbors.n_cols);
+ BOOST_REQUIRE_EQUAL(neighbors.n_elem, baselineNeighbors.n_elem);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, baselineDistances.n_rows);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, baselineDistances.n_cols);
+ BOOST_REQUIRE_EQUAL(distances.n_elem, baselineDistances.n_elem);
+ for (size_t i = 0; i < distances.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors[i], baselineNeighbors[i]);
+ if (std::abs(baselineDistances[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(distances[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(distances[i], baselineDistances[i], 1e-5);
+ }
+ }
+ }
+}
+
+BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
+{
+ // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+ // results, in the case where the reference set is the same as the query set.
+ typedef NSModel<NearestNeighborSort> KNNModel;
+
+ arma::mat referenceData = arma::randu<arma::mat>(10, 500);
+
+ // Build all the possible models.
+ KNNModel models[8];
+ models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+ models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+ models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+ models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+ models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+ models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+ models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+ models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+ for (size_t j = 0; j < 2; ++j)
+ {
+ // Get a baseline.
+ AllkNN knn(referenceData);
+ arma::Mat<size_t> baselineNeighbors;
+ arma::mat baselineDistances;
+ knn.Search(3, baselineNeighbors, baselineDistances);
+
+ for (size_t i = 0; i < 8; ++i)
+ {
+ if (j == 0)
+ models[i].BuildModel(referenceData, 20, false, false);
+ if (j == 1)
+ models[i].BuildModel(referenceData, 20, false, true);
+ if (j == 2)
+ models[i].BuildModel(referenceData, 20, true, false);
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ models[i].Search(3, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, baselineNeighbors.n_rows);
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, baselineNeighbors.n_cols);
+ BOOST_REQUIRE_EQUAL(neighbors.n_elem, baselineNeighbors.n_elem);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, baselineDistances.n_rows);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, baselineDistances.n_cols);
+ BOOST_REQUIRE_EQUAL(distances.n_elem, baselineDistances.n_elem);
+ for (size_t i = 0; i < distances.n_elem; ++i)
+ {
+ BOOST_REQUIRE_EQUAL(neighbors[i], baselineNeighbors[i]);
+ if (std::abs(baselineDistances[i]) < 1e-5)
+ BOOST_REQUIRE_SMALL(distances[i], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(distances[i], baselineDistances[i], 1e-5);
+ }
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list