[mlpack-git] master: Adds 2 deterministic LSH tests (8ad5711)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Jun 4 09:30:45 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/29d43319f1a3ace534a95e966be9e903f06b07e1...c726b603bc23c7c304523e60eaba4d496ce48e47
>---------------------------------------------------------------
commit 8ad5711d77865c6a2df5b6a296de8a905f587c94
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Sat Jun 4 16:30:45 2016 +0300
Adds 2 deterministic LSH tests
>---------------------------------------------------------------
8ad5711d77865c6a2df5b6a296de8a905f587c94
src/mlpack/tests/lsh_test.cpp | 178 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 178 insertions(+)
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index d425666..26cf24d 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -15,6 +15,9 @@ using namespace std;
using namespace mlpack;
using namespace mlpack::neighbor;
+/**
+ * Computes Recall (percent of neighbors found correctly)
+ */
double ComputeRecall(
const arma::Mat<size_t>& lshNeighbors,
const arma::Mat<size_t>& groundTruth)
@@ -26,6 +29,62 @@ double ComputeRecall(
return same / (static_cast<double>(queries * neigh));
}
+/**
+ * Generates a point set of four clusters around (0.5, 0.5),
+ * (3.5, 0.5), (0.5, 3.5), (3.5, 3.5)
+ */
+void getPointset(const size_t N, arma::mat& rdata)
+{
+ const size_t d = 2;
+ // Create four clusters of points
+ arma::mat C1(d, N/4, arma::fill::randu);
+ arma::mat C2(d, N/4, arma::fill::randu);
+ arma::mat C3(d, N/4, arma::fill::randu);
+ arma::mat C4(d, N/4, arma::fill::randu);
+
+ arma::colvec offset1;
+ offset1<<0<<arma::endr<<3<<arma::endr;
+ arma::colvec offset2;
+ offset2<<3<<arma::endr<<3<<arma::endr;
+ arma::colvec offset4;
+ offset4<<3<<arma::endr<<0<<arma::endr;
+ //spread points in plane
+ for (size_t p = 0; p < N/4; ++p)
+ {
+ C1.col(p)+=offset1;
+ C2.col(p)+=offset2;
+ C4.col(p)+=offset4;
+ }
+
+ rdata.set_size(d, N);
+ rdata.cols(0, N/4-1) = C1;
+ rdata.cols(N/4, N/2-1) = C2;
+ rdata.cols(N/2, 3*N/4-1) = C3;
+ rdata.cols(3*N/4, N-1) = C4;
+}
+
+/**
+ * Generates two queries, one around (0.5, 0.5) and one around (3.5, 3.5)
+ */
+void getQueries(arma::mat& qdata)
+{
+ const size_t d = 2;
+ // generate two queries inside two of the clusters
+
+ // put query 1 into cluster 3
+ arma::colvec q1, q2;
+ q1.randu(d, 1);
+
+ // offset second query to go into cluster 2
+ q2.randu(d, 1);
+ q2.row(0)+=3;
+ q2.row(1)+=3;
+
+ qdata.set_size(d, 2);
+ qdata.col(0) = q1;
+ qdata.col(1) = q2;
+}
+
BOOST_AUTO_TEST_SUITE(LSHTest);
/**
@@ -302,6 +361,125 @@ BOOST_AUTO_TEST_CASE(RecallTest)
BOOST_REQUIRE_LE(recallChp, recallThreshChp);
}
+/**
+ * Test: This is a deterministic test that projects 2-dpoints to a known line (axis
+ * 2). The reference set contains 4 well-separated clusters that will merge into
+ * 2 clusters when projected on that axis.
+ *
+ * We create two queries, each one belonging in one cluster (q1 in cluster 3
+ * located around (0, 0) and q2 in cluster 2 located around (3, 3). After the
+ * projection, q1 should have neighbors in C3 and C4 and q2 in C1 and C2.
+ */
+BOOST_AUTO_TEST_CASE(DeterministicMerge)
+{
+ const size_t N = 40; //must be devisable by 4 to create 4 clusters properly
+ arma::mat rdata;
+ arma::mat qdata;
+ getPointset(N, rdata);
+ getQueries(qdata);
+
+
+ const int k = N/2;
+ const double hashWidth = 1;
+ const int secondHashSize = 99901;
+ const int bucketSize = 500;
+
+ //1 table, with one projection to axis 1
+ arma::cube projections(2, 1, 1);
+ projections(0, 0, 0) = 0;
+ projections(1, 0, 0) = 1;
+
+ LSHSearch<> lshTest(rdata, projections,
+ hashWidth, secondHashSize, bucketSize);
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ lshTest.Search(qdata, k, neighbors, distances);
+
+ // test query 1
+ size_t q;
+ for (size_t j = 0; j < k; ++j) //for each neighbor
+ {
+ q = 0;
+ if (neighbors(j, 0) == N || neighbors(j, 1) == N) //neighbor not found, ignore
+ continue;
+
+ //query 1 is in cluster 3, which under this projection was merged with
+ //cluster 4. Clusters 3 and 4 have points 20:39, so only neighbors among
+ //those should be found
+ q = 0;
+ BOOST_REQUIRE(neighbors(j, q) >= N/2);
+
+ //query 2 is in cluster 2, which under this projection was merged with
+ //cluster 1. Clusters 1 and 2 have points 0:19, so only neighbors among
+ //those should be found
+ q = 1;
+ BOOST_REQUIRE(neighbors(j, q) < N/2);
+
+ }
+}
+
+
+/**
+ * Test: This is a deterministic test that projects 2-di points to the plane.
+ * The reference set contains 4 well-separated clusters that should not merge.
+ *
+ * We create two queries, each one belonging in one cluster (q1 in cluster 3
+ * located around (0, 0) and q2 in cluster 2 located around (3, 3). The test is
+ * a success if, after the projection, q1 should have neighbors in C3 and q2
+ * in C2.
+ */
+BOOST_AUTO_TEST_CASE(DeterministicNoMerge)
+{
+ const size_t N = 40;
+ arma::mat rdata;
+ arma::mat qdata;
+ getPointset(N, rdata);
+ getQueries(qdata);
+
+
+ const int k = N/2;
+ const double hashWidth = 1;
+ const int secondHashSize = 99901;
+ const int bucketSize = 500;
+
+ //1 table, with one projection to axis 1
+ arma::cube projections(2, 2, 1);
+ projections(0, 0, 0) = 0;
+ projections(1, 0, 0) = 1;
+ projections(0, 1, 0) = 1;
+ projections(1, 1, 0) = 0;
+
+ LSHSearch<> lshTest(rdata, projections,
+ hashWidth, secondHashSize, bucketSize);
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+ lshTest.Search(qdata, k, neighbors, distances);
+
+ // test query 1
+ size_t q;
+ for (size_t j = 0; j < k; ++j) //for each neighbor
+ {
+
+ //neighbor not found, ignore
+ if (neighbors(j, 0) == N || neighbors(j, 1) == N)
+ continue;
+
+ q = 0;
+ //query 1 is in cluster 3, which is points 20:29
+ BOOST_REQUIRE(
+ neighbors(j, q) >= N/2 && neighbors(j, q) < 3*N/4
+ );
+
+ q = 1;
+ //query 2 is in cluster 2, which is points 10:19
+ BOOST_REQUIRE(
+ neighbors(j, q) >= N/4 && neighbors(j, q) < N/2
+ );
+ }
+
+}
BOOST_AUTO_TEST_CASE(LSHTrainTest)
{
// This is a not very good test that simply checks that the re-trained LSH
More information about the mlpack-git
mailing list