[mlpack-git] master: Add comprehensive tests for QDAFN. (15f4b07)

gitdub at mlpack.org gitdub at mlpack.org
Sun Oct 30 08:32:24 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 15f4b073adc1410d182cc94d5fded590331eff71
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 30 21:32:24 2016 +0900

    Add comprehensive tests for QDAFN.
    
    There is a bug now, but I have to push this to be able to solve it so I can get
    to a system that has working gdb.


>---------------------------------------------------------------

15f4b073adc1410d182cc94d5fded590331eff71
 src/mlpack/methods/approx_kfn/qdafn.hpp      |  16 +++++
 src/mlpack/methods/approx_kfn/qdafn_impl.hpp |   9 +++
 src/mlpack/tests/qdafn_test.cpp              | 101 +++++++++++++++++++++++++++
 3 files changed, 126 insertions(+)

diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index f7949db..6ba8b81 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -50,6 +50,19 @@ class QDAFN
         const size_t m);
 
   /**
+   * Train the QDAFN model on the given reference set, optionally setting new
+   * parameters for the number of projections/tables (l) and the number of
+   * elements stored for each projection/table (m).
+   *
+   * @param referenceSet Reference set to train on.
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  void Train(const MatType& referenceSet,
+             const size_t l = 0,
+             const size_t m = 0);
+
+  /**
    * Search for the k furthest neighbors of the given query set.  (The query set
    * can contain just one point, that is okay.)  The results will be stored in
    * the given neighbors and distances matrices, in the same format as the
@@ -64,6 +77,9 @@ class QDAFN
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
+  //! Get the number of projections.
+  size_t NumProjections() const { return candidateSet.size(); }
+
   //! Get the candidate set for the given projection table.
   const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; }
   //! Modify the candidate set for the given projection table.  Careful!
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index de6c882..475538c 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -28,6 +28,15 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
     l(l),
     m(m)
 {
+  Train(referenceSet);
+}
+
+// Train the object.
+template<typename MatType>
+void QDAFN<MatType>::Train(const MatType& referenceSet,
+                           const size_t l,
+                           const size_t m)
+{
   // Build tables.  This is done by drawing random points from a Gaussian
   // distribution as the vectors we project onto.  The Gaussian should have zero
   // mean and unit variance.
diff --git a/src/mlpack/tests/qdafn_test.cpp b/src/mlpack/tests/qdafn_test.cpp
index ea64b52..332b7c7 100644
--- a/src/mlpack/tests/qdafn_test.cpp
+++ b/src/mlpack/tests/qdafn_test.cpp
@@ -102,4 +102,105 @@ BOOST_AUTO_TEST_CASE(QDAFNUniformSet)
   BOOST_REQUIRE_GE(successes, 700);
 }
 
+/**
+ * Test re-training method.
+ */
+BOOST_AUTO_TEST_CASE(RetrainTest)
+{
+  arma::mat dataset = arma::randu<arma::mat>(25, 500);
+  arma::mat newDataset = arma::randu<arma::mat>(15, 600);
+
+  QDAFN<> qdafn(dataset, 20, 60);
+
+  qdafn.Train(newDataset, 10, 50);
+
+  BOOST_REQUIRE_EQUAL(qdafn.NumProjections(), 10);
+  for (size_t i = 0; i < 10; ++i)
+  {
+    BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_rows, 15);
+    BOOST_REQUIRE_EQUAL(qdafn.CandidateSet(i).n_cols, 50);
+  }
+}
+
+/**
+ * Test serialization of QDAFN.
+ */
+BOOST_AUTO_TEST_CASE(SerializationTest)
+{
+  // Use a random dataset.
+  arma::mat dataset = arma::randu<arma::mat>(15, 300);
+
+  QDAFN<> qdafn(dataset, 10, 50);
+
+  arma::mat fakeDataset1 = arma::randu<arma::mat>(10, 200);
+  arma::mat fakeDataset2 = arma::randu<arma::mat>(50, 500);
+  QDAFN<> qdafnXml(fakeDataset1, 5, 10);
+  QDAFN<> qdafnText(6, 50);
+  QDAFN<> qdafnBinary(7, 15);
+  qdafnBinary.Train(fakeDataset2);
+
+  // Serialize the objects.
+  SerializeObjectAll(qdafn, qdafnXml, qdafnText, qdafnBinary);
+
+  // Check that the tables are all the same.
+  BOOST_REQUIRE_EQUAL(qdafnXml.NumProjections(), qdafn.NumProjections());
+  BOOST_REQUIRE_EQUAL(qdafnText.NumProjections(), qdafn.NumProjections());
+  BOOST_REQUIRE_EQUAL(qdafnBinary.NumProjections(), qdafn.NumProjections());
+
+  for (size_t i = 0; i < qdafn.NumProjections(); ++i)
+  {
+    BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_rows,
+        qdafn.CandidateSet(i).n_rows);
+    BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_rows,
+        qdafn.CandidateSet(i).n_rows);
+    BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_rows,
+        qdafn.CandidateSet(i).n_rows);
+
+    BOOST_REQUIRE_EQUAL(qdafnXml.CandidateSet(i).n_cols,
+        qdafn.CandidateSet(i).n_cols);
+    BOOST_REQUIRE_EQUAL(qdafnText.CandidateSet(i).n_cols,
+        qdafn.CandidateSet(i).n_cols);
+    BOOST_REQUIRE_EQUAL(qdafnBinary.CandidateSet(i).n_cols,
+        qdafn.CandidateSet(i).n_cols);
+
+    for (size_t j = 0; j < qdafn.CandidateSet(i).n_elem; ++j)
+    {
+      if (std::abs(qdafn.CandidateSet(i)[j]) < 1e-5)
+      {
+        BOOST_REQUIRE_SMALL(qdafnXml.CandidateSet(i)[j], 1e-5);
+        BOOST_REQUIRE_SMALL(qdafnText.CandidateSet(i)[j], 1e-5);
+        BOOST_REQUIRE_SMALL(qdafnBinary.CandidateSet(i)[j], 1e-5);
+      }
+      else
+      {
+        const double value = qdafn.CandidateSet(i)[j];
+        BOOST_REQUIRE_CLOSE(qdafnXml.CandidateSet(i)[j], value, 1e-5);
+        BOOST_REQUIRE_CLOSE(qdafnText.CandidateSet(i)[j], value, 1e-5);
+        BOOST_REQUIRE_CLOSE(qdafnBinary.CandidateSet(i)[j], value, 1e-5);
+      }
+    }
+  }
+}
+
+// Make sure QDAFN works with sparse data.
+BOOST_AUTO_TEST_CASE(SparseTest)
+{
+  arma::sp_mat dataset;
+  dataset.sprandu(200, 1000, 0.3);
+
+  // Create a sparse version.
+  QDAFN<arma::sp_mat> sparse(dataset, 15, 50);
+
+  // Make sure the results are of the right shape.  It's hard to test anything
+  // more than that because we don't have easy-to-check performance guarantees.
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+  sparse.Search(dataset, 3, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 1000);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list