[mlpack-git] master: Merges multiprobe LSH (3af80c3)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Jul 1 10:48:25 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/34cf8d94f79c9a72ff4199676033b060cd039fcd...425324bf7fb7c86c85d10a909d8a59d4f69b7164
>---------------------------------------------------------------
commit 3af80c339a7a846bb3bdf305131c87eaa939fc01
Merge: 0d38271 bdbc2fe
Author: Yannis Mentekidis <mentekid at gmail.com>
Date: Fri Jul 1 15:48:25 2016 +0100
Merges multiprobe LSH
>---------------------------------------------------------------
3af80c339a7a846bb3bdf305131c87eaa939fc01
HISTORY.md | 16 +-
README.md | 4 +-
src/mlpack/core/tree/CMakeLists.txt | 10 +
src/mlpack/core/tree/rectangle_tree.hpp | 6 +
.../tree/rectangle_tree/discrete_hilbert_value.hpp | 263 ++++++++++++
.../rectangle_tree/discrete_hilbert_value_impl.hpp | 457 ++++++++++++++++++++
.../tree/rectangle_tree/dual_tree_traverser.hpp | 7 +-
.../rectangle_tree/dual_tree_traverser_impl.hpp | 24 +-
.../hilbert_r_tree_auxiliary_information.hpp | 120 ++++++
.../hilbert_r_tree_auxiliary_information_impl.hpp | 169 ++++++++
.../hilbert_r_tree_descent_heuristic.hpp | 53 +++
.../hilbert_r_tree_descent_heuristic_impl.hpp | 49 +++
.../tree/rectangle_tree/hilbert_r_tree_split.hpp | 94 ++++
.../rectangle_tree/hilbert_r_tree_split_impl.hpp | 340 +++++++++++++++
.../rectangle_tree/no_auxiliary_information.hpp | 117 +++++
.../r_star_tree_descent_heuristic.hpp | 16 +-
.../r_star_tree_descent_heuristic_impl.hpp | 50 ++-
.../core/tree/rectangle_tree/r_star_tree_split.hpp | 24 +-
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 151 ++++---
.../rectangle_tree/r_tree_descent_heuristic.hpp | 4 +-
.../r_tree_descent_heuristic_impl.hpp | 32 +-
.../core/tree/rectangle_tree/r_tree_split.hpp | 33 +-
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 150 +++----
.../core/tree/rectangle_tree/rectangle_tree.hpp | 101 ++---
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 457 ++++++++++++--------
.../tree/rectangle_tree/single_tree_traverser.hpp | 7 +-
.../rectangle_tree/single_tree_traverser_impl.hpp | 20 +-
src/mlpack/core/tree/rectangle_tree/traits.hpp | 7 +-
src/mlpack/core/tree/rectangle_tree/typedef.hpp | 45 +-
.../x_tree_auxiliary_information.hpp | 186 ++++++++
.../core/tree/rectangle_tree/x_tree_split.hpp | 61 +--
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 243 +++++------
src/mlpack/methods/amf/init_rules/CMakeLists.txt | 1 +
src/mlpack/methods/amf/init_rules/given_init.hpp | 75 ++++
src/mlpack/methods/lsh/lsh_main.cpp | 13 +-
src/mlpack/methods/lsh/lsh_search.hpp | 86 +++-
src/mlpack/methods/lsh/lsh_search_impl.hpp | 437 +++++++++++++++++--
src/mlpack/methods/neighbor_search/kfn_main.cpp | 34 +-
src/mlpack/methods/neighbor_search/knn_main.cpp | 20 +-
.../methods/neighbor_search/neighbor_search.hpp | 15 +
.../neighbor_search/neighbor_search_impl.hpp | 32 +-
.../neighbor_search/neighbor_search_rules.hpp | 4 +
.../neighbor_search/neighbor_search_rules_impl.hpp | 14 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 58 ++-
.../methods/neighbor_search/ns_model_impl.hpp | 64 ++-
.../sort_policies/furthest_neighbor_sort.cpp | 4 +-
.../sort_policies/furthest_neighbor_sort.hpp | 17 +
.../sort_policies/nearest_neighbor_sort.hpp | 15 +
.../methods/preprocess/preprocess_split_main.cpp | 2 +-
.../methods/range_search/range_search_main.cpp | 6 +-
src/mlpack/methods/range_search/rs_model.cpp | 21 +-
src/mlpack/methods/range_search/rs_model.hpp | 5 +-
src/mlpack/methods/range_search/rs_model_impl.hpp | 14 +
src/mlpack/methods/rann/krann_main.cpp | 6 +-
src/mlpack/methods/rann/ra_model.hpp | 5 +-
src/mlpack/methods/rann/ra_model_impl.hpp | 56 ++-
src/mlpack/prereqs.hpp | 1 +
src/mlpack/tests/CMakeLists.txt | 2 +
src/mlpack/tests/activation_functions_test.cpp | 2 +-
src/mlpack/tests/ada_delta_test.cpp | 2 +-
src/mlpack/tests/adaboost_test.cpp | 3 +-
src/mlpack/tests/adam_test.cpp | 2 +-
src/mlpack/tests/akfn_test.cpp | 240 +++++++++++
src/mlpack/tests/aknn_test.cpp | 405 ++++++++++++++++++
src/mlpack/tests/arma_extend_test.cpp | 2 +-
src/mlpack/tests/armadillo_svd_test.cpp | 2 +-
src/mlpack/tests/aug_lagrangian_test.cpp | 2 +-
src/mlpack/tests/binarize_test.cpp | 2 +-
src/mlpack/tests/cf_test.cpp | 2 +-
src/mlpack/tests/cli_test.cpp | 2 +-
src/mlpack/tests/convolution_test.cpp | 2 +-
src/mlpack/tests/convolutional_network_test.cpp | 2 +-
src/mlpack/tests/cosine_tree_test.cpp | 2 +-
src/mlpack/tests/decision_stump_test.cpp | 2 +-
src/mlpack/tests/det_test.cpp | 2 +-
src/mlpack/tests/distribution_test.cpp | 2 +-
src/mlpack/tests/emst_test.cpp | 2 +-
src/mlpack/tests/fastmks_test.cpp | 2 +-
src/mlpack/tests/feedforward_network_test.cpp | 2 +-
src/mlpack/tests/gmm_test.cpp | 2 +-
src/mlpack/tests/hmm_test.cpp | 2 +-
src/mlpack/tests/hoeffding_tree_test.cpp | 2 +-
src/mlpack/tests/ind2sub_test.cpp | 2 +-
src/mlpack/tests/init_rules_test.cpp | 2 +-
src/mlpack/tests/kernel_pca_test.cpp | 2 +-
src/mlpack/tests/kernel_test.cpp | 2 +-
src/mlpack/tests/kernel_traits_test.cpp | 2 +-
src/mlpack/tests/kfn_test.cpp | 2 +-
src/mlpack/tests/kmeans_test.cpp | 3 +-
src/mlpack/tests/knn_test.cpp | 18 +-
src/mlpack/tests/krann_search_test.cpp | 8 +-
src/mlpack/tests/lars_test.cpp | 2 +-
src/mlpack/tests/layer_traits_test.cpp | 2 +-
src/mlpack/tests/lbfgs_test.cpp | 2 +-
src/mlpack/tests/lin_alg_test.cpp | 2 +-
src/mlpack/tests/linear_regression_test.cpp | 2 +-
src/mlpack/tests/load_save_test.cpp | 2 +-
src/mlpack/tests/local_coordinate_coding_test.cpp | 2 +-
src/mlpack/tests/log_test.cpp | 2 +-
src/mlpack/tests/logistic_regression_test.cpp | 2 +-
src/mlpack/tests/lrsdp_test.cpp | 2 +-
src/mlpack/tests/lsh_test.cpp | 158 ++++++-
src/mlpack/tests/lstm_peephole_test.cpp | 2 +-
src/mlpack/tests/math_test.cpp | 2 +-
src/mlpack/tests/matrix_completion_test.cpp | 2 +-
src/mlpack/tests/maximal_inputs_test.cpp | 2 +-
src/mlpack/tests/mean_shift_test.cpp | 3 +-
src/mlpack/tests/metric_test.cpp | 2 +-
src/mlpack/tests/minibatch_sgd_test.cpp | 2 +-
src/mlpack/tests/mlpack_test.cpp | 2 +-
src/mlpack/tests/nbc_test.cpp | 2 +-
src/mlpack/tests/nca_test.cpp | 2 +-
src/mlpack/tests/network_util_test.cpp | 2 +-
src/mlpack/tests/nmf_test.cpp | 17 +-
src/mlpack/tests/nystroem_method_test.cpp | 2 +-
src/mlpack/tests/pca_test.cpp | 2 +-
src/mlpack/tests/perceptron_test.cpp | 2 +-
src/mlpack/tests/performance_functions_test.cpp | 2 +-
src/mlpack/tests/pooling_rules_test.cpp | 2 +-
src/mlpack/tests/quic_svd_test.cpp | 2 +-
src/mlpack/tests/radical_test.cpp | 2 +-
src/mlpack/tests/range_search_test.cpp | 16 +-
src/mlpack/tests/rectangle_tree_test.cpp | 472 ++++++++++++++++-----
src/mlpack/tests/recurrent_network_test.cpp | 2 +-
src/mlpack/tests/regularized_svd_test.cpp | 2 +-
src/mlpack/tests/rmsprop_test.cpp | 2 +-
src/mlpack/tests/sa_test.cpp | 2 +-
src/mlpack/tests/sdp_primal_dual_test.cpp | 2 +-
src/mlpack/tests/serialization.hpp | 2 +-
src/mlpack/tests/serialization_test.cpp | 2 +-
src/mlpack/tests/sgd_test.cpp | 2 +-
src/mlpack/tests/softmax_regression_test.cpp | 2 +-
src/mlpack/tests/sort_policy_test.cpp | 2 +-
src/mlpack/tests/sparse_autoencoder_test.cpp | 2 +-
src/mlpack/tests/sparse_coding_test.cpp | 2 +-
src/mlpack/tests/split_data_test.cpp | 2 +-
src/mlpack/tests/svd_batch_test.cpp | 2 +-
src/mlpack/tests/svd_incremental_test.cpp | 2 +-
src/mlpack/tests/termination_policy_test.cpp | 2 +-
...d_boost_test_definitions.hpp => test_tools.hpp} | 14 +-
src/mlpack/tests/tree_test.cpp | 2 +-
src/mlpack/tests/tree_traits_test.cpp | 2 +-
src/mlpack/tests/union_find_test.cpp | 2 +-
143 files changed, 4771 insertions(+), 1055 deletions(-)
diff --cc src/mlpack/methods/lsh/lsh_search.hpp
index 0882393,44a7b64..811a73d
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@@ -187,9 -205,9 +205,10 @@@ class LSHSearc
void Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
- const size_t numTablesToSearch = 0);
+ const size_t numTablesToSearch = 0,
+ size_t T = 0);
+
/**
* Compute the recall (% of neighbors found) given the neighbors returned by
* LSHSearch::Search and a "ground truth" set of neighbors. The recall
diff --cc src/mlpack/methods/lsh/lsh_search_impl.hpp
index dae15a6,ab81088..8d3cdb1
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@@ -351,80 -339,274 +363,343 @@@ void LSHSearch<SortPolicy>::BaseCase(co
// SortDistance() returns (size_t() - 1) if we shouldn't add it.
if (insertPosition != (size_t() - 1))
- InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
- referenceIndex, distance);
+ {
+ #pragma omp critical
+ {
+ InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+ referenceIndex, distance);
+ }
+ }
}
+*/
+// Base case where the query set is the reference set. (So, we can't return
+// ourselves as the nearest neighbor.)
+template<typename SortPolicy>
+inline force_inline
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+ const arma::uvec& referenceIndices,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const
+{
+ for (size_t j = 0; j < referenceIndices.n_elem; ++j)
+ {
+ const size_t referenceIndex = referenceIndices[j];
+ // If the points are the same, skip this point.
+ if (queryIndex == referenceIndex)
+ continue;
+
+ const double distance = metric::EuclideanDistance::Evaluate(
+ referenceSet->unsafe_col(queryIndex),
+ referenceSet->unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+ distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+ referenceIndex, distance);
+ }
+}
+
+// Base case for bichromatic search.
+template<typename SortPolicy>
+inline force_inline
+void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
+ const arma::uvec& referenceIndices,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& neighbors,
+ arma::mat& distances) const
+{
+ for (size_t j = 0; j < referenceIndices.n_elem; ++j)
+ {
+ const size_t referenceIndex = referenceIndices[j];
+ const double distance = metric::EuclideanDistance::Evaluate(
+ querySet.unsafe_col(queryIndex),
+ referenceSet->unsafe_col(referenceIndex));
+
+ // If this distance is better than any of the current candidates, the
+ // SortDistance() function will give us the position to insert it into.
+ arma::vec queryDist = distances.unsafe_col(queryIndex);
+ arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
+ size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
+ distance);
+
+ // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+ if (insertPosition != (size_t() - 1))
+ InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
+ referenceIndex, distance);
+ }
+}
template<typename SortPolicy>
+ inline force_inline
+ double LSHSearch<SortPolicy>::PerturbationScore(
+ const std::vector<bool>& A,
+ const arma::vec& scores) const
+ {
+ double score = 0.0;
+ for (size_t i = 0; i < A.size(); ++i)
+ if (A[i])
+ score += scores(i); // add scores of non-zero indices
+ return score;
+ }
+
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationShift(std::vector<bool>& A) const
+ {
+ size_t maxPos = 0;
+ for (size_t i = 0; i < A.size(); ++i)
+ if (A[i] == 1) // marked true
+ maxPos=i;
+
+ if ( maxPos + 1 < A.size()) // otherwise, this is an invalid vector
+ {
+ A[maxPos] = 0;
+ A[maxPos + 1] = 1;
+ return true; // valid
+ }
+ return false; // invalid
+ }
+
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationExpand(std::vector<bool>& A) const
+ {
+ // Find the last '1' in A
+ size_t maxPos = 0;
+ for (size_t i = 0; i < A.size(); ++i)
+ if (A[i]) // marked true
+ maxPos = i;
+
+ if (maxPos + 1 < A.size()) // otherwise, this is an invalid vector
+ {
+ A[maxPos + 1] = 1;
+ return true;
+ }
+ return false;
+ }
+
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationValid(
+ const std::vector<bool>& A) const
+ {
+ // Use check to mark dimensions we have seen before in A. If a dimension is
+ // seen twice (or more), A is not a valid perturbation.
+ std::vector<bool> check(numProj);
+
+ if (A.size() > 2 * numProj)
+ return false; // This should never happen.
+
+ // Check that we only see each dimension once. If not, vector is not valid.
+ for (size_t i = 0; i < A.size(); ++i)
+ {
+ // Only check dimensions that were included.
+ if (!A[i])
+ continue;
+
+ // If dimesnion is unseen thus far, mark it as seen.
+ if (check[i % numProj] == false)
+ check[i % numProj] = true;
+ else
+ return false; // If dimension was seen before, set is not valid.
+ }
+ // If we didn't fail, set is valid.
+ return true;
+ }
+
+ // Compute additional probing bins for a query
+ template<typename SortPolicy>
+ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
+ const arma::vec& queryCode,
+ const arma::vec& queryCodeNotFloored,
+ const size_t T,
+ arma::mat& additionalProbingBins) const
+ {
+
+ // No additional bins requested. Our work is done.
+ if (T == 0)
+ return;
+
+ // Each column of additionalProbingBins is the code of a bin.
+ additionalProbingBins.set_size(numProj, T);
+
+ // Copy the query's code, then in the end we will add/subtract according
+ // to perturbations we calculated.
+ for (size_t c = 0; c < T; ++c)
+ additionalProbingBins.col(c) = queryCode;
+
+
+ // Calculate query point's projection position.
+ arma::mat projection = queryCodeNotFloored;
+
+ // Use projection to calculate query's distance from hash limits.
+ arma::vec limLow = projection - queryCode * hashWidth;
+ arma::vec limHigh = hashWidth - limLow;
+
+ // Calculate scores. score = distance^2.
+ arma::vec scores(2 * numProj);
+ scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
+ scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
+
+ // Actions vector describes what perturbation (-1/+1) corresponds to a score.
+ arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
+ actions.rows(0, numProj - 1) = // First numProj rows.
+ -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
+ actions.rows(numProj, (2 * numProj) - 1) = // Last numProj rows.
+ arma::ones< arma::Col<short int> > (numProj); // 1s
+
+
+ // Acting dimension vector shows which coordinate to transform according to
+ // actions (actions are described by actions vector above).
+ arma::Col<size_t> positions(2 * numProj); // Will be [0 1 2 ... 0 1 2 ...].
+ positions.rows(0, numProj - 1) =
+ arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+ positions.rows(numProj, 2 * numProj - 1) =
+ arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+
+ // Special case: No need to create heap for 1 or 2 codes.
+ if (T <= 2)
+ {
+ // First, find location of minimum score, generate 1 perturbation vector,
+ // and add its code to additionalProbingBins column 0.
+
+ // Find location and value of smallest element of scores vector.
+ double minscore = scores[0];
+ size_t minloc = 0;
+ for (size_t s = 1; s < (2 * numProj); ++s)
+ {
+ if (minscore > scores[s])
+ {
+ minscore = scores[s];
+ minloc = s;
+ }
+ }
+
+ // Add or subtract 1 to dimension corresponding to minimum score.
+ additionalProbingBins(positions[minloc], 0) += actions[minloc];
+ if (T == 1)
+ return; // Done if asked for only 1 code.
+
+ // Now, find location of second smallest score and generate one more vector.
+ // The second perturbation vector still can't comprise of more than one
+ // change in the bin codes, because of the way perturbation vectors
+ // are generated: First we create the one with the smallest score (Ao) and
+ // then we either add 1 extra dimension to it (Ae) or shift it by one (As).
+ // Since As contains the second smallest score, and Ae contains both the
+ // smallest and the second smallest, it's obvious that score(Ae) >
+ // score(As). Therefore the second perturbation vector is ALWAYS the vector
+ // containing only the second-lowest scoring perturbation.
+
+ double minscore2 = scores[0];
+ size_t minloc2 = 0;
+ for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1
+ {
+ if (minscore2 > scores[s] && s != minloc) //second smallest
+ {
+ minscore2 = scores[s];
+ minloc2 = s;
+ }
+ }
+
+ // Add or subtract 1 to create second-lowest scoring vector.
+ additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
+ return;
+ }
+
+ // General case: more than 2 perturbation vectors require use of minheap.
+
+ // Sort everything in increasing order.
+ arma::uvec sortidx = arma::sort_index(scores);
+ scores = scores(sortidx);
+ actions = actions(sortidx);
+ positions = positions(sortidx);
+
+
+ // Theory:
+ // A probing sequence is a sequence of T probing bins where a query's
+ // neighbors are most likely to be. Likelihood is dependent only on a bin's
+ // score, which is the sum of scores of all dimension-action pairs, so we
+ // need to calculate the T smallest sums of scores that are not conflicting.
+ //
+ // Method:
+ // Store each perturbation set (pair of (dimension, action)) in a
+ // std::vector. Create a minheap of scores, with each node pointing to its
+ // relevant perturbation set. Each perturbation set popped from the minheap
+ // is the next most likely perturbation set.
+ // Transform perturbation set to perturbation vector by setting the
+ // dimensions specified by the set to queryCode+action (action is {-1, 1}).
+
+ // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
+ // included in a given perturbation vector. Other spaces are 0.
+ std::vector<bool> Ao(2 * numProj);
+ Ao[0] = 1; // Smallest vector includes only smallest score.
+
+ std::vector< std::vector<bool> > perturbationSets;
+ perturbationSets.push_back(Ao); // Storage of perturbation sets.
+
+ std::priority_queue<
+ std::pair<double, size_t>, // contents: pairs of (score, index)
+ std::vector< // container: vector of pairs
+ std::pair<double, size_t>
+ >,
+ std::greater< std::pair<double, size_t> > // comparator of pairs
+ > minHeap; // our minheap
+
+ // Start by adding the lowest scoring set to the minheap.
+ minHeap.push( std::make_pair(PerturbationScore(Ao, scores), 0) );
+
+ // Loop invariable: after pvec iterations, additionalProbingBins contains pvec
+ // valid codes of the lowest-scoring bins (bins most likely to contain
+ // neighbors of the query).
+ for (size_t pvec = 0; pvec < T; ++pvec)
+ {
+ std::vector<bool> Ai;
+ do
+ {
+ // Get the perturbation set corresponding to the minimum score.
+ Ai = perturbationSets[ minHeap.top().second ];
+ minHeap.pop(); // .top() returns, .pop() removes
+
+ // Shift operation on Ai (replace max with max+1).
+ std::vector<bool> As = Ai;
+ if (PerturbationShift(As) && PerturbationValid(As))
+ // Don't add invalid sets.
+ {
+ perturbationSets.push_back(As); // add shifted set to sets
+ minHeap.push(
+ std::make_pair(PerturbationScore(As, scores),
+ perturbationSets.size() - 1));
+ }
+
+ // Expand operation on Ai (add max+1 to set).
+ std::vector<bool> Ae = Ai;
+ if (PerturbationExpand(Ae) && PerturbationValid(Ae))
+ // Don't add invalid sets.
+ {
+ perturbationSets.push_back(Ae); // add expanded set to sets
+ minHeap.push(
+ std::make_pair(PerturbationScore(Ae, scores),
+ perturbationSets.size() - 1));
+ }
+
+ } while (!PerturbationValid(Ai));//Discard invalid perturbations
+
+ // Found valid perturbation set Ai. Construct perturbation vector from set.
+ for (size_t pos = 0; pos < Ai.size(); ++pos)
+ // If Ai[pos] is marked, add action to probing vector.
+ additionalProbingBins(positions(pos), pvec)
+ += Ai[pos] ? actions(pos) : 0;
+ }
+ }
+
+ template<typename SortPolicy>
template<typename VecType>
void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
const VecType& queryPoint,
@@@ -517,19 -733,21 +826,22 @@@
// Retrieve candidates.
size_t start = 0;
- for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
+
+ for (long long int i = 0; i < numTablesToSearch; ++i) // For all tables
{
- const size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
- const size_t tableRow = bucketRowInHashTable[hashInd];
+ for (size_t p = 0; p < T + 1; ++p)
+ {
+ const size_t hashInd = hashMat(p, i); // Find the query's bucket.
+ const size_t tableRow = bucketRowInHashTable[hashInd];
- // Store all secondHashTable points in the candidates set.
- if (tableRow != secondHashSize)
- for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
- refPointsConsideredSmall(start++) = secondHashTable[tableRow][j];
+ if (tableRow < secondHashSize)
+ // Store all secondHashTable points in the candidates set.
+ for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
+ refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
+ }
}
- // Only keep unique candidates.
+ // Keep only one copy of each candidate.
referenceIndices = arma::unique(refPointsConsideredSmall);
return;
}
@@@ -655,12 -880,11 +1001,13 @@@ Search(const size_t k
// Hash every query into every hash table and eventually into the
// 'secondHashTable' to obtain the neighbor candidates.
arma::uvec refIndices;
- ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
+ ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch,
+ Teffective);
// An informative book-keeping for the number of neighbor candidates
- // returned on average.
+ // returned on average. Make atomic to avoid race conditions when multiple
+ // threads are running.
+ #pragma omp atomic
avgIndicesReturned += refIndices.n_elem;
// Sequentially go through all the candidates and save the best 'k'
diff --cc src/mlpack/tests/lsh_test.cpp
index c594991,4431f94..8aac238
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@@ -466,118 -471,146 +471,260 @@@ BOOST_AUTO_TEST_CASE(DeterministicNoMer
}
}
++
+ /**
+ * Test: Create an LSHSearch object and use an increasing number of probes to
+ * search for points. Require that recall for the same object doesn't decrease
+ * with increasing number of probes. Also require that at least a few times
+ * there's some increase in recall.
+ */
+ BOOST_AUTO_TEST_CASE(MultiprobeTest)
+ {
++ // Test parameters.
+ const double epsilonIncrease = 0.05;
- const size_t repetitions = 5; // train five objects
++ const size_t repetitions = 5; // Train five objects.
+
+ const size_t probeTrials = 5;
+ const size_t numProbes[probeTrials] = {0, 1, 2, 3, 4};
+
-
- /// algorithm parameters
++ // Algorithm parameters.
+ const int k = 4;
+ const int numTables = 16;
+ const int numProj = 3;
+ const double hashWidth = 0;
+ const int secondHashSize = 99901;
+ const int bucketSize = 500;
+
+ const string trainSet = "iris_train.csv";
+ const string testSet = "iris_test.csv";
+ arma::mat rdata;
+ arma::mat qdata;
+ data::Load(trainSet, rdata, true);
+ data::Load(testSet, qdata, true);
+
- // Run classic knn on reference set
++ // Run classic knn on reference set.
+ KNN knn(rdata);
+ arma::Mat<size_t> groundTruth;
+ arma::mat groundDistances;
+ knn.Search(qdata, k, groundTruth, groundDistances);
+
+ bool foundIncrease = 0;
+
+ for (size_t rep = 0; rep < repetitions; ++rep)
+ {
+ // Train a model.
+ LSHSearch<> multiprobeTest(rdata, numProj, numTables, hashWidth,
+ secondHashSize, bucketSize);
+
+ double prevRecall = 0;
+ // Search with varying number of probes.
+ for (size_t p = 0; p < probeTrials; ++p)
+ {
+ arma::Mat<size_t> lshNeighbors;
+ arma::mat lshDistances;
+
+ multiprobeTest.Search(qdata, k, lshNeighbors, lshDistances, 0,
+ numProbes[p]);
+
+ // Compute recall of this run.
+ double recall = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
+ if (p > 0)
+ {
+ // More probes should at the very least not lower recall...
+ BOOST_REQUIRE_GE(recall, prevRecall);
+
+ // ... and should ideally increase it a bit.
+ if (recall > prevRecall + epsilonIncrease)
+ foundIncrease = true;
+ prevRecall = recall;
+ }
+ }
+ }
+ BOOST_REQUIRE(foundIncrease);
+ }
+
+ /**
+ * Test: This is a deterministic test that verifies multiprobe LSH works
+ * correctly. To do this, we generate two queries, q1 and q2. q1 is hashed
+ * directly under cluster C2, q2 is hashed in C2's center.
+ * We verify that:
+ * 1) q1 should have no neighbors without multiprobe.
+ * 2) q1 should have neighbors only from C2 with 1 additional probe.
+ * 3) q2 should have all neighbors found with 3 additional probes.
+ */
+ BOOST_AUTO_TEST_CASE(MultiprobeDeterministicTest)
+ {
+ // Generate known deterministic clusters of points.
+ const size_t N = 40;
+ arma::mat rdata;
+ GetPointset(N, rdata);
+
+ const int k = N / 4;
+ const double hashWidth = 1;
+ const int secondHashSize = 99901;
+ const int bucketSize = 500;
+
+ // 1 table, projections on orthonormal plane.
+ arma::cube projections(2, 2, 1);
+ projections(0, 0, 0) = 1;
+ projections(1, 0, 0) = 0;
+ projections(0, 1, 0) = 0;
+ projections(1, 1, 0) = 1;
+
+ // Construct LSH object with given tables.
+ LSHSearch<> lshTest(rdata, projections,
+ hashWidth, secondHashSize, bucketSize);
+
+ const arma::mat offsets = lshTest.Offsets();
+
+ // Construct q1 so it is hashed directly under C2.
+ arma::mat q1;
+ q1 << 3.9 << arma::endr << 2.99;
+ q1 -= offsets;
+
+ // Construct q2 so it is hashed near the center of C2.
+ arma::mat q2;
+ q2 << 3.6 << arma::endr << 3.6;
+ q2 -= offsets;
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ // Test that q1 simple search comes up empty.
+ lshTest.Search(q1, k, neighbors, distances);
+ BOOST_REQUIRE(arma::all(neighbors.col(0) == N));
+
+ // Test that q1 search with 1 additional probe returns some C2 points.
+ lshTest.Search(q1, k, neighbors, distances, 0, 1);
+ BOOST_REQUIRE(arma::all(
+ neighbors.col(0) == N ||
+ (neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2)));
+
+ // Test that q2 simple search returns some C2 points.
+ lshTest.Search(q2, k, neighbors, distances);
+ BOOST_REQUIRE(arma::all(
+ neighbors.col(0) == N ||
+ (neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2)));
+
+ // Test that q2 with 3 additional probes returns all C2 points.
+ lshTest.Search(q2, k, neighbors, distances, 0, 3);
+ BOOST_REQUIRE(arma::all(
+ neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2));
+ }
+
++
+/**
+ * Test: This test verifies that parallel query processing returns correct
+ * results for the bichromatic search.
+ */
+BOOST_AUTO_TEST_CASE(ParallelBichromatic)
+{
+ // kNN and LSH parameters (use LSH default parameters).
+ const int k = 4;
+ const int numTables = 16;
+ const int numProj = 3;
+
+ // Read iris training and testing data as reference and query sets.
+ const string trainSet = "iris_train.csv";
+ const string testSet = "iris_test.csv";
+ arma::mat rdata;
+ arma::mat qdata;
+ data::Load(trainSet, rdata, true);
+ data::Load(testSet, qdata, true);
+
+ // Where to store neighbors and distances
+ arma::Mat<size_t> sequentialNeighbors;
+ arma::Mat<size_t> parallelNeighbors;
+ arma::mat distances;
+
+ // Construct an LSH object. By default, it uses the maximum number of threads
+ LSHSearch<> lshTest(rdata, numProj, numTables); //default parameters
+ lshTest.Search(qdata, k, parallelNeighbors, distances);
+
+ // Now perform same search but with 1 thread
+ lshTest.MaxThreads(1);
+ lshTest.Search(qdata, k, sequentialNeighbors, distances);
+
+ // Require both have same results
+ double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
+ BOOST_REQUIRE_EQUAL(recall, 1);
+}
+
+/**
+ * Test: This test verifies that parallel query processing returns correct
+ * results for the monochromatic search.
+ */
+BOOST_AUTO_TEST_CASE(ParallelMonochromatic)
+{
+ // kNN and LSH parameters.
+ const int k = 4;
+ const int numTables = 16;
+ const int numProj = 3;
+
+ // Read iris training data as reference and query set.
+ const string trainSet = "iris_train.csv";
+ arma::mat rdata;
+ data::Load(trainSet, rdata, true);
+
+ // Where to store neighbors and distances
+ arma::Mat<size_t> sequentialNeighbors;
+ arma::Mat<size_t> parallelNeighbors;
+ arma::mat distances;
+
+ // Construct an LSH object, using maximum number of available threads.
+ LSHSearch<> lshTest(rdata, numProj, numTables);
+ lshTest.Search(k, parallelNeighbors, distances);
+
+ // Now perform same search but with 1 thread.
+ lshTest.MaxThreads(1);
+ lshTest.Search(k, sequentialNeighbors, distances);
+
+ // Require both have same results.
+ double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
+ BOOST_REQUIRE_EQUAL(recall, 1);
+}
+
+/**
+ * Test: This test verifies that processing a query in parallel returns the same
+ * results with processing it sequentially.
+ * Requires OMP_NESTED environment variable to be set to TRUE. To set it,
+ * execute:
+ * ```
+ * export OMP_NESTED=TRUE
+ * ```
+ */
+BOOST_AUTO_TEST_CASE(ParallelSingleQuery)
+{
+ // kNN and LSH parameters.
+ const int k = 4;
+ const int numTables = 16;
+ const int numProj = 3;
+
+ // Read iris training data as reference and query set.
+ const string trainSet = "iris_train.csv";
+ arma::mat rdata;
+ data::Load(trainSet, rdata, true);
+
+ arma::mat qdata = rdata.col(0); // Only 1 query.
+
+ // Where to store neighbors and distances.
+ arma::Mat<size_t> sequentialNeighbors;
+ arma::Mat<size_t> parallelNeighbors;
+ arma::mat distances;
+
+ // Construct an LSH object. By default, maximum number of threads are used.
+ LSHSearch<> lshTest(rdata, numProj, numTables);
+ lshTest.Search(qdata, k, parallelNeighbors, distances);
+
+ // Now perform same search but with 1 thread.
+ lshTest.MaxThreads(1);
+ lshTest.Search(qdata, k, sequentialNeighbors, distances);
+
+ // Require both have same results.
+ double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
+ BOOST_REQUIRE_EQUAL(recall, 1);
+}
+
BOOST_AUTO_TEST_CASE(LSHTrainTest)
{
// This is a not very good test that simply checks that the re-trained LSH
More information about the mlpack-git
mailing list