[mlpack-git] master: Add Train() method. (b4c2948)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:26 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d

>---------------------------------------------------------------

commit b4c2948c7ad9dd3c1da368e7ffc005ead8422400
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Nov 13 08:22:05 2015 +0000

    Add Train() method.


>---------------------------------------------------------------

b4c2948c7ad9dd3c1da368e7ffc005ead8422400
 src/mlpack/methods/lsh/lsh_search.hpp      | 23 ++++++++++++----
 src/mlpack/methods/lsh/lsh_search_impl.hpp | 44 ++++++++++++++++++++++++------
 src/mlpack/tests/lsh_test.cpp              | 24 ++++++++++++++++
 3 files changed, 76 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index e07559c..2832b87 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -77,6 +77,17 @@ class LSHSearch
   ~LSHSearch();
 
   /**
+   * Train the LSH model on the given dataset.  This means building new hash
+   * tables.
+   */
+  void Train(const arma::mat& referenceSet,
+             const size_t numProj,
+             const size_t numTables,
+             const double hashWidth = 0.0,
+             const size_t secondHashSize = 99901,
+             const size_t bucketSize = 500);
+
+  /**
    * Compute the nearest neighbors of the points in the given query set and
    * store the output in the given matrices.  The matrices will be set to the
    * size of n columns by k rows, where n is the number of points in the query
@@ -225,9 +236,9 @@ class LSHSearch
   bool ownsSet;
 
   //! The number of projections.
-  const size_t numProj;
+  size_t numProj;
   //! The number of hash tables.
-  const size_t numTables;
+  size_t numTables;
 
   //! The std::vector containing the projection matrix of each table.
   std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
@@ -239,13 +250,13 @@ class LSHSearch
   double hashWidth;
 
   //! The big prime representing the size of the second hash.
-  const size_t secondHashSize;
+  size_t secondHashSize;
 
   //! The weights of the second hash.
   arma::vec secondHashWeights;
 
   //! The bucket size of the second hash.
-  const size_t bucketSize;
+  size_t bucketSize;
 
   //! The final hash table; should be (< secondHashSize) x bucketSize.
   arma::Mat<size_t> secondHashTable;
@@ -262,8 +273,8 @@ class LSHSearch
   size_t distanceEvaluations;
 }; // class LSHSearch
 
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
 
 // Include implementation.
 #include "lsh_search_impl.hpp"
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 2463751..41de73b 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -21,7 +21,7 @@ LSHSearch(const arma::mat& referenceSet,
           const double hashWidthIn,
           const size_t secondHashSize,
           const size_t bucketSize) :
-  referenceSet(&referenceSet),
+  referenceSet(NULL), // This will be set in Train().
   ownsSet(false),
   numProj(numProj),
   numTables(numTables),
@@ -30,6 +30,40 @@ LSHSearch(const arma::mat& referenceSet,
   bucketSize(bucketSize),
   distanceEvaluations(0)
 {
+  // Pass work to training function.
+  Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
+      bucketSize);
+}
+
+// Destructor.
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::~LSHSearch()
+{
+  if (ownsSet)
+    delete referenceSet;
+}
+
+// Train on a new reference set.
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
+                                  const size_t numProj,
+                                  const size_t numTables,
+                                  const double hashWidthIn,
+                                  const size_t secondHashSize,
+                                  const size_t bucketSize)
+{
+  // Set new reference set.
+  if (this->referenceSet && ownsSet)
+    delete this->referenceSet;
+  this->referenceSet = &referenceSet;
+
+  // Set new parameters.
+  this->numProj = numProj;
+  this->numTables = numTables;
+  this->hashWidth = hashWidthIn;
+  this->secondHashSize = secondHashSize;
+  this->bucketSize = bucketSize;
+
   if (hashWidth == 0.0) // The user has not provided any value.
   {
     // Compute a heuristic hash width from the data.
@@ -50,14 +84,6 @@ LSHSearch(const arma::mat& referenceSet,
   BuildHash();
 }
 
-// Destructor.
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::~LSHSearch()
-{
-  if (ownsSet)
-    delete referenceSet;
-}
-
 template<typename SortPolicy>
 void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
                                            arma::Mat<size_t>& neighbors,
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 341da69..feb7ae2 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -108,4 +108,28 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
   }
 }
 
+BOOST_AUTO_TEST_CASE(LSHTrainTest)
+{
+  // This is a not very good test that simply checks that the re-trained LSH
+  // model operates on the correct dimensionality and returns the correct number
+  // of results.
+  arma::mat referenceData = arma::randu<arma::mat>(3, 100);
+  arma;:mat newReferenceData = arma::randu<arma::mat>(10, 400);
+  arma::mat queryData = arma::randu<arma::mat>(10, 200);
+
+  LSHSearch<> lsh(referenceData, 3, 2, 2.0, 11, 3);
+
+  lsh.Train(newReferenceData, 4, 3, 3.0, 12, 4);
+
+  arma::Mat<size_t> neighbors;
+  arma::mat distances;
+
+  lsh.Search(queryData, 3, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+}
+
 BOOST_AUTO_TEST_SUITE_END();



More information about the mlpack-git mailing list