[mlpack-git] master: Merges multiprobe LSH (3af80c3)

gitdub at mlpack.org gitdub at mlpack.org
Fri Jul 1 10:48:25 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/34cf8d94f79c9a72ff4199676033b060cd039fcd...425324bf7fb7c86c85d10a909d8a59d4f69b7164

>---------------------------------------------------------------

commit 3af80c339a7a846bb3bdf305131c87eaa939fc01
Merge: 0d38271 bdbc2fe
Author: Yannis Mentekidis <mentekid at gmail.com>
Date:   Fri Jul 1 15:48:25 2016 +0100

    Merges multiprobe LSH


>---------------------------------------------------------------

3af80c339a7a846bb3bdf305131c87eaa939fc01
 HISTORY.md                                         |  16 +-
 README.md                                          |   4 +-
 src/mlpack/core/tree/CMakeLists.txt                |  10 +
 src/mlpack/core/tree/rectangle_tree.hpp            |   6 +
 .../tree/rectangle_tree/discrete_hilbert_value.hpp | 263 ++++++++++++
 .../rectangle_tree/discrete_hilbert_value_impl.hpp | 457 ++++++++++++++++++++
 .../tree/rectangle_tree/dual_tree_traverser.hpp    |   7 +-
 .../rectangle_tree/dual_tree_traverser_impl.hpp    |  24 +-
 .../hilbert_r_tree_auxiliary_information.hpp       | 120 ++++++
 .../hilbert_r_tree_auxiliary_information_impl.hpp  | 169 ++++++++
 .../hilbert_r_tree_descent_heuristic.hpp           |  53 +++
 .../hilbert_r_tree_descent_heuristic_impl.hpp      |  49 +++
 .../tree/rectangle_tree/hilbert_r_tree_split.hpp   |  94 ++++
 .../rectangle_tree/hilbert_r_tree_split_impl.hpp   | 340 +++++++++++++++
 .../rectangle_tree/no_auxiliary_information.hpp    | 117 +++++
 .../r_star_tree_descent_heuristic.hpp              |  16 +-
 .../r_star_tree_descent_heuristic_impl.hpp         |  50 ++-
 .../core/tree/rectangle_tree/r_star_tree_split.hpp |  24 +-
 .../tree/rectangle_tree/r_star_tree_split_impl.hpp | 151 ++++---
 .../rectangle_tree/r_tree_descent_heuristic.hpp    |   4 +-
 .../r_tree_descent_heuristic_impl.hpp              |  32 +-
 .../core/tree/rectangle_tree/r_tree_split.hpp      |  33 +-
 .../core/tree/rectangle_tree/r_tree_split_impl.hpp | 150 +++----
 .../core/tree/rectangle_tree/rectangle_tree.hpp    | 101 ++---
 .../tree/rectangle_tree/rectangle_tree_impl.hpp    | 457 ++++++++++++--------
 .../tree/rectangle_tree/single_tree_traverser.hpp  |   7 +-
 .../rectangle_tree/single_tree_traverser_impl.hpp  |  20 +-
 src/mlpack/core/tree/rectangle_tree/traits.hpp     |   7 +-
 src/mlpack/core/tree/rectangle_tree/typedef.hpp    |  45 +-
 .../x_tree_auxiliary_information.hpp               | 186 ++++++++
 .../core/tree/rectangle_tree/x_tree_split.hpp      |  61 +--
 .../core/tree/rectangle_tree/x_tree_split_impl.hpp | 243 +++++------
 src/mlpack/methods/amf/init_rules/CMakeLists.txt   |   1 +
 src/mlpack/methods/amf/init_rules/given_init.hpp   |  75 ++++
 src/mlpack/methods/lsh/lsh_main.cpp                |  13 +-
 src/mlpack/methods/lsh/lsh_search.hpp              |  86 +++-
 src/mlpack/methods/lsh/lsh_search_impl.hpp         | 437 +++++++++++++++++--
 src/mlpack/methods/neighbor_search/kfn_main.cpp    |  34 +-
 src/mlpack/methods/neighbor_search/knn_main.cpp    |  20 +-
 .../methods/neighbor_search/neighbor_search.hpp    |  15 +
 .../neighbor_search/neighbor_search_impl.hpp       |  32 +-
 .../neighbor_search/neighbor_search_rules.hpp      |   4 +
 .../neighbor_search/neighbor_search_rules_impl.hpp |  14 +-
 src/mlpack/methods/neighbor_search/ns_model.hpp    |  58 ++-
 .../methods/neighbor_search/ns_model_impl.hpp      |  64 ++-
 .../sort_policies/furthest_neighbor_sort.cpp       |   4 +-
 .../sort_policies/furthest_neighbor_sort.hpp       |  17 +
 .../sort_policies/nearest_neighbor_sort.hpp        |  15 +
 .../methods/preprocess/preprocess_split_main.cpp   |   2 +-
 .../methods/range_search/range_search_main.cpp     |   6 +-
 src/mlpack/methods/range_search/rs_model.cpp       |  21 +-
 src/mlpack/methods/range_search/rs_model.hpp       |   5 +-
 src/mlpack/methods/range_search/rs_model_impl.hpp  |  14 +
 src/mlpack/methods/rann/krann_main.cpp             |   6 +-
 src/mlpack/methods/rann/ra_model.hpp               |   5 +-
 src/mlpack/methods/rann/ra_model_impl.hpp          |  56 ++-
 src/mlpack/prereqs.hpp                             |   1 +
 src/mlpack/tests/CMakeLists.txt                    |   2 +
 src/mlpack/tests/activation_functions_test.cpp     |   2 +-
 src/mlpack/tests/ada_delta_test.cpp                |   2 +-
 src/mlpack/tests/adaboost_test.cpp                 |   3 +-
 src/mlpack/tests/adam_test.cpp                     |   2 +-
 src/mlpack/tests/akfn_test.cpp                     | 240 +++++++++++
 src/mlpack/tests/aknn_test.cpp                     | 405 ++++++++++++++++++
 src/mlpack/tests/arma_extend_test.cpp              |   2 +-
 src/mlpack/tests/armadillo_svd_test.cpp            |   2 +-
 src/mlpack/tests/aug_lagrangian_test.cpp           |   2 +-
 src/mlpack/tests/binarize_test.cpp                 |   2 +-
 src/mlpack/tests/cf_test.cpp                       |   2 +-
 src/mlpack/tests/cli_test.cpp                      |   2 +-
 src/mlpack/tests/convolution_test.cpp              |   2 +-
 src/mlpack/tests/convolutional_network_test.cpp    |   2 +-
 src/mlpack/tests/cosine_tree_test.cpp              |   2 +-
 src/mlpack/tests/decision_stump_test.cpp           |   2 +-
 src/mlpack/tests/det_test.cpp                      |   2 +-
 src/mlpack/tests/distribution_test.cpp             |   2 +-
 src/mlpack/tests/emst_test.cpp                     |   2 +-
 src/mlpack/tests/fastmks_test.cpp                  |   2 +-
 src/mlpack/tests/feedforward_network_test.cpp      |   2 +-
 src/mlpack/tests/gmm_test.cpp                      |   2 +-
 src/mlpack/tests/hmm_test.cpp                      |   2 +-
 src/mlpack/tests/hoeffding_tree_test.cpp           |   2 +-
 src/mlpack/tests/ind2sub_test.cpp                  |   2 +-
 src/mlpack/tests/init_rules_test.cpp               |   2 +-
 src/mlpack/tests/kernel_pca_test.cpp               |   2 +-
 src/mlpack/tests/kernel_test.cpp                   |   2 +-
 src/mlpack/tests/kernel_traits_test.cpp            |   2 +-
 src/mlpack/tests/kfn_test.cpp                      |   2 +-
 src/mlpack/tests/kmeans_test.cpp                   |   3 +-
 src/mlpack/tests/knn_test.cpp                      |  18 +-
 src/mlpack/tests/krann_search_test.cpp             |   8 +-
 src/mlpack/tests/lars_test.cpp                     |   2 +-
 src/mlpack/tests/layer_traits_test.cpp             |   2 +-
 src/mlpack/tests/lbfgs_test.cpp                    |   2 +-
 src/mlpack/tests/lin_alg_test.cpp                  |   2 +-
 src/mlpack/tests/linear_regression_test.cpp        |   2 +-
 src/mlpack/tests/load_save_test.cpp                |   2 +-
 src/mlpack/tests/local_coordinate_coding_test.cpp  |   2 +-
 src/mlpack/tests/log_test.cpp                      |   2 +-
 src/mlpack/tests/logistic_regression_test.cpp      |   2 +-
 src/mlpack/tests/lrsdp_test.cpp                    |   2 +-
 src/mlpack/tests/lsh_test.cpp                      | 158 ++++++-
 src/mlpack/tests/lstm_peephole_test.cpp            |   2 +-
 src/mlpack/tests/math_test.cpp                     |   2 +-
 src/mlpack/tests/matrix_completion_test.cpp        |   2 +-
 src/mlpack/tests/maximal_inputs_test.cpp           |   2 +-
 src/mlpack/tests/mean_shift_test.cpp               |   3 +-
 src/mlpack/tests/metric_test.cpp                   |   2 +-
 src/mlpack/tests/minibatch_sgd_test.cpp            |   2 +-
 src/mlpack/tests/mlpack_test.cpp                   |   2 +-
 src/mlpack/tests/nbc_test.cpp                      |   2 +-
 src/mlpack/tests/nca_test.cpp                      |   2 +-
 src/mlpack/tests/network_util_test.cpp             |   2 +-
 src/mlpack/tests/nmf_test.cpp                      |  17 +-
 src/mlpack/tests/nystroem_method_test.cpp          |   2 +-
 src/mlpack/tests/pca_test.cpp                      |   2 +-
 src/mlpack/tests/perceptron_test.cpp               |   2 +-
 src/mlpack/tests/performance_functions_test.cpp    |   2 +-
 src/mlpack/tests/pooling_rules_test.cpp            |   2 +-
 src/mlpack/tests/quic_svd_test.cpp                 |   2 +-
 src/mlpack/tests/radical_test.cpp                  |   2 +-
 src/mlpack/tests/range_search_test.cpp             |  16 +-
 src/mlpack/tests/rectangle_tree_test.cpp           | 472 ++++++++++++++++-----
 src/mlpack/tests/recurrent_network_test.cpp        |   2 +-
 src/mlpack/tests/regularized_svd_test.cpp          |   2 +-
 src/mlpack/tests/rmsprop_test.cpp                  |   2 +-
 src/mlpack/tests/sa_test.cpp                       |   2 +-
 src/mlpack/tests/sdp_primal_dual_test.cpp          |   2 +-
 src/mlpack/tests/serialization.hpp                 |   2 +-
 src/mlpack/tests/serialization_test.cpp            |   2 +-
 src/mlpack/tests/sgd_test.cpp                      |   2 +-
 src/mlpack/tests/softmax_regression_test.cpp       |   2 +-
 src/mlpack/tests/sort_policy_test.cpp              |   2 +-
 src/mlpack/tests/sparse_autoencoder_test.cpp       |   2 +-
 src/mlpack/tests/sparse_coding_test.cpp            |   2 +-
 src/mlpack/tests/split_data_test.cpp               |   2 +-
 src/mlpack/tests/svd_batch_test.cpp                |   2 +-
 src/mlpack/tests/svd_incremental_test.cpp          |   2 +-
 src/mlpack/tests/termination_policy_test.cpp       |   2 +-
 ...d_boost_test_definitions.hpp => test_tools.hpp} |  14 +-
 src/mlpack/tests/tree_test.cpp                     |   2 +-
 src/mlpack/tests/tree_traits_test.cpp              |   2 +-
 src/mlpack/tests/union_find_test.cpp               |   2 +-
 143 files changed, 4771 insertions(+), 1055 deletions(-)

diff --cc src/mlpack/methods/lsh/lsh_search.hpp
index 0882393,44a7b64..811a73d
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@@ -187,9 -205,9 +205,10 @@@ class LSHSearc
    void Search(const size_t k,
                arma::Mat<size_t>& resultingNeighbors,
                arma::mat& distances,
-               const size_t numTablesToSearch = 0);
+               const size_t numTablesToSearch = 0,
+               size_t T = 0);
  
 +
    /**
     * Compute the recall (% of neighbors found) given the neighbors returned by
     * LSHSearch::Search and a "ground truth" set of neighbors.  The recall
diff --cc src/mlpack/methods/lsh/lsh_search_impl.hpp
index dae15a6,ab81088..8d3cdb1
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@@ -351,80 -339,274 +363,343 @@@ void LSHSearch<SortPolicy>::BaseCase(co
  
    // SortDistance() returns (size_t() - 1) if we shouldn't add it.
    if (insertPosition != (size_t() - 1))
 -    InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
 -        referenceIndex, distance);
 +  {  
 +    #pragma omp critical
 +    {
 +      InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
 +          referenceIndex, distance);
 +    }
 +  }
  }
 +*/
  
 +// Base case where the query set is the reference set.  (So, we can't return
 +// ourselves as the nearest neighbor.)
 +template<typename SortPolicy>
 +inline force_inline
 +void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
 +                                     const arma::uvec& referenceIndices,
 +                                     arma::Mat<size_t>& neighbors,
 +                                     arma::mat& distances) const
 +{
 +  for (size_t j = 0; j < referenceIndices.n_elem; ++j)
 +  {
 +    const size_t referenceIndex = referenceIndices[j];
 +    // If the points are the same, skip this point.
 +    if (queryIndex == referenceIndex)
 +      continue;
 +
 +    const double distance = metric::EuclideanDistance::Evaluate(
 +        referenceSet->unsafe_col(queryIndex),
 +        referenceSet->unsafe_col(referenceIndex));
 +
 +    // If this distance is better than any of the current candidates, the
 +    // SortDistance() function will give us the position to insert it into.
 +    arma::vec queryDist = distances.unsafe_col(queryIndex);
 +    arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
 +    size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
 +        distance);
 +
 +    // SortDistance() returns (size_t() - 1) if we shouldn't add it.
 +    if (insertPosition != (size_t() - 1))
 +      InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
 +          referenceIndex, distance);
 +  }
 +}
 +
 +// Base case for bichromatic search.
 +template<typename SortPolicy>
 +inline force_inline
 +void LSHSearch<SortPolicy>::BaseCase(const size_t queryIndex,
 +                                     const arma::uvec& referenceIndices,
 +                                     const arma::mat& querySet,
 +                                     arma::Mat<size_t>& neighbors,
 +                                     arma::mat& distances) const
 +{
 +  for (size_t j = 0; j < referenceIndices.n_elem; ++j)
 +  {
 +    const size_t referenceIndex = referenceIndices[j];
 +    const double distance = metric::EuclideanDistance::Evaluate(
 +        querySet.unsafe_col(queryIndex),
 +        referenceSet->unsafe_col(referenceIndex));
 +
 +    // If this distance is better than any of the current candidates, the
 +    // SortDistance() function will give us the position to insert it into.
 +    arma::vec queryDist = distances.unsafe_col(queryIndex);
 +    arma::Col<size_t> queryIndices = neighbors.unsafe_col(queryIndex);
 +    size_t insertPosition = SortPolicy::SortDistance(queryDist, queryIndices,
 +        distance);
 +
 +    // SortDistance() returns (size_t() - 1) if we shouldn't add it.
 +    if (insertPosition != (size_t() - 1))
 +      InsertNeighbor(distances, neighbors, queryIndex, insertPosition,
 +          referenceIndex, distance);
 +  }
 +}
  template<typename SortPolicy>
+ inline force_inline
+ double LSHSearch<SortPolicy>::PerturbationScore(
+     const std::vector<bool>& A,
+     const arma::vec& scores) const
+ {
+   double score = 0.0;
+   for (size_t i = 0; i < A.size(); ++i)
+     if (A[i])
+       score += scores(i); // add scores of non-zero indices
+   return score;
+ }
+ 
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationShift(std::vector<bool>& A) const
+ {
+   size_t maxPos = 0;
+   for (size_t i = 0; i < A.size(); ++i)
+     if (A[i] == 1) // marked true
+       maxPos=i;
+   
+   if ( maxPos + 1 < A.size()) // otherwise, this is an invalid vector 
+   {
+     A[maxPos] = 0;
+     A[maxPos + 1] = 1;
+     return true; // valid
+   }
+   return false; // invalid
+ }
+ 
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationExpand(std::vector<bool>& A) const
+ {
+   // Find the last '1' in A
+   size_t maxPos = 0;
+   for (size_t i = 0; i < A.size(); ++i)
+     if (A[i]) // marked true
+       maxPos = i;
+ 
+   if (maxPos + 1 < A.size()) // otherwise, this is an invalid vector
+   {
+     A[maxPos + 1] = 1;
+     return true;
+   }
+   return false;
+ }
+ 
+ template<typename SortPolicy>
+ inline force_inline
+ bool LSHSearch<SortPolicy>::PerturbationValid(
+     const std::vector<bool>& A) const
+ {
+   // Use check to mark dimensions we have seen before in A. If a dimension is
+   // seen twice (or more), A is not a valid perturbation.
+   std::vector<bool> check(numProj);
+ 
+   if (A.size() > 2 * numProj)
+     return false; // This should never happen.
+ 
+   // Check that we only see each dimension once. If not, vector is not valid.
+   for (size_t i = 0; i < A.size(); ++i)
+   {
+     // Only check dimensions that were included.
+     if (!A[i])
+       continue;
+ 
+     // If dimesnion is unseen thus far, mark it as seen.
+     if (check[i % numProj] == false)
+       check[i % numProj] = true;
+     else
+       return false; // If dimension was seen before, set is not valid.
+   }
+   // If we didn't fail, set is valid.
+   return true;
+ }
+ 
+ // Compute additional probing bins for a query
+ template<typename SortPolicy>
+ void LSHSearch<SortPolicy>::GetAdditionalProbingBins(
+     const arma::vec& queryCode,
+     const arma::vec& queryCodeNotFloored,
+     const size_t T,
+     arma::mat& additionalProbingBins) const
+ {
+ 
+   // No additional bins requested. Our work is done.
+   if (T == 0)
+     return;
+ 
+   // Each column of additionalProbingBins is the code of a bin.
+   additionalProbingBins.set_size(numProj, T);
+   
+   // Copy the query's code, then in the end we will  add/subtract according 
+   // to perturbations we calculated.
+   for (size_t c = 0; c < T; ++c)
+     additionalProbingBins.col(c) = queryCode;
+ 
+ 
+   // Calculate query point's projection position.
+   arma::mat projection = queryCodeNotFloored;
+ 
+   // Use projection to calculate query's distance from hash limits.
+   arma::vec limLow = projection - queryCode * hashWidth;
+   arma::vec limHigh = hashWidth - limLow;
+ 
+   // Calculate scores. score = distance^2.
+   arma::vec scores(2 * numProj);
+   scores.rows(0, numProj - 1) = arma::pow(limLow, 2);
+   scores.rows(numProj, (2 * numProj) - 1) = arma::pow(limHigh, 2);
+ 
+   // Actions vector describes what perturbation (-1/+1) corresponds to a score.
+   arma::Col<short int> actions(2 * numProj); // will be [-1 ... 1 ...]
+   actions.rows(0, numProj - 1) = // First numProj rows.
+     -1 * arma::ones< arma::Col<short int> > (numProj); // -1s
+   actions.rows(numProj, (2 * numProj) - 1) = // Last numProj rows.
+     arma::ones< arma::Col<short int> > (numProj); // 1s
+ 
+ 
+   // Acting dimension vector shows which coordinate to transform according to
+   // actions (actions are described by actions vector above).
+   arma::Col<size_t> positions(2 * numProj); // Will be [0 1 2 ... 0 1 2 ...].
+   positions.rows(0, numProj - 1) =
+     arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+   positions.rows(numProj, 2 * numProj - 1) =
+     arma::linspace< arma::Col<size_t> >(0, numProj - 1, numProj);
+ 
+   // Special case: No need to create heap for 1 or 2 codes.
+   if (T <= 2)
+   {
+     // First, find location of minimum score, generate 1 perturbation vector,
+     // and add its code to additionalProbingBins column 0.
+ 
+     // Find location and value of smallest element of scores vector.
+     double minscore = scores[0];
+     size_t minloc = 0;
+     for (size_t s = 1; s < (2 * numProj); ++s)
+     {
+       if (minscore > scores[s])
+       {
+         minscore = scores[s];
+         minloc = s;
+       }
+     }
+     
+     // Add or subtract 1 to dimension corresponding to minimum score.
+     additionalProbingBins(positions[minloc], 0) += actions[minloc];
+     if (T == 1)
+       return; // Done if asked for only 1 code.
+ 
+     // Now, find location of second smallest score and generate one more vector.
+     // The second perturbation vector still can't comprise of more than one
+     // change in the bin codes, because of the way perturbation vectors
+     // are generated: First we create the one with the smallest score (Ao) and
+     // then we either add 1 extra dimension to it (Ae) or shift it by one (As).
+     // Since As contains the second smallest score, and Ae contains both the
+     // smallest and the second smallest, it's obvious that score(Ae) >
+     // score(As). Therefore the second perturbation vector is ALWAYS the vector
+     // containing only the second-lowest scoring perturbation.
+     
+     double minscore2 = scores[0];
+     size_t minloc2 = 0;
+     for (size_t s = 0; s < (2 * numProj); ++s) // here we can't start from 1
+     {
+       if (minscore2 > scores[s] && s != minloc) //second smallest
+       {
+         minscore2 = scores[s];
+         minloc2 = s;
+       }
+     }
+ 
+     // Add or subtract 1 to create second-lowest scoring vector.
+     additionalProbingBins(positions[minloc2], 1) += actions[minloc2];
+     return;
+   }
+ 
+   // General case: more than 2 perturbation vectors require use of minheap.
+ 
+   // Sort everything in increasing order.
+   arma::uvec sortidx = arma::sort_index(scores);
+   scores = scores(sortidx);
+   actions = actions(sortidx);
+   positions = positions(sortidx);
+ 
+ 
+   // Theory:
+   // A probing sequence is a sequence of T probing bins where a query's
+   // neighbors are most likely to be. Likelihood is dependent only on a bin's
+   // score, which is the sum of scores of all dimension-action pairs, so we
+   // need to calculate the T smallest sums of scores that are not conflicting.
+   //
+   // Method:
+   // Store each perturbation set (pair of (dimension, action)) in a
+   // std::vector. Create a minheap of scores, with each node pointing to its
+   // relevant perturbation set. Each perturbation set popped from the minheap
+   // is the next most likely perturbation set.
+   // Transform perturbation set to perturbation vector by setting the
+   // dimensions specified by the set to queryCode+action (action is {-1, 1}).
+ 
+   // Perturbation sets (A) mark with 1 the (score, action, dimension) positions
+   // included in a given perturbation vector. Other spaces are 0.
+   std::vector<bool> Ao(2 * numProj);
+   Ao[0] = 1; // Smallest vector includes only smallest score.
+ 
+   std::vector< std::vector<bool> > perturbationSets;
+   perturbationSets.push_back(Ao); // Storage of perturbation sets.
+ 
+   std::priority_queue<
+     std::pair<double, size_t>,        // contents: pairs of (score, index)
+     std::vector<                      // container: vector of pairs
+       std::pair<double, size_t>
+       >,
+     std::greater< std::pair<double, size_t> > // comparator of pairs
+   > minHeap; // our minheap
+ 
+   // Start by adding the lowest scoring set to the minheap.
+   minHeap.push( std::make_pair(PerturbationScore(Ao, scores), 0) );
+ 
+   // Loop invariable: after pvec iterations, additionalProbingBins contains pvec
+   // valid codes of the lowest-scoring bins (bins most likely to contain
+   // neighbors of the query).
+   for (size_t pvec = 0; pvec < T; ++pvec)
+   {
+     std::vector<bool> Ai;
+     do
+     {
+       // Get the perturbation set corresponding to the minimum score.
+       Ai = perturbationSets[ minHeap.top().second ];
+       minHeap.pop(); // .top() returns, .pop() removes
+ 
+       // Shift operation on Ai (replace max with max+1).
+       std::vector<bool> As = Ai;
+       if (PerturbationShift(As) && PerturbationValid(As))
+         // Don't add invalid sets.
+       {
+         perturbationSets.push_back(As); // add shifted set to sets
+         minHeap.push(
+             std::make_pair(PerturbationScore(As, scores), 
+             perturbationSets.size() - 1));
+       }
+ 
+       // Expand operation on Ai (add max+1 to set).
+       std::vector<bool> Ae = Ai;
+       if (PerturbationExpand(Ae) && PerturbationValid(Ae)) 
+         // Don't add invalid sets.
+       {
+         perturbationSets.push_back(Ae); // add expanded set to sets
+         minHeap.push(
+             std::make_pair(PerturbationScore(Ae, scores),
+             perturbationSets.size() - 1));
+       }
+ 
+     } while (!PerturbationValid(Ai));//Discard invalid perturbations
+ 
+     // Found valid perturbation set Ai. Construct perturbation vector from set.
+     for (size_t pos = 0; pos < Ai.size(); ++pos)
+       // If Ai[pos] is marked, add action to probing vector.
+       additionalProbingBins(positions(pos), pvec) 
+           += Ai[pos] ? actions(pos) : 0;
+   }
+ }
+ 
+ template<typename SortPolicy>
  template<typename VecType>
  void LSHSearch<SortPolicy>::ReturnIndicesFromTable(
      const VecType& queryPoint,
@@@ -517,19 -733,21 +826,22 @@@
  
      // Retrieve candidates.
      size_t start = 0;
 -    for (size_t i = 0; i < numTablesToSearch; ++i) // For all tables
 +
 +    for (long long int i = 0; i < numTablesToSearch; ++i) // For all tables
      {
-       const size_t hashInd = (size_t) hashVec[i]; // Find the query's bucket.
-       const size_t tableRow = bucketRowInHashTable[hashInd];
+       for (size_t p = 0; p < T + 1; ++p)
+       {
+         const size_t hashInd =  hashMat(p, i); // Find the query's bucket.
+         const size_t tableRow = bucketRowInHashTable[hashInd];
  
-       // Store all secondHashTable points in the candidates set.
-       if (tableRow != secondHashSize)
-         for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
-           refPointsConsideredSmall(start++) = secondHashTable[tableRow][j];
+         if (tableRow < secondHashSize)
+          // Store all secondHashTable points in the candidates set.
+          for (size_t j = 0; j < bucketContentSize[tableRow]; ++j)
+            refPointsConsideredSmall(start++) = secondHashTable[tableRow](j);
+       }
      }
  
 -    // Only keep unique candidates.
 +    // Keep only one copy of each candidate.
      referenceIndices = arma::unique(refPointsConsideredSmall);
      return;
    }
@@@ -655,12 -880,11 +1001,13 @@@ Search(const size_t k
      // Hash every query into every hash table and eventually into the
      // 'secondHashTable' to obtain the neighbor candidates.
      arma::uvec refIndices;
-     ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch);
+     ReturnIndicesFromTable(referenceSet->col(i), refIndices, numTablesToSearch,
+         Teffective);
  
      // An informative book-keeping for the number of neighbor candidates
 -    // returned on average.
 +    // returned on average. Make atomic to avoid race conditions when multiple
 +    // threads are running.
 +    #pragma omp atomic
      avgIndicesReturned += refIndices.n_elem;
  
      // Sequentially go through all the candidates and save the best 'k'
diff --cc src/mlpack/tests/lsh_test.cpp
index c594991,4431f94..8aac238
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@@ -466,118 -471,146 +471,260 @@@ BOOST_AUTO_TEST_CASE(DeterministicNoMer
    }
  }
  
++
+ /**
+  * Test: Create an LSHSearch object and use an increasing number of probes to
+  * search for points. Require that recall for the same object doesn't decrease
+  * with increasing number of probes. Also require that at least a few times
+  * there's some increase in recall.
+  */
+ BOOST_AUTO_TEST_CASE(MultiprobeTest)
+ {
++  // Test parameters.
+   const double epsilonIncrease = 0.05;
 -  const size_t repetitions = 5; // train five objects
++  const size_t repetitions = 5; // Train five objects.
+ 
+   const size_t probeTrials = 5;
+   const size_t numProbes[probeTrials] = {0, 1, 2, 3, 4};
+ 
 -
 -  /// algorithm parameters
++  // Algorithm parameters.
+   const int k = 4;
+   const int numTables = 16;
+   const int numProj = 3;
+   const double hashWidth = 0;
+   const int secondHashSize = 99901;
+   const int bucketSize = 500;
+ 
+   const string trainSet = "iris_train.csv";
+   const string testSet = "iris_test.csv";
+   arma::mat rdata;
+   arma::mat qdata;
+   data::Load(trainSet, rdata, true);
+   data::Load(testSet, qdata, true);
+   
 -  // Run classic knn on reference set
++  // Run classic knn on reference set.
+   KNN knn(rdata);
+   arma::Mat<size_t> groundTruth;
+   arma::mat groundDistances;
+   knn.Search(qdata, k, groundTruth, groundDistances);
+ 
+   bool foundIncrease = 0;
+ 
+   for (size_t rep = 0; rep < repetitions; ++rep)
+   {
+     // Train a model.
+     LSHSearch<> multiprobeTest(rdata, numProj, numTables, hashWidth,
+         secondHashSize, bucketSize);
+ 
+     double prevRecall = 0;
+     // Search with varying number of probes.
+     for (size_t p = 0; p < probeTrials; ++p)
+     {
+       arma::Mat<size_t> lshNeighbors;
+       arma::mat lshDistances;
+ 
+       multiprobeTest.Search(qdata, k, lshNeighbors, lshDistances, 0,
+           numProbes[p]);
+ 
+       // Compute recall of this run.
+       double recall = LSHSearch<>::ComputeRecall(lshNeighbors, groundTruth);
+       if (p > 0)
+       {
+         // More probes should at the very least not lower recall...
+         BOOST_REQUIRE_GE(recall, prevRecall);
+ 
+         // ... and should ideally increase it a bit.
+         if (recall > prevRecall + epsilonIncrease)
+           foundIncrease = true;
+         prevRecall = recall;
+       }
+     }
+   }
+   BOOST_REQUIRE(foundIncrease);
+ }
+ 
+ /**
+  * Test: This is a deterministic test that verifies multiprobe LSH works
+  * correctly. To do this, we generate two queries, q1 and q2. q1 is hashed
+  * directly under cluster C2, q2 is hashed in C2's center.
+  * We verify that:
+  * 1) q1 should have no neighbors without multiprobe.
+  * 2) q1 should have neighbors only from C2 with 1 additional probe.
+  * 3) q2 should have all neighbors found with 3 additional probes.
+  */
+ BOOST_AUTO_TEST_CASE(MultiprobeDeterministicTest)
+ {
+   // Generate known deterministic clusters of points.
+   const size_t N = 40;
+   arma::mat rdata;
+   GetPointset(N, rdata);
+ 
+   const int k = N / 4;
+   const double hashWidth = 1;
+   const int secondHashSize = 99901;
+   const int bucketSize = 500;
+ 
+   // 1 table, projections on orthonormal plane.
+   arma::cube projections(2, 2, 1);
+   projections(0, 0, 0) = 1;
+   projections(1, 0, 0) = 0;
+   projections(0, 1, 0) = 0;
+   projections(1, 1, 0) = 1;
+ 
+   // Construct LSH object with given tables.
+   LSHSearch<> lshTest(rdata, projections,
+                       hashWidth, secondHashSize, bucketSize);
+ 
+   const arma::mat offsets = lshTest.Offsets();
+ 
+   // Construct q1 so it is hashed directly under C2.
+   arma::mat q1;
+   q1 << 3.9 << arma::endr << 2.99;
+   q1 -= offsets;
+ 
+   // Construct q2 so it is hashed near the center of C2.
+   arma::mat q2;
+   q2 << 3.6 << arma::endr << 3.6;
+   q2 -= offsets;
+ 
+   arma::Mat<size_t> neighbors;
+   arma::mat distances;
+ 
+   // Test that q1 simple search comes up empty.
+   lshTest.Search(q1, k, neighbors, distances);
+   BOOST_REQUIRE(arma::all(neighbors.col(0) == N));
+ 
+   // Test that q1 search with 1 additional probe returns some C2 points.
+   lshTest.Search(q1, k, neighbors, distances, 0, 1);
+   BOOST_REQUIRE(arma::all(
+         neighbors.col(0) == N ||
+         (neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2)));
+ 
+   // Test that q2 simple search returns some C2 points.
+   lshTest.Search(q2, k, neighbors, distances);
+   BOOST_REQUIRE(arma::all(
+       neighbors.col(0) == N ||
+       (neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2)));
+ 
+   // Test that q2 with 3 additional probes returns all C2 points.
+   lshTest.Search(q2, k, neighbors, distances, 0, 3);
+   BOOST_REQUIRE(arma::all(
+       neighbors.col(0) >= N / 4 && neighbors.col(0) < N / 2));
+ }
+ 
++
 +/**
 + * Test: This test verifies that parallel query processing returns correct
 + * results for the bichromatic search.
 + */
 +BOOST_AUTO_TEST_CASE(ParallelBichromatic)
 +{
 +  // kNN and LSH parameters (use LSH default parameters).
 +  const int k = 4;
 +  const int numTables = 16;
 +  const int numProj = 3;
 +
 +  // Read iris training and testing data as reference and query sets.
 +  const string trainSet = "iris_train.csv";
 +  const string testSet = "iris_test.csv";
 +  arma::mat rdata;
 +  arma::mat qdata;
 +  data::Load(trainSet, rdata, true);
 +  data::Load(testSet, qdata, true);
 +
 +  // Where to store neighbors and distances
 +  arma::Mat<size_t> sequentialNeighbors;
 +  arma::Mat<size_t> parallelNeighbors;
 +  arma::mat distances;
 +
 +  // Construct an LSH object. By default, it uses the maximum number of threads
 +  LSHSearch<> lshTest(rdata, numProj, numTables); //default parameters
 +  lshTest.Search(qdata, k, parallelNeighbors, distances);
 +
 +  // Now perform same search but with 1 thread
 +  lshTest.MaxThreads(1);
 +  lshTest.Search(qdata, k, sequentialNeighbors, distances);
 +
 +  // Require both have same results
 +  double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
 +  BOOST_REQUIRE_EQUAL(recall, 1);
 +}
 +
 +/**
 + * Test: This test verifies that parallel query processing returns correct
 + * results for the monochromatic search.
 + */
 +BOOST_AUTO_TEST_CASE(ParallelMonochromatic)
 +{
 +  // kNN and LSH parameters.
 +  const int k = 4;
 +  const int numTables = 16;
 +  const int numProj = 3;
 +
 +  // Read iris training data as reference and query set.
 +  const string trainSet = "iris_train.csv";
 +  arma::mat rdata;
 +  data::Load(trainSet, rdata, true);
 +
 +  // Where to store neighbors and distances
 +  arma::Mat<size_t> sequentialNeighbors;
 +  arma::Mat<size_t> parallelNeighbors;
 +  arma::mat distances;
 +
 +  // Construct an LSH object, using maximum number of available threads.
 +  LSHSearch<> lshTest(rdata, numProj, numTables);
 +  lshTest.Search(k, parallelNeighbors, distances);
 +
 +  // Now perform same search but with 1 thread.
 +  lshTest.MaxThreads(1);
 +  lshTest.Search(k, sequentialNeighbors, distances);
 +
 +  // Require both have same results.
 +  double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
 +  BOOST_REQUIRE_EQUAL(recall, 1);
 +}
 +
 +/**
 + * Test: This test verifies that processing a query in parallel returns the same
 + * results with processing it sequentially. 
 + * Requires OMP_NESTED environment variable to be set to TRUE. To set it,
 + * execute:
 + * ```
 + * export OMP_NESTED=TRUE
 + * ```
 + */
 +BOOST_AUTO_TEST_CASE(ParallelSingleQuery)
 +{
 +  // kNN and LSH parameters.
 +  const int k = 4;
 +  const int numTables = 16;
 +  const int numProj = 3;
 +
 +  // Read iris training data as reference and query set.
 +  const string trainSet = "iris_train.csv";
 +  arma::mat rdata;
 +  data::Load(trainSet, rdata, true);
 +
 +  arma::mat qdata = rdata.col(0); // Only 1 query.
 +
 +  // Where to store neighbors and distances.
 +  arma::Mat<size_t> sequentialNeighbors;
 +  arma::Mat<size_t> parallelNeighbors;
 +  arma::mat distances;
 +
 +  // Construct an LSH object. By default, maximum number of threads are used.
 +  LSHSearch<> lshTest(rdata, numProj, numTables);
 +  lshTest.Search(qdata, k, parallelNeighbors, distances);
 +
 +  // Now perform same search but with 1 thread.
 +  lshTest.MaxThreads(1);
 +  lshTest.Search(qdata, k, sequentialNeighbors, distances);
 +
 +  // Require both have same results.
 +  double recall = LSHSearch<>::ComputeRecall(sequentialNeighbors, parallelNeighbors);
 +  BOOST_REQUIRE_EQUAL(recall, 1);
 +}
 +
  BOOST_AUTO_TEST_CASE(LSHTrainTest)
  {
    // This is a not very good test that simply checks that the re-trained LSH




More information about the mlpack-git mailing list