[mlpack-git] master: Add DrusillaSelect implementation. (5e0db4c)

gitdub at mlpack.org gitdub at mlpack.org
Mon Oct 24 03:27:10 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 5e0db4c90f96fe58ca31a2723a1a4e686043950d
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 24 16:27:10 2016 +0900

    Add DrusillaSelect implementation.


>---------------------------------------------------------------

5e0db4c90f96fe58ca31a2723a1a4e686043950d
 src/mlpack/methods/CMakeLists.txt                  |   1 +
 .../methods/approx_kfn/.drusilla_select.hpp.swp    | Bin 0 -> 16384 bytes
 .../approx_kfn/.drusilla_select_impl.hpp.swo       | Bin 0 -> 45056 bytes
 .../approx_kfn/.drusilla_select_impl.hpp.swp       | Bin 0 -> 20480 bytes
 .../methods/{lsh => approx_kfn}/CMakeLists.txt     |   8 +-
 src/mlpack/methods/approx_kfn/drusilla_select.hpp  | 125 ++++++++++++
 .../methods/approx_kfn/drusilla_select_impl.hpp    | 210 +++++++++++++++++++++
 .../methods/approx_kfn/drusilla_select_main.cpp    | 100 ++++++++++
 src/mlpack/tests/CMakeLists.txt                    |   1 +
 src/mlpack/tests/drusilla_select_test.cpp          | 145 ++++++++++++++
 10 files changed, 586 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/CMakeLists.txt b/src/mlpack/methods/CMakeLists.txt
index dbbd231..f292e97 100644
--- a/src/mlpack/methods/CMakeLists.txt
+++ b/src/mlpack/methods/CMakeLists.txt
@@ -18,6 +18,7 @@ endmacro ()
 set(DIRS
   preprocess
   adaboost
+  approx_kfn
   amf
   ann
   cf
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp
new file mode 100644
index 0000000..ae44b28
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select.hpp.swp differ
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo
new file mode 100644
index 0000000..b2bbbba
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swo differ
diff --git a/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp
new file mode 100644
index 0000000..9d5090f
Binary files /dev/null and b/src/mlpack/methods/approx_kfn/.drusilla_select_impl.hpp.swp differ
diff --git a/src/mlpack/methods/lsh/CMakeLists.txt b/src/mlpack/methods/approx_kfn/CMakeLists.txt
similarity index 82%
copy from src/mlpack/methods/lsh/CMakeLists.txt
copy to src/mlpack/methods/approx_kfn/CMakeLists.txt
index 3540e04..0e907d6 100644
--- a/src/mlpack/methods/lsh/CMakeLists.txt
+++ b/src/mlpack/methods/approx_kfn/CMakeLists.txt
@@ -1,9 +1,9 @@
 # Define the files we need to compile.
 # Anything not in this list will not be compiled into mlpack.
 set(SOURCES
-  # LSH-search class
-  lsh_search.hpp
-  lsh_search_impl.hpp
+  # DrusillaSelect sources.
+  drusilla_select.hpp
+  drusilla_select_impl.hpp
 )
 
 # Add directory name to sources.
@@ -17,4 +17,4 @@ set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
 
 # The code to compute the approximate neighbor for the given query and reference
 # sets with p-stable LSH.
-add_cli_executable(lsh)
+add_cli_executable(drusilla_select)
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select.hpp b/src/mlpack/methods/approx_kfn/drusilla_select.hpp
new file mode 100644
index 0000000..38b90ab
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select.hpp
@@ -0,0 +1,125 @@
+/**
+ * @file drusilla_select.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the approximate furthest neighbor algorithm specified in
+ * the following paper:
+ *
+ * @code
+ * @incollection{curtin2016fast,
+ *   title={Fast approximate furthest neighbors with data-dependent candidate
+ *          selection},
+ *   author={Curtin, R.R., and Gardner, A.B.},
+ *   booktitle={Similarity Search and Applications},
+ *   pages={221--235},
+ *   year={2016},
+ *   publisher={Springer}
+ * }
+ * @endcode
+ *
+ * This algorithm, called DrusillaSelect, constructs a candidate set of points
+ * to query to find an approximate furthest neighbor.  The strange name is a
+ * result of the algorithm being named after a cat.  The cat in question may be
+ * viewed at http://www.ratml.org/misc_img/drusilla_fence.png.
+ */
+#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP
+#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_HPP
+
+#include <mlpack/core.hpp>
+
+namespace mlpack {
+namespace neighbor {
+
+template<typename MatType = arma::mat>
+class DrusillaSelect
+{
+ public:
+  /**
+   * Construct the DrusillaSelect object with the given reference set (this is
+   * the set that will be searched).  The resulting set of candidate points that
+   * will be searched at query time will have size l*m.
+   *
+   * @param referenceSet Set of reference data.
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  DrusillaSelect(const MatType& referenceSet,
+                 const size_t l,
+                 const size_t m);
+
+  /**
+   * Construct the DrusillaSelect object with no given reference set.  Be sure
+   * to call Train() before calling Search()!
+   *
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  DrusillaSelect(const size_t l, const size_t m);
+
+  /**
+   * Build the set of candidate points on the given reference set.  If l and m
+   * are left unspecified, then the values set in the constructor will be used
+   * instead.
+   *
+   * @param referenceSet Set to extract candidate points from.
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  void Train(const MatType& referenceSet,
+             const size_t l = 0,
+             const size_t m = 0);
+
+  /**
+   * Search for the k furthest neighbors of the given query set.  (The query set
+   * can contain just one point: that is okay.)  The results will be stored in
+   * the given neighbors and distances matrices, in the same format as the
+   * NeighborSearch and LSHSearch classes.  That is, each column in the
+   * neighbors and distances matrices will refer to a single query point, and
+   * the k'th row in that column will refer to the k'th candidate neighbor or
+   * distance for that query point.
+   *
+   * @param querySet Set of query points to search.
+   * @param k Number of furthest neighbors to search for.
+   * @param neighbors Matrix to store resulting neighbors in.
+   * @param distances Matrix to store resulting distances in.
+   */
+  void Search(const MatType& querySet,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+  /**
+   * Serialize the model.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
+  //! Access the candidate set.
+  const MatType& CandidateSet() const { return candidateSet; }
+  //! Modify the candidate set.  Be careful!
+  MatType& CandidateSet() { return candidateSet; }
+
+  //! Access the indices of points in the candidate set.
+  const arma::Col<size_t>& CandidateIndices() const { return candidateIndices; }
+  //! Modify the indices of points in the candidate set.  Be careful!
+  arma::Col<size_t>& CandidateIndices() { return candidateIndices; }
+
+ private:
+  //! The reference set.
+  MatType candidateSet;
+  //! Indices of each point in the reference set.
+  arma::Col<size_t> candidateIndices;
+
+  //! The number of projections.
+  size_t l;
+  //! The number of points in each projection.
+  size_t m;
+};
+
+} // namespace neighbor
+} // namespace mlpack
+
+// Include implementation.
+#include "drusilla_select_impl.hpp"
+
+#endif
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
new file mode 100644
index 0000000..a84b304
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -0,0 +1,210 @@
+/**
+ * @file drusilla_select_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of DrusillaSelect class methods.
+ */
+#ifndef MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
+#define MLPACK_METHODS_APPROX_KFN_DRUSILLA_SELECT_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "drusilla_select.hpp"
+
+#include <queue>
+#include <mlpack/methods/neighbor_search/neighbor_search_rules.hpp>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <algorithm>
+
+namespace mlpack {
+namespace neighbor {
+
+// Constructor.
+template<typename MatType>
+DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
+                                        const size_t l,
+                                        const size_t m) :
+    candidateSet(referenceSet.n_rows, l * m),
+    l(l),
+    m(m)
+{
+  if (l == 0)
+    throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+        "value of l; must be greater than 0!");
+  else if (m == 0)
+    throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+        "value of m; must be greater than 0!");
+
+  Train(referenceSet, l, m);
+}
+
+// Constructor with no training.
+template<typename MatType>
+DrusillaSelect<MatType>::DrusillaSelect(const size_t l, const size_t m) :
+    l(l),
+    m(m)
+{
+  if (l == 0)
+    throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+        "value of l; must be greater than 0!");
+  else if (m == 0)
+    throw std::invalid_argument("DrusillaSelect::DrusillaSelect(): invalid "
+        "value of m; must be greater than 0!");
+}
+
+// Train the model.
+template<typename MatType>
+void DrusillaSelect<MatType>::Train(
+    const MatType& referenceSet,
+    const size_t lIn,
+    const size_t mIn)
+{
+  // Did the user specify a new size?  If so, use it.
+  if (lIn > 0)
+    l = lIn;
+  if (mIn > 0)
+    m = mIn;
+
+  if ((l * m) > referenceSet.n_cols)
+    throw std::invalid_argument("DrusillaSelect::Train(): l and m are too "
+        "large!  Choose smaller values.  l*m must be smaller than the number "
+        "of points in the dataset.");
+
+  arma::vec dataMean = arma::mean(referenceSet, 1);
+  arma::vec norms(referenceSet.n_cols);
+
+  arma::mat refCopy = referenceSet.each_col() - dataMean;
+  for (size_t i = 0; i < refCopy.n_cols; ++i)
+    norms[i] = arma::norm(refCopy.col(i) - dataMean);
+
+  // Find the top m points for each of the l projections...
+  for (size_t i = 0; i < l; ++i)
+  {
+    // Pick best index.
+    arma::uword maxIndex;
+    norms.max(maxIndex);
+
+    arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
+    const size_t n_nonzero = (size_t) arma::sum(norms > 0);
+
+    // Calculate distortion and offset.
+    arma::vec distortions(referenceSet.n_cols);
+    arma::vec offsets(referenceSet.n_cols);
+    for (size_t j = 0; j < referenceSet.n_cols; ++j)
+    {
+      if (norms[j] > 0.0)
+      {
+        offsets[j] = arma::dot(refCopy.col(j), line);
+        distortions[j] = arma::norm(refCopy.col(j) - offsets[j] *
+            line);
+      }
+      else
+      {
+        offsets[j] = 0.0;
+        distortions[j] = 0.0;
+      }
+    }
+    arma::vec sums = arma::abs(offsets) - arma::abs(distortions);
+    arma::uvec sortedSums = arma::sort_index(sums, "descend");
+
+    arma::vec bestSums(m);
+    arma::Col<size_t> bestIndices(m);
+    bestSums.fill(-DBL_MAX);
+
+    // Find the top m elements using a priority queue.
+    typedef std::pair<double, size_t> Candidate;
+    struct CandidateCmp
+    {
+      bool operator()(const Candidate& c1, const Candidate& c2)
+      {
+        return c2.first > c1.first;
+      }
+    };
+
+    std::vector<Candidate> clist(m, std::make_pair(size_t(-1), double(0.0)));
+    std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
+        pq(CandidateCmp(), std::move(clist));
+
+    for (size_t j = 0; j < sums.n_elem; ++j)
+    {
+      Candidate c = std::make_pair(sums[j], j);
+      if (CandidateCmp()(c, pq.top()))
+      {
+        pq.pop();
+        pq.push(c);
+      }
+    }
+
+    // Take the top m elements for this table.
+    for (size_t j = 0; j < m; ++j)
+    {
+      const size_t index = pq.top().second;
+      pq.pop();
+      candidateSet.col(i * m + j) = referenceSet.col(index);
+
+      // Mark the norm as 0 so we don't see this point again.
+      norms[index] = 0.0;
+    }
+
+    // Calculate angles from the current projection.  Anything close enough,
+    // mark the norm as 0.
+    arma::vec farPoints = arma::conv_to<arma::vec>::from(
+        arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0));
+    norms %= farPoints;
+  }
+}
+
+// Search.
+template<typename MatType>
+void DrusillaSelect<MatType>::Search(const MatType& querySet,
+                                     const size_t k,
+                                     arma::Mat<size_t>& neighbors,
+                                     arma::mat& distances)
+{
+  if (candidateSet.n_cols == 0)
+    throw std::runtime_error("DrusillaSelect::Search(): candidate set not "
+        "initialized!  Call Train() first.");
+
+  if (k > (l * m))
+    throw std::invalid_argument("DrusillaSelect::Search(): requested k is "
+        "greater than number of points in candidate set!  Increase l or m.");
+
+  // We'll use the NeighborSearchRules class to perform our brute-force search.
+  // Note that we aren't using trees for our search, so we can use 'int' as a
+  // TreeType.
+  metric::EuclideanDistance metric;
+  NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
+      tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
+      rules(querySet, candidateSet, k, metric, 0, false);
+
+  neighbors.set_size(k, querySet.n_cols);
+  neighbors.fill(size_t() - 1);
+  distances.zeros(k, querySet.n_cols);
+
+  for (size_t q = 0; q < querySet.n_cols; ++q)
+    for (size_t r = 0; r < candidateSet.n_cols; ++r)
+      rules.BaseCase(q, r);
+
+  // Map the neighbors back to their original indices in the reference set.
+  for (size_t i = 0; i < neighbors.n_elem; ++i)
+    neighbors[i] = candidateIndices[neighbors[i]];
+}
+
+//! Serialize the model.
+template<typename MatType>
+template<typename Archive>
+void DrusillaSelect<MatType>::Serialize(Archive& ar,
+                                        const unsigned int /* version */)
+{
+  using data::CreateNVP;
+
+  ar & CreateNVP(candidateSet, "candidateSet");
+  ar & CreateNVP(candidateIndices, "candidateIndices");
+  ar & CreateNVP(l, "l");
+  ar & CreateNVP(m, "m");
+}
+
+} // namespace neighbor
+} // namespace mlpack
+
+#endif
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
new file mode 100644
index 0000000..9e55ec7
--- /dev/null
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_main.cpp
@@ -0,0 +1,100 @@
+/**
+ * @file smarthash_main.cpp
+ * @author Ryan Curtin
+ *
+ * Command-line program for the SmartHash algorithm.
+ */
+#include <mlpack/core.hpp>
+#include "smarthash_fn.hpp"
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+using namespace smarthash;
+using namespace mlpack;
+using namespace std;
+
+PROGRAM_INFO("Query-dependent approximate furthest neighbor search",
+    "This program implements the algorithm from the SISAP 2015 paper titled "
+    "'Approximate Furthest Neighbor in High Dimensions' by R. Pagh, F. "
+    "Silvestri, J. Sivertsen, and M. Skala.  Specify a reference set (set to "
+    "search in) with --reference_file, specify a query set (set to search for) "
+    "with --query_file, and specify algorithm parameters with --num_tables and "
+    "--num_projections (or don't, and defaults will be used).  Also specify "
+    "the number of points to search for with --k.  Each of those options has "
+    "short names too; see the detailed parameter documentation below."
+    "\n\n"
+    "Results for each query point are stored in the files specified by "
+    "--neighbors_file and --distances_file.  This is in the same format as the "
+    "mlpack KFN and KNN programs: each row holds the k distances or neighbor "
+    "indices for each query point.");
+
+PARAM_STRING_REQ("reference_file", "File containing reference points.", "r");
+PARAM_STRING_REQ("query_file", "File containing query points.", "q");
+
+PARAM_INT_REQ("k", "Number of furthest neighbors to search for.", "k");
+
+PARAM_INT("num_tables", "Number of hash tables to use.", "t", 10);
+PARAM_INT("num_projections", "Number of projections to use in each hash table.",
+    "p", 30);
+
+PARAM_STRING("neighbors_file", "File to save furthest neighbor indices to.",
+    "n", "");
+PARAM_STRING("distances_file", "File to save furthest neighbor distances to.",
+    "d", "");
+
+PARAM_FLAG("calculate_error", "If set, calculate the average distance error.",
+    "e");
+PARAM_STRING("exact_distances_file", "File containing exact distances", "x", "");
+
+int main(int argc, char** argv)
+{
+  CLI::ParseCommandLine(argc, argv);
+
+  const string referenceFile = CLI::GetParam<string>("reference_file");
+  const string queryFile = CLI::GetParam<string>("query_file");
+  const size_t k = (size_t) CLI::GetParam<int>("k");
+  const size_t numTables = (size_t) CLI::GetParam<int>("num_tables");
+  const size_t numProjections = (size_t) CLI::GetParam<int>("num_projections");
+
+  // Load the data.
+  arma::mat referenceData, queryData;
+  data::Load(referenceFile, referenceData, true);
+  data::Load(queryFile, queryData, true);
+
+  // Construct the object.
+  Timer::Start("smarthash_construct");
+  SmartHash<> q(referenceData, numTables, numProjections);
+  Timer::Stop("smarthash_construct");
+
+  // Do the search.
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+  Timer::Start("smarthash_search");
+  q.Search(queryData, k, neighbors, distances);
+  Timer::Stop("smarthash_search");
+
+  if (CLI::HasParam("calculate_error"))
+  {
+//    neighbor::AllkFN kfn(referenceData);
+
+//    arma::Mat<size_t> trueNeighbors;
+    arma::mat trueDistances;
+    data::Load(CLI::GetParam<string>("exact_distances_file"), trueDistances);
+
+//    kfn.Search(queryData, 1, trueNeighbors, trueDistances);
+
+    const double averageError = arma::sum(trueDistances / distances.row(0)) /
+        distances.n_cols;
+    const double minError = arma::min(trueDistances / distances.row(0));
+    const double maxError = arma::max(trueDistances / distances.row(0));
+
+    Log::Info << "Average error: " << averageError << "." << endl;
+    Log::Info << "Maximum error: " << maxError << "." << endl;
+    Log::Info << "Minimum error: " << minError << "." << endl;
+  }
+
+  // Save the results.
+  if (CLI::HasParam("neighbors_file"))
+    data::Save(CLI::GetParam<string>("neighbors_file"), neighbors);
+  if (CLI::HasParam("distances_file"))
+    data::Save(CLI::GetParam<string>("distances_file"), distances);
+}
diff --git a/src/mlpack/tests/CMakeLists.txt b/src/mlpack/tests/CMakeLists.txt
index 9ad4092..a93f7bf 100644
--- a/src/mlpack/tests/CMakeLists.txt
+++ b/src/mlpack/tests/CMakeLists.txt
@@ -18,6 +18,7 @@ add_executable(mlpack_test
   decision_stump_test.cpp
   det_test.cpp
   distribution_test.cpp
+  drusilla_select_test.cpp
   emst_test.cpp
   fastmks_test.cpp
   feedforward_network_test.cpp
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
new file mode 100644
index 0000000..504fd62
--- /dev/null
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -0,0 +1,145 @@
+/**
+ * @file drusilla_select_test.cpp
+ * @author Ryan Curtin
+ *
+ * Test for DrusillaSelect.
+ */
+#include <mlpack/methods/approx_kfn/drusilla_select.hpp>
+#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "test_tools.hpp"
+#include "serialization.hpp"
+
+using namespace mlpack;
+using namespace mlpack::neighbor;
+
+BOOST_AUTO_TEST_SUITE(DrusillaSelectTest);
+
+// If we have a dataset with an extreme outlier, then every point (except that
+// one) should end up with that point as the furthest neighbor candidate.
+BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(5, 100);
+  dataset.col(100) += 100; // Make last column very large.
+
+  // Construct with some reasonable parameters.
+  DrusillaSelect<> ds(dataset, 5, 5);
+
+  // Query with every point except the extreme point.
+  arma::mat distances;
+  arma::Mat<size_t> neighbors;
+  ds.Search(dataset.cols(0, 99), 1, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 99);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
+
+  for (size_t i = 0; i < 99; ++i)
+    BOOST_REQUIRE_EQUAL(neighbors[i], 100);
+}
+
+// If we use only one projection with the number of points equal to what is in
+// the dataset, we should end up with the exact result.
+BOOST_AUTO_TEST_CASE(DrusillaSelectExhaustiveExactTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(5, 100);
+
+  // Construct with one projection and 100 points in that projection.
+  DrusillaSelect<> ds(dataset, 100, 1);
+
+  arma::mat distances, distancesTrue;
+  arma::Mat<size_t> neighbors, neighborsTrue;
+
+  ds.Search(dataset, 5, neighbors, distances);
+
+  AllkFN kfn(dataset);
+  kfn.Search(dataset, 5, neighborsTrue, distancesTrue);
+
+  BOOST_REQUIRE_EQUAL(neighborsTrue.n_cols, neighbors.n_cols);
+  BOOST_REQUIRE_EQUAL(neighborsTrue.n_rows, neighbors.n_rows);
+  BOOST_REQUIRE_EQUAL(distancesTrue.n_cols, distances.n_cols);
+  BOOST_REQUIRE_EQUAL(distancesTrue.n_rows, distances.n_rows);
+
+  for (size_t i = 0; i < distances.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(neighbors[i], neighborsTrue[i]);
+    BOOST_REQUIRE_CLOSE(distances[i], distancesTrue[i], 1e-5);
+  }
+}
+
+// Test that we can call Train() after calling the constructor.
+BOOST_AUTO_TEST_CASE(RetrainTest)
+{
+  arma::mat firstDataset = arma::randu<arma::mat>(3, 10);
+  arma::mat dataset = arma::randu<arma::mat>(3, 200);
+
+  DrusillaSelect<> ds(firstDataset, 3, 3);
+  ds.Train(std::move(dataset), 2, 2);
+
+  arma::mat distances;
+  arma::Mat<size_t> neighbors;
+  ds.Search(dataset, 1, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
+}
+
+// Test serialization.
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+  // Create a random dataset.
+  arma::mat dataset = arma::randu<arma::mat>(3, 100);
+
+  DrusillaSelect<> ds(dataset, 3, 3);
+
+  arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 5);
+  arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 8);
+  DrusillaSelect<> dsXml(fakeDataset1, 10, 10);
+  DrusillaSelect<> dsText(2, 2);
+  DrusillaSelect<> dsBinary(5, 6);
+  dsBinary.Train(fakeDataset2);
+
+  // Now do the serialization.
+  SerializeObjectAll(ds, dsXml, dsText, dsBinary);
+
+  // Now do a search and make sure all the results are the same.
+  arma::Mat<size_t> neighbors, neighborsXml, neighborsText, neighborsBinary;
+  arma::mat distances, distancesXml, distancesText, distancesBinary;
+
+  ds.Search(dataset, 3, neighbors, distances);
+  dsXml.Search(dataset, 3, neighborsXml, distancesXml);
+  dsText.Search(dataset, 3, neighborsText, distancesText);
+  dsBinary.Search(dataset, 3, neighborsBinary, distancesBinary);
+
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsXml.n_rows);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsXml.n_cols);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsText.n_rows);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsText.n_cols);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, neighborsBinary.n_rows);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, neighborsBinary.n_cols);
+
+  BOOST_REQUIRE_EQUAL(distances.n_rows, distancesXml.n_rows);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, distancesXml.n_cols);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, distancesText.n_rows);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, distancesText.n_cols);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, distancesBinary.n_rows);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, distancesBinary.n_cols);
+
+  for (size_t i = 0; i < neighbors.n_elem; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(neighbors[i], neighborsXml[i]);
+    BOOST_REQUIRE_EQUAL(neighbors[i], neighborsText[i]);
+    BOOST_REQUIRE_EQUAL(neighbors[i], neighborsBinary[i]);
+
+    BOOST_REQUIRE_CLOSE(distances[i], distancesXml[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(distances[i], distancesText[i], 1e-5);
+    BOOST_REQUIRE_CLOSE(distances[i], distancesBinary[i], 1e-5);
+  }
+}
+
+BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list