[mlpack-git] master: Adds 2 deterministic LSH tests (8ad5711)

gitdub at mlpack.org gitdub at mlpack.org
Sat Jun 4 09:30:45 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/29d43319f1a3ace534a95e966be9e903f06b07e1...c726b603bc23c7c304523e60eaba4d496ce48e47

>---------------------------------------------------------------

commit 8ad5711d77865c6a2df5b6a296de8a905f587c94
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Sat Jun 4 16:30:45 2016 +0300

    Adds 2 deterministic LSH tests


>---------------------------------------------------------------

8ad5711d77865c6a2df5b6a296de8a905f587c94
 src/mlpack/tests/lsh_test.cpp | 178 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 178 insertions(+)

diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index d425666..26cf24d 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,6 +15,9 @@ using namespace std;
 using namespace mlpack;
 using namespace mlpack::neighbor;
 
+/**
+ * Computes Recall (percent of neighbors found correctly)
+ */
 double ComputeRecall(
     const arma::Mat<size_t>& lshNeighbors,
     const arma::Mat<size_t>& groundTruth)
@@ -26,6 +29,62 @@ double ComputeRecall(
   return same / (static_cast<double>(queries * neigh));
 }
 
+/**
+ * Generates a point set of four clusters around (0.5, 0.5),
+ * (3.5, 0.5), (0.5, 3.5), (3.5, 3.5)
+ */
+void getPointset(const size_t N, arma::mat& rdata)
+{
+  const size_t d = 2;
+  // Create four clusters of points
+  arma::mat C1(d, N/4, arma::fill::randu);
+  arma::mat C2(d, N/4, arma::fill::randu);
+  arma::mat C3(d, N/4, arma::fill::randu);
+  arma::mat C4(d, N/4, arma::fill::randu);
+
+  arma::colvec offset1;
+  offset1<<0<<arma::endr<<3<<arma::endr;
+  arma::colvec offset2;
+  offset2<<3<<arma::endr<<3<<arma::endr;
+  arma::colvec offset4;
+  offset4<<3<<arma::endr<<0<<arma::endr;
+  //spread points in plane
+  for (size_t p = 0; p < N/4; ++p)
+  {
+    C1.col(p)+=offset1;
+    C2.col(p)+=offset2;
+    C4.col(p)+=offset4;
+  }
+
+  rdata.set_size(d, N);
+  rdata.cols(0, N/4-1) = C1;
+  rdata.cols(N/4, N/2-1) = C2;
+  rdata.cols(N/2, 3*N/4-1) = C3;
+  rdata.cols(3*N/4, N-1) = C4;
+}
+
+/**
+ * Generates two queries, one around (0.5, 0.5) and one around (3.5, 3.5)
+ */
+void getQueries(arma::mat& qdata)
+{
+  const size_t d = 2;
+  // generate two queries inside two of the clusters
+
+  // put query 1 into cluster 3
+  arma::colvec q1, q2;
+  q1.randu(d, 1);
+
+  // offset second query to go into cluster 2
+  q2.randu(d, 1);
+  q2.row(0)+=3;
+  q2.row(1)+=3;
+
+  qdata.set_size(d, 2);
+  qdata.col(0) = q1;
+  qdata.col(1) = q2;
+}
+
 BOOST_AUTO_TEST_SUITE(LSHTest);
 
 /**
@@ -302,6 +361,125 @@ BOOST_AUTO_TEST_CASE(RecallTest)
   BOOST_REQUIRE_LE(recallChp, recallThreshChp);
 }
 
+/**
+ * Test: This is a deterministic test that projects 2-dpoints to a known line (axis
+ * 2). The reference set contains 4 well-separated clusters that will merge into
+ * 2 clusters when projected on that axis.
+ *
+ * We create two queries, each one belonging in one cluster (q1 in cluster 3
+ * located around (0, 0) and q2 in cluster 2 located around (3, 3). After the
+ * projection, q1 should have neighbors in C3 and C4 and q2 in C1 and C2.
+ */
+BOOST_AUTO_TEST_CASE(DeterministicMerge)
+{
+  const size_t N = 40; //must be devisable by 4 to create 4 clusters properly
+  arma::mat rdata;
+  arma::mat qdata;
+  getPointset(N, rdata);
+  getQueries(qdata);
+
+
+  const int k = N/2;
+  const double hashWidth = 1;
+  const int secondHashSize = 99901;
+  const int bucketSize = 500;
+
+  //1 table, with one projection to axis 1
+  arma::cube projections(2, 1, 1);
+  projections(0, 0, 0) = 0;
+  projections(1, 0, 0) = 1;
+
+  LSHSearch<> lshTest(rdata, projections, 
+                      hashWidth, secondHashSize, bucketSize);
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+  lshTest.Search(qdata, k, neighbors, distances);
+
+  // test query 1
+  size_t q;
+  for (size_t j = 0; j < k; ++j) //for each neighbor
+  {
+    q = 0;
+    if (neighbors(j, 0) == N || neighbors(j, 1) == N) //neighbor not found, ignore
+      continue;
+
+    //query 1 is in cluster 3, which under this projection was merged with
+    //cluster 4. Clusters 3 and 4 have points 20:39, so only neighbors among
+    //those should be found
+    q = 0;
+    BOOST_REQUIRE(neighbors(j, q) >= N/2);
+  
+    //query 2 is in cluster 2, which under this projection was merged with
+    //cluster 1. Clusters 1 and 2 have points 0:19, so only neighbors among
+    //those should be found
+    q = 1;
+    BOOST_REQUIRE(neighbors(j, q) < N/2);
+
+  }
+}
+
+
+/**
+ * Test: This is a deterministic test that projects 2-di points to the plane.
+ * The reference set contains 4 well-separated clusters that should not merge.
+ *
+ * We create two queries, each one belonging in one cluster (q1 in cluster 3
+ * located around (0, 0) and q2 in cluster 2 located around (3, 3). The test is
+ * a success if, after the projection, q1 should have neighbors in C3 and q2 
+ * in C2.
+ */
+BOOST_AUTO_TEST_CASE(DeterministicNoMerge)
+{
+  const size_t N = 40;
+  arma::mat rdata;
+  arma::mat qdata;
+  getPointset(N, rdata);
+  getQueries(qdata);
+
+
+  const int k = N/2;
+  const double hashWidth = 1;
+  const int secondHashSize = 99901;
+  const int bucketSize = 500;
+
+  //1 table, with one projection to axis 1
+  arma::cube projections(2, 2, 1);
+  projections(0, 0, 0) = 0;
+  projections(1, 0, 0) = 1;
+  projections(0, 1, 0) = 1;
+  projections(1, 1, 0) = 0;
+
+  LSHSearch<> lshTest(rdata, projections, 
+                      hashWidth, secondHashSize, bucketSize);
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+  lshTest.Search(qdata, k, neighbors, distances);
+
+  // test query 1
+  size_t q;
+  for (size_t j = 0; j < k; ++j) //for each neighbor
+  {
+  
+    //neighbor not found, ignore
+    if (neighbors(j, 0) == N || neighbors(j, 1) == N)
+      continue;
+
+    q = 0;
+    //query 1 is in cluster 3, which is points 20:29
+    BOOST_REQUIRE(
+        neighbors(j, q) >= N/2 && neighbors(j, q) < 3*N/4
+        );
+
+    q = 1;
+    //query 2 is in cluster 2, which is points 10:19
+    BOOST_REQUIRE(
+        neighbors(j, q) >= N/4 && neighbors(j, q) < N/2
+        );
+  }
+
+}
 BOOST_AUTO_TEST_CASE(LSHTrainTest)
 {
   // This is a not very good test that simply checks that the re-trained LSH




More information about the mlpack-git mailing list