[mlpack-git] master: Add some warnings, and make some minor fixes. (70871c6)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:36 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d

>---------------------------------------------------------------

commit 70871c6ceb8ce1d6a80e6566f83ebbfe1e6d5671
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Nov 20 21:38:25 2015 +0000

    Add some warnings, and make some minor fixes.


>---------------------------------------------------------------

70871c6ceb8ce1d6a80e6566f83ebbfe1e6d5671
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 27 ++++++++++++++++++++-------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index a174de6..eeeb504 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -252,9 +252,22 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
 {
   // Ensure the dimensionality of the query set is correct.
   if (querySet.n_rows != referenceSet->n_rows)
-    Log::Fatal << "LSHSearch::Search(): dimensionality of query set ("
+  {
+    std::ostringstream oss;
+    oss << "LSHSearch::Search(): dimensionality of query set ("
         << querySet.n_rows << ") is not equal to the dimensionality the model "
         << "was trained on (" << referenceSet->n_rows << ")!" << std::endl;
+    throw std::invalid_argument(oss.str());
+  }
+
+  if (k > referenceSet->n_cols)
+  {
+    std::ostringstream oss;
+    oss << "LSHSearch::Search(): requested " << k << " approximate nearest "
+        << "neighbors, but reference set has " << referenceSet->n_cols
+        << " points!" << std::endl;
+    throw std::invalid_argument(oss.str());
+  }
 
   // Set the size of the neighbor and distance matrices.
   resultingNeighbors.set_size(k, querySet.n_cols);
@@ -262,6 +275,10 @@ void LSHSearch<SortPolicy>::Search(const arma::mat& querySet,
   distances.fill(SortPolicy::WorstDistance());
   resultingNeighbors.fill(referenceSet->n_cols);
 
+  // If the user asked for 0 nearest neighbors... uh... we're done.
+  if (k == 0)
+    return;
+
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
@@ -307,10 +324,6 @@ Search(const size_t k,
   distances.fill(SortPolicy::WorstDistance());
   resultingNeighbors.fill(referenceSet->n_cols);
 
-  // If the user asked for 0 nearest neighbors... uh... we're done.
-  if (k == 0)
-    return;
-
   size_t avgIndicesReturned = 0;
 
   Timer::Start("computing_neighbors");
@@ -398,6 +411,7 @@ void LSHSearch<SortPolicy>::BuildHash()
 
   // Step III: Create each hash table in the first level hash one by one and
   // putting them directly into the 'secondHashTable' for memory efficiency.
+  projections.clear(); // Reset projections vector.
   for (size_t i = 0; i < numTables; i++)
   {
     // Step IV: Obtain the 'numProj' projections for each table.
@@ -429,8 +443,7 @@ void LSHSearch<SortPolicy>::BuildHash()
 
     // Step VI: Putting the points in the 'secondHashTable' by hashing the key.
     // Now we hash every key, point ID to its corresponding bucket.
-    arma::rowvec secondHashVec = secondHashWeights.t()
-      * arma::floor(hashMat);
+    arma::rowvec secondHashVec = secondHashWeights.t() * arma::floor(hashMat);
 
     // This gives us the bucket for the corresponding point ID.
     for (size_t j = 0; j < secondHashVec.n_elem; j++)



More information about the mlpack-git mailing list