[mlpack-git] master: Add some warnings, and make some minor fixes. (70871c6)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:36 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d
>---------------------------------------------------------------
commit 70871c6ceb8ce1d6a80e6566f83ebbfe1e6d5671
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Nov 20 21:38:25 2015 +0000
Add some warnings, and make some minor fixes.
>---------------------------------------------------------------
70871c6ceb8ce1d6a80e6566f83ebbfe1e6d5671
src/mlpack/methods/lsh/lsh_search_impl.hpp | 27 ++++++++++++++++++++-------
1 file changed, 20 insertions(+), 7 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a174de6..eeeb504 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -252,9 +252,22 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
{
// Ensure the dimensionality of the query set is correct.
if (querySet.n_rows != referenceSet->n_rows)
- Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
+ {
+ std::ostringstream oss;
+ oss << "LSHSearch::Search(): dimensionality of query set ("
<< querySet.n_rows << ") is not equal to the dimensionality the model "
<< "was trained on (" << referenceSet->n_rows << ")!" << std::endl;
+ throw std::invalid_argument(oss.str());
+ }
+
+ if (k > referenceSet->n_cols)
+ {
+ std::ostringstream oss;
+ oss << "LSHSearch::Search(): requested " << k << " approximate nearest "
+ << "neighbors, but reference set has " << referenceSet->n_cols
+ << " points!" << std::endl;
+ throw std::invalid_argument(oss.str());
+ }
// Set the size of the neighbor and distance matrices.
resultingNeighbors.set_size(k, querySet.n_cols);
@@ -262,6 +275,10 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
distances.fill(SortPolicy::WorstDistance());
resultingNeighbors.fill(referenceSet->n_cols);
+ // If the user asked for 0 nearest neighbors... uh... we're done.
+ if (k == 0)
+ return;
+
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
@@ -307,10 +324,6 @@ Search(const size_t k,
distances.fill(SortPolicy::WorstDistance());
resultingNeighbors.fill(referenceSet->n_cols);
- // If the user asked for 0 nearest neighbors... uh... we're done.
- if (k == 0)
- return;
-
size_t avgIndicesReturned = 0;
Timer::Start("computing_neighbors");
@@ -398,6 +411,7 @@ void LSHSearch<SortPolicy>::BuildHash()
// Step III: Create each hash table in the first level hash one by one and
// putting them directly into the 'secondHashTable' for memory efficiency.
+ projections.clear(); // Reset projections vector.
for (size_t i = 0; i < numTables; i++)
{
// Step IV: Obtain the 'numProj' projections for each table.
@@ -429,8 +443,7 @@ void LSHSearch<SortPolicy>::BuildHash()
// Step VI: Putting the points in the 'secondHashTable' by hashing the key.
// Now we hash every key, point ID to its corresponding bucket.
- arma::rowvec secondHashVec = secondHashWeights.t()
- * arma::floor(hashMat);
+ arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
// This gives us the bucket for the corresponding point ID.
for (size_t j = 0; j < secondHashVec.n_elem; j++)
More information about the mlpack-git
mailing list