[mlpack-git] master: Add a helper class for knn model saving. (9ed5c23)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Mon Oct 19 16:04:25 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/09cd0d67f2fdae252a8ab85324e71dbb4dfe0010...fecf1194c123ced12d56e7daad761c7b9aaac262

>---------------------------------------------------------------

commit 9ed5c236edcc626b44b2dcf3490a637effb10a9f
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 18 08:09:59 2015 -0400

    Add a helper class for knn model saving.


>---------------------------------------------------------------

9ed5c236edcc626b44b2dcf3490a637effb10a9f
 src/mlpack/methods/neighbor_search/CMakeLists.txt  |   2 +
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 138 +++++++
 .../methods/neighbor_search/ns_model_impl.hpp      | 401 +++++++++++++++++++++
 src/mlpack/tests/allknn_test.cpp                   | 120 ++++++
 4 files changed, 661 insertions(+)

diff --git a/src/mlpack/methods/neighbor_search/CMakeLists.txt b/src/mlpack/methods/neighbor_search/CMakeLists.txt
index ddf3a8c..2f8d275 100644
--- a/src/mlpack/methods/neighbor_search/CMakeLists.txt
+++ b/src/mlpack/methods/neighbor_search/CMakeLists.txt
@@ -6,6 +6,8 @@ set(SOURCES
   neighbor_search_rules.hpp
   neighbor_search_rules_impl.hpp
   neighbor_search_stat.hpp
+  ns_model.hpp
+  ns_model_impl.hpp
   ns_traversal_info.hpp
   sort_policies/nearest_neighbor_sort.hpp
   sort_policies/nearest_neighbor_sort.cpp
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
new file mode 100644
index 0000000..e9e8042
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -0,0 +1,138 @@
+/**
+ * @file ns_model.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for nearest or furthest neighbor search.  It is useful in
+ * that it provides an easy way to serialize a model, abstracts away the
+ * different types of trees, and also reflects the NeighborSearch API and
+ * automatically directs to the right tree type.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+
+#include "neighbor_search.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename SortPolicy>
+struct NSModelName
+{
+  const static constexpr char value[22] = "neighbor_search_model";
+};
+
+template<>
+struct NSModelName<NearestNeighborSort>
+{
+  const static constexpr char value[30] = "nearest_neighbor_search_model";
+};
+
+template<>
+struct NSModelName<FurthestNeighborSort>
+{
+  const static constexpr char value[31] = "furthest_neighbor_search_model";
+};
+
+template<typename SortPolicy>
+class NSModel
+{
+ public:
+  enum TreeTypes
+  {
+    KD_TREE,
+    COVER_TREE,
+    R_TREE,
+    R_STAR_TREE
+  };
+
+ private:
+  int treeType;
+  size_t leafSize;
+
+  // For random projections.
+  bool randomBasis;
+  arma::mat q;
+
+  // Mappings, in case they are necessary.
+  std::vector<size_t> oldFromNewReferences;
+
+  template<template<typename TreeMetricType,
+                    typename TreeStatType,
+                    typename TreeMatType> class TreeType>
+  using NSType = NeighborSearch<SortPolicy,
+                                metric::EuclideanDistance,
+                                arma::mat,
+                                TreeType>;
+
+  // Only one of these pointers will be non-NULL.
+  NSType<tree::KDTree>* kdTreeNS;
+  NSType<tree::StandardCoverTree>* coverTreeNS;
+  NSType<tree::RTree>* rTreeNS;
+  NSType<tree::RStarTree>* rStarTreeNS;
+
+  // This pointers is only non-null if we are using kd-trees and we built the
+  // tree ourselves (which only happens if BuildModel() is called).
+  typename NSType<tree::KDTree>::Tree* kdTree;
+
+ public:
+  /**
+   * Initialize the NSModel with the given type and whether or not a random
+   * basis should be used.
+   */
+  NSModel(int treeType = TreeTypes::KD_TREE, bool randomBasis = false);
+
+  //! Clean memory, if necessary.
+  ~NSModel();
+
+  //! Serialize the kNN model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Expose the dataset.
+  const arma::mat& Dataset() const;
+
+  //! Expose singleMode.
+  bool SingleMode() const;
+  bool& SingleMode();
+
+  bool Naive() const;
+  bool& Naive();
+
+  size_t LeafSize() const { return leafSize; }
+  size_t& LeafSize() { return leafSize; }
+
+  int TreeType() const { return treeType; }
+  int& TreeType() { return treeType; }
+
+  bool RandomBasis() const { return randomBasis; }
+  bool& RandomBasis() { return randomBasis; }
+
+  //! Build the reference tree.
+  void BuildModel(arma::mat& referenceSet,
+                  const size_t leafSize,
+                  const bool naive,
+                  const bool singleMode);
+
+  //! Perform neighbor search.  The query set will be reordered.
+  void Search(arma::mat& querySet,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  //! Perform neighbor search.
+  void Search(const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+// Include implementation.
+#include "ns_model_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
new file mode 100644
index 0000000..be0df8e
--- /dev/null
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -0,0 +1,401 @@
+/**
+ * @file ns_model_impl.hpp
+ * @author Ryan Curtin
+ *
+ * This is a model for nearest or furthest neighbor search.  It is useful in
+ * that it provides an easy way to serialize a model, abstracts away the
+ * different types of trees, and also reflects the NeighborSearch API and
+ * automatically directs to the right tree type.
+ */
+#ifndef __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
+#define __MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "ns_model.hpp"
+
+namespace mlpack {
+namespace neighbor {
+
+/**
+ * Initialize the NSModel with the given type and whether or not a random
+ * basis should be used.
+ */
+template<typename SortPolicy>
+NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
+    treeType(treeType),
+    randomBasis(randomBasis),
+    kdTreeNS(NULL),
+    coverTreeNS(NULL),
+    rTreeNS(NULL),
+    rStarTreeNS(NULL),
+    kdTree(NULL)
+{
+  // Nothing to do.
+}
+
+//! Clean memory, if necessary.
+template<typename SortPolicy>
+NSModel<SortPolicy>::~NSModel()
+{
+  if (kdTree)
+    delete kdTree;
+
+  if (kdTreeNS)
+    delete kdTreeNS;
+  if (coverTreeNS)
+    delete coverTreeNS;
+  if (rTreeNS)
+    delete rTreeNS;
+  if (rStarTreeNS)
+    delete rStarTreeNS;
+}
+
+//! Serialize the kNN model.
+template<typename SortPolicy>
+  template<typename Archive>
+void NSModel<SortPolicy>::Serialize(Archive& ar, 
+                                    const unsigned int /* version */)
+{
+  ar & data::CreateNVP(treeType, "treeType");
+  ar & data::CreateNVP(randomBasis, "randomBasis");
+  ar & data::CreateNVP(q, "q");
+  ar & data::CreateNVP(oldFromNewReferences, "oldFromNewReferences");
+
+  // This should never happen, but just in case, be clean with memory.
+  if (Archive::is_loading::value)
+  {
+    if (kdTree)
+      delete kdTree;
+
+    if (kdTreeNS)
+      delete kdTreeNS;
+    if (coverTreeNS)
+      delete coverTreeNS;
+    if (rTreeNS)
+      delete rTreeNS;
+    if (rStarTreeNS)
+      delete rStarTreeNS;
+
+    // Set all the pointers to NULL.
+    kdTree = NULL;
+
+    kdTreeNS = NULL;
+    coverTreeNS = NULL;
+    rTreeNS = NULL;
+    rStarTreeNS = NULL;
+  }
+
+  // We'll only need to serialize one of the kNN objects, based on the type.
+  const std::string& name = NSModelName<SortPolicy>::value;
+  switch (treeType)
+  {
+    case KD_TREE:
+      ar & data::CreateNVP(kdTreeNS, name);
+      break;
+    case COVER_TREE:
+      ar & data::CreateNVP(coverTreeNS, name);
+      break;
+    case R_TREE:
+      ar & data::CreateNVP(rTreeNS, name);
+      break;
+    case R_STAR_TREE:
+      ar & data::CreateNVP(rStarTreeNS, name);
+      break;
+  }
+}
+
+//! Expose singleMode.
+template<typename SortPolicy>
+bool NSModel<SortPolicy>::SingleMode() const
+{
+  if (kdTreeNS)
+    return kdTreeNS->SingleMode();
+  else if (coverTreeNS)
+    return coverTreeNS->SingleMode();
+  else if (rTreeNS)
+    return rTreeNS->SingleMode();
+  else if (rStarTreeNS)
+    return rStarTreeNS->SingleMode();
+
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool& NSModel<SortPolicy>::SingleMode()
+{
+  if (kdTreeNS)
+    return kdTreeNS->SingleMode();
+  else if (coverTreeNS)
+    return coverTreeNS->SingleMode();
+  else if (rTreeNS)
+    return rTreeNS->SingleMode();
+  else if (rStarTreeNS)
+    return rStarTreeNS->SingleMode();
+
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool NSModel<SortPolicy>::Naive() const
+{
+  if (kdTreeNS)
+    return kdTreeNS->Naive();
+  else if (coverTreeNS)
+    return coverTreeNS->Naive();
+  else if (rTreeNS)
+    return rTreeNS->Naive();
+  else if (rStarTreeNS)
+    return rStarTreeNS->Naive();
+
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+template<typename SortPolicy>
+bool& NSModel<SortPolicy>::Naive()
+{
+  if (kdTreeNS)
+    return kdTreeNS->Naive();
+  else if (coverTreeNS)
+    return coverTreeNS->Naive();
+  else if (rTreeNS)
+    return rTreeNS->Naive();
+  else if (rStarTreeNS)
+    return rStarTreeNS->Naive();
+
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
+//! Build the reference tree.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::BuildModel(arma::mat& referenceSet,
+                                     const size_t leafSize,
+                                     const bool naive,
+                                     const bool singleMode)
+{
+  // Initialize random basis if necessary.
+  if (randomBasis)
+  {
+    Log::Info << "Creating random basis..." << std::endl;
+    while (true)
+    {
+      // [Q, R] = qr(randn(d, d));
+      // Q = Q * diag(sign(diag(R)));
+      arma::mat r;
+      if (arma::qr(q, r, arma::randn<arma::mat>(referenceSet.n_rows,
+              referenceSet.n_rows)))
+      {
+        arma::vec rDiag(r.n_rows);
+        for (size_t i = 0; i < rDiag.n_elem; ++i)
+        {
+          if (r(i, i) < 0)
+            rDiag(i) = -1;
+          else if (r(i, i) > 0)
+            rDiag(i) = 1;
+          else
+            rDiag(i) = 0;
+        }
+
+        q *= arma::diagmat(rDiag);
+
+        // Check if the determinant is positive.
+        if (arma::det(q) >= 0)
+          break;
+      }
+    }
+  }
+
+  // Clean memory, if necessary.
+  if (kdTree)
+    delete kdTree;
+
+  if (kdTreeNS)
+    delete kdTreeNS;
+  if (coverTreeNS)
+    delete coverTreeNS;
+  if (rTreeNS)
+    delete rTreeNS;
+  if (rStarTreeNS)
+    delete rStarTreeNS;
+
+  // Do we need to modify the reference set?
+  if (randomBasis)
+    referenceSet = q * referenceSet;
+
+  if (!naive)
+  {
+    Timer::Start("tree_building");
+    Log::Info << "Building reference tree..." << std::endl;
+  }
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      // If necessary, build the kd-tree.
+      if (naive)
+      {
+        kdTreeNS = new NSType<tree::KDTree>(referenceSet, naive, singleMode);
+      }
+      else
+      {
+        kdTree = new typename NSType<tree::KDTree>::Tree(referenceSet,
+            oldFromNewReferences, leafSize);
+        kdTreeNS = new NSType<tree::KDTree>(kdTree, singleMode);
+      }
+
+      break;
+    case COVER_TREE:
+      // If necessary, build the cover tree.
+      coverTreeNS = new NSType<tree::StandardCoverTree>(referenceSet,
+          singleMode);
+      break;
+    case R_TREE:
+      // If necessary, build the R tree.
+      rTreeNS = new NSType<tree::RTree>(referenceSet, singleMode);
+      break;
+    case R_STAR_TREE:
+      // If necessary, build the R* tree.
+      rStarTreeNS = new NSType<tree::RStarTree>(referenceSet, singleMode);
+      break;
+  }
+
+  if (!naive)
+  {
+    Timer::Stop("tree_building");
+    Log::Info << "Tree built." << std::endl;
+  }
+}
+
+//! Perform neighbor search.  The query set will be reordered.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::Search(arma::mat& querySet,
+                                 const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
+{
+  // We may need to map the query set randomly.
+  if (randomBasis)
+    querySet = q * querySet;
+
+  Log::Info << "Searching for " << k << " nearest neighbors with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree search..." << std::endl;
+  else if (!Naive())
+    Log::Info << "single-tree search..." << std::endl;
+  else
+    Log::Info << "brute-force (naive) search..." << std::endl;
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      if (!kdTreeNS->Naive() && !kdTreeNS->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << std::endl;
+        std::vector<size_t> oldFromNewQueries;
+        typename NSType<tree::KDTree>::Tree queryTree(querySet,
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << std::endl;
+        Timer::Stop("tree_building");
+
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        kdTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
+
+        // Unmap the results.
+        Unmap(neighborsOut, distancesOut, oldFromNewReferences,
+            oldFromNewQueries, neighbors, distances);
+      }
+      else if (kdTreeNS->SingleMode() && !kdTreeNS->Naive())
+      {
+        // Search without building a second tree.
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
+
+        Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
+            distances);
+      }
+      else
+      {
+        // Naive mode search.  No unmapping will be necessary... unless a tree
+        // has been built.
+        if (oldFromNewReferences.size() == 0)
+        {
+          kdTreeNS->Search(querySet, k, neighbors, distances);
+        }
+        else
+        {
+          arma::Mat<size_t> neighborsOut;
+          arma::mat distancesOut;
+          kdTreeNS->Search(querySet, k, neighborsOut, distancesOut);
+
+          Unmap(neighborsOut, distancesOut, oldFromNewReferences, neighbors,
+              distances);
+        }
+      }
+      break;
+    case COVER_TREE:
+      // No mapping necessary.
+      coverTreeNS->Search(querySet, k, neighbors, distances);
+      break;
+    case R_TREE:
+      // No mapping necessary.
+      rTreeNS->Search(querySet, k, neighbors, distances);
+      break;
+    case R_STAR_TREE:
+      // No mapping necessary.
+      rStarTreeNS->Search(querySet, k, neighbors, distances);
+      break;
+  }
+}
+
+//! Perform neighbor search.
+template<typename SortPolicy>
+void NSModel<SortPolicy>::Search(const size_t k,
+                                 arma::Mat<size_t>& neighbors,
+                                 arma::mat& distances)
+{
+  Log::Info << "Searching for " << k << " nearest neighbors with ";
+  if (!Naive() && !SingleMode())
+    Log::Info << "dual-tree search..." << std::endl;
+  else if (!Naive())
+    Log::Info << "single-tree search..." << std::endl;
+  else
+    Log::Info << "brute-force (naive) search..." << std::endl;
+
+  switch (treeType)
+  {
+    case KD_TREE:
+      // If in dual-tree or single-tree mode, we'll have to do unmapping.  We
+      // also must do unmapping in naive mode, if a tree has been built on the
+      // data.
+      if (oldFromNewReferences.size() > 0) // Mapping has occured.
+      {
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        kdTreeNS->Search(k, neighborsOut, distancesOut);
+        Unmap(neighborsOut, distancesOut, oldFromNewReferences,
+            oldFromNewReferences, neighbors, distances);
+      }
+      else
+      {
+        kdTreeNS->Search(k, neighbors, distances);
+      }
+      break;
+    case COVER_TREE:
+      coverTreeNS->Search(k, neighbors, distances);
+      break;
+    case R_TREE:
+      rTreeNS->Search(k, neighbors, distances);
+      break;
+    case R_STAR_TREE:
+      rStarTreeNS->Search(k, neighbors, distances);
+      break;
+  }
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 8818920..cf22d5e 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -6,6 +6,7 @@
 #include <mlpack/core.hpp>
 #include <mlpack/methods/neighbor_search/neighbor_search.hpp>
 #include <mlpack/methods/neighbor_search/unmap.hpp>
+#include <mlpack/methods/neighbor_search/ns_model.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/example_tree.hpp>
 #include <boost/test/unit_test.hpp>
@@ -900,4 +901,123 @@ BOOST_AUTO_TEST_CASE(SparseAllkNNCoverTreeTest)
 }
 */
 
+BOOST_AUTO_TEST_CASE(KNNModelTest)
+{
+  // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+  // results.
+  typedef NSModel<NearestNeighborSort> KNNModel;
+
+  arma::mat queryData = arma::randu<arma::mat>(10, 100);
+  arma::mat referenceData = arma::randu<arma::mat>(10, 500);
+
+  // Build all the possible models.
+  KNNModel models[8];
+  models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+  models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+  models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+  models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+  models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+  models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+  models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+  models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+  for (size_t j = 0; j < 2; ++j)
+  {
+    // Get a baseline.
+    AllkNN knn(referenceData);
+    arma::Mat<size_t> baselineNeighbors;
+    arma::mat baselineDistances;
+    knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
+
+    for (size_t i = 0; i < 8; ++i)
+    {
+      if (j == 0)
+        models[i].BuildModel(referenceData, 20, false, false);
+      if (j == 1)
+        models[i].BuildModel(referenceData, 20, false, true);
+      if (j == 2)
+        models[i].BuildModel(referenceData, 20, true, false);
+
+      arma::Mat<size_t> neighbors;
+      arma::mat distances;
+
+      models[i].Search(queryData, 3, neighbors, distances);
+
+      BOOST_REQUIRE_EQUAL(neighbors.n_rows, baselineNeighbors.n_rows);
+      BOOST_REQUIRE_EQUAL(neighbors.n_cols, baselineNeighbors.n_cols);
+      BOOST_REQUIRE_EQUAL(neighbors.n_elem, baselineNeighbors.n_elem);
+      BOOST_REQUIRE_EQUAL(distances.n_rows, baselineDistances.n_rows);
+      BOOST_REQUIRE_EQUAL(distances.n_cols, baselineDistances.n_cols);
+      BOOST_REQUIRE_EQUAL(distances.n_elem, baselineDistances.n_elem);
+      for (size_t i = 0; i < distances.n_elem; ++i)
+      {
+        BOOST_REQUIRE_EQUAL(neighbors[i], baselineNeighbors[i]);
+        if (std::abs(baselineDistances[i]) < 1e-5)
+          BOOST_REQUIRE_SMALL(distances[i], 1e-5);
+        else
+          BOOST_REQUIRE_CLOSE(distances[i], baselineDistances[i], 1e-5);
+      }
+    }
+  }
+}
+
+BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
+{
+  // Ensure that we can build an NSModel<NearestNeighborSearch> and get correct
+  // results, in the case where the reference set is the same as the query set.
+  typedef NSModel<NearestNeighborSort> KNNModel;
+
+  arma::mat referenceData = arma::randu<arma::mat>(10, 500);
+
+  // Build all the possible models.
+  KNNModel models[8];
+  models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
+  models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
+  models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
+  models[3] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
+  models[4] = KNNModel(KNNModel::TreeTypes::R_TREE, true);
+  models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
+  models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
+  models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+
+  for (size_t j = 0; j < 2; ++j)
+  {
+    // Get a baseline.
+    AllkNN knn(referenceData);
+    arma::Mat<size_t> baselineNeighbors;
+    arma::mat baselineDistances;
+    knn.Search(3, baselineNeighbors, baselineDistances);
+
+    for (size_t i = 0; i < 8; ++i)
+    {
+      if (j == 0)
+        models[i].BuildModel(referenceData, 20, false, false);
+      if (j == 1)
+        models[i].BuildModel(referenceData, 20, false, true);
+      if (j == 2)
+        models[i].BuildModel(referenceData, 20, true, false);
+
+      arma::Mat<size_t> neighbors;
+      arma::mat distances;
+
+      models[i].Search(3, neighbors, distances);
+
+      BOOST_REQUIRE_EQUAL(neighbors.n_rows, baselineNeighbors.n_rows);
+      BOOST_REQUIRE_EQUAL(neighbors.n_cols, baselineNeighbors.n_cols);
+      BOOST_REQUIRE_EQUAL(neighbors.n_elem, baselineNeighbors.n_elem);
+      BOOST_REQUIRE_EQUAL(distances.n_rows, baselineDistances.n_rows);
+      BOOST_REQUIRE_EQUAL(distances.n_cols, baselineDistances.n_cols);
+      BOOST_REQUIRE_EQUAL(distances.n_elem, baselineDistances.n_elem);
+      for (size_t i = 0; i < distances.n_elem; ++i)
+      {
+        BOOST_REQUIRE_EQUAL(neighbors[i], baselineNeighbors[i]);
+        if (std::abs(baselineDistances[i]) < 1e-5)
+          BOOST_REQUIRE_SMALL(distances[i], 1e-5);
+        else
+          BOOST_REQUIRE_CLOSE(distances[i], baselineDistances[i], 1e-5);
+      }
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list