[mlpack-git] master: Add Train() method. (b4c2948)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Nov 20 17:33:26 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/962a37fe8374913c435054aa50e12d912bdfa01c...a7d8231fe7526dcfaadae0bf37d67b50d286e45d
>---------------------------------------------------------------
commit b4c2948c7ad9dd3c1da368e7ffc005ead8422400
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Nov 13 08:22:05 2015 +0000
Add Train() method.
>---------------------------------------------------------------
b4c2948c7ad9dd3c1da368e7ffc005ead8422400
src/mlpack/methods/lsh/lsh_search.hpp | 23 ++++++++++++----
src/mlpack/methods/lsh/lsh_search_impl.hpp | 44 ++++++++++++++++++++++++------
src/mlpack/tests/lsh_test.cpp | 24 ++++++++++++++++
3 files changed, 76 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/methods/lsh/lsh_search.hpp b/src/mlpack/methods/lsh/lsh_search.hpp
index e07559c..2832b87 100644
--- a/src/mlpack/methods/lsh/lsh_search.hpp
+++ b/src/mlpack/methods/lsh/lsh_search.hpp
@@ -77,6 +77,17 @@ class LSHSearch
~LSHSearch();
/**
+ * Train the LSH model on the given dataset. This means building new hash
+ * tables.
+ */
+ void Train(const arma::mat& referenceSet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidth = 0.0,
+ const size_t secondHashSize = 99901,
+ const size_t bucketSize = 500);
+
+ /**
* Compute the nearest neighbors of the points in the given query set and
* store the output in the given matrices. The matrices will be set to the
* size of n columns by k rows, where n is the number of points in the query
@@ -225,9 +236,9 @@ class LSHSearch
bool ownsSet;
//! The number of projections.
- const size_t numProj;
+ size_t numProj;
//! The number of hash tables.
- const size_t numTables;
+ size_t numTables;
//! The std::vector containing the projection matrix of each table.
std::vector<arma::mat> projections; // should be [numProj x dims] x numTables
@@ -239,13 +250,13 @@ class LSHSearch
double hashWidth;
//! The big prime representing the size of the second hash.
- const size_t secondHashSize;
+ size_t secondHashSize;
//! The weights of the second hash.
arma::vec secondHashWeights;
//! The bucket size of the second hash.
- const size_t bucketSize;
+ size_t bucketSize;
//! The final hash table; should be (< secondHashSize) x bucketSize.
arma::Mat<size_t> secondHashTable;
@@ -262,8 +273,8 @@ class LSHSearch
size_t distanceEvaluations;
}; // class LSHSearch
-}; // namespace neighbor
-}; // namespace mlpack
+} // namespace neighbor
+} // namespace mlpack
// Include implementation.
#include "lsh_search_impl.hpp"
diff --git a/src/mlpack/methods/lsh/lsh_search_impl.hpp b/src/mlpack/methods/lsh/lsh_search_impl.hpp
index 2463751..41de73b 100644
--- a/src/mlpack/methods/lsh/lsh_search_impl.hpp
+++ b/src/mlpack/methods/lsh/lsh_search_impl.hpp
@@ -21,7 +21,7 @@ LSHSearch(const arma::mat& referenceSet,
const double hashWidthIn,
const size_t secondHashSize,
const size_t bucketSize) :
- referenceSet(&referenceSet),
+ referenceSet(NULL), // This will be set in Train().
ownsSet(false),
numProj(numProj),
numTables(numTables),
@@ -30,6 +30,40 @@ LSHSearch(const arma::mat& referenceSet,
bucketSize(bucketSize),
distanceEvaluations(0)
{
+ // Pass work to training function.
+ Train(referenceSet, numProj, numTables, hashWidthIn, secondHashSize,
+ bucketSize);
+}
+
+// Destructor.
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::~LSHSearch()
+{
+ if (ownsSet)
+ delete referenceSet;
+}
+
+// Train on a new reference set.
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::Train(const arma::mat& referenceSet,
+ const size_t numProj,
+ const size_t numTables,
+ const double hashWidthIn,
+ const size_t secondHashSize,
+ const size_t bucketSize)
+{
+ // Set new reference set.
+ if (this->referenceSet && ownsSet)
+ delete this->referenceSet;
+ this->referenceSet = &referenceSet;
+
+ // Set new parameters.
+ this->numProj = numProj;
+ this->numTables = numTables;
+ this->hashWidth = hashWidthIn;
+ this->secondHashSize = secondHashSize;
+ this->bucketSize = bucketSize;
+
if (hashWidth == 0.0) // The user has not provided any value.
{
// Compute a heuristic hash width from the data.
@@ -50,14 +84,6 @@ LSHSearch(const arma::mat& referenceSet,
BuildHash();
}
-// Destructor.
-template<typename SortPolicy>
-LSHSearch<SortPolicy>::~LSHSearch()
-{
- if (ownsSet)
- delete referenceSet;
-}
-
template<typename SortPolicy>
void LSHSearch<SortPolicy>::InsertNeighbor(arma::mat& distances,
arma::Mat<size_t>& neighbors,
diff --git a/src/mlpack/tests/lsh_test.cpp b/src/mlpack/tests/lsh_test.cpp
index 341da69..feb7ae2 100644
--- a/src/mlpack/tests/lsh_test.cpp
+++ b/src/mlpack/tests/lsh_test.cpp
@@ -108,4 +108,28 @@ BOOST_AUTO_TEST_CASE(LSHSearchTest)
}
}
+BOOST_AUTO_TEST_CASE(LSHTrainTest)
+{
+ // This is a not very good test that simply checks that the re-trained LSH
+ // model operates on the correct dimensionality and returns the correct number
+ // of results.
+ arma::mat referenceData = arma::randu<arma::mat>(3, 100);
+ arma;:mat newReferenceData = arma::randu<arma::mat>(10, 400);
+ arma::mat queryData = arma::randu<arma::mat>(10, 200);
+
+ LSHSearch<> lsh(referenceData, 3, 2, 2.0, 11, 3);
+
+ lsh.Train(newReferenceData, 4, 3, 3.0, 12, 4);
+
+ arma::Mat<size_t> neighbors;
+ arma::mat distances;
+
+ lsh.Search(queryData, 3, neighbors, distances);
+
+ BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
+ BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+ BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
+ BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+}
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list