[mlpack-svn] r14143 - mlpack/trunk/src/mlpack/methods/lsh
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Jan 19 00:45:16 EST 2013
Author: pram
Date: 2013-01-19 00:45:15 -0500 (Sat, 19 Jan 2013)
New Revision: 14143
Modified:
mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
LSHSearch class MetricType template removed and only metric::SquareEuclideanDistance used appropriately throughout the class
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp 2013-01-19 05:45:15 UTC (rev 14143)
@@ -101,8 +101,8 @@
}
// Pick up the LSH-specific parameters.
- const size_t numProj = CLI::GetParam<int>("num_projections");
- const size_t numTables = CLI::GetParam<int>("num_tables");
+ const size_t numProj = CLI::GetParam<int>("projections");
+ const size_t numTables = CLI::GetParam<int>("tables");
const double hashWidth = CLI::GetParam<double>("hash_width");
arma::Mat<size_t> neighbors;
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp 2013-01-19 05:45:15 UTC (rev 14143)
@@ -39,10 +39,8 @@
* of the given queries.
*
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- * @tparam MetricType The metric to use for computation.
*/
-template<typename SortPolicy = NearestNeighborSort,
- typename MetricType = mlpack::metric::SquaredEuclideanDistance>
+template<typename SortPolicy = NearestNeighborSort>
class LSHSearch
{
public:
@@ -66,7 +64,6 @@
* @param bucketSize The size of the bucket in the second hash table. This is
* the maximum number of points that can be hashed into single bucket.
* Default values are already provided here.
- * @param metric An optional instance of the MetricType class.
*/
LSHSearch(const arma::mat& referenceSet,
const arma::mat& querySet,
@@ -74,8 +71,7 @@
const size_t numTables,
const double hashWidth = 0.0,
const size_t secondHashSize = 99901,
- const size_t bucketSize = 500,
- const MetricType metric = MetricType());
+ const size_t bucketSize = 500);
/**
* This function initializes the LSH class. It builds the hash on the
@@ -96,15 +92,13 @@
* @param bucketSize The size of the bucket in the second hash table. This is
* the maximum number of points that can be hashed into single bucket.
* Default values are already provided here.
- * @param metric An optional instance of the MetricType class.
*/
LSHSearch(const arma::mat& referenceSet,
const size_t numProj,
const size_t numTables,
const double hashWidth = 0.0,
const size_t secondHashSize = 99901,
- const size_t bucketSize = 500,
- const MetricType metric = MetricType());
+ const size_t bucketSize = 500);
/**
* Compute the nearest neighbors and store the output in the given matrices.
@@ -216,7 +210,7 @@
const size_t bucketSize;
//! Instantiation of the metric.
- MetricType metric;
+ metric::SquaredEuclideanDistance metric;
//! The final hash table; should be (< secondHashSize) x bucketSize.
arma::Mat<size_t> secondHashTable;
Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp 2013-01-19 05:45:15 UTC (rev 14143)
@@ -13,24 +13,22 @@
namespace neighbor {
// Construct the object.
-template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
LSHSearch(const arma::mat& referenceSet,
const arma::mat& querySet,
const size_t numProj,
const size_t numTables,
const double hashWidthIn,
const size_t secondHashSize,
- const size_t bucketSize,
- const MetricType metric) :
+ const size_t bucketSize) :
referenceSet(referenceSet),
querySet(querySet),
numProj(numProj),
numTables(numTables),
hashWidth(hashWidthIn),
secondHashSize(secondHashSize),
- bucketSize(bucketSize),
- metric(metric)
+ bucketSize(bucketSize)
{
if (hashWidth == 0.0) // The user has not provided any value.
{
@@ -40,33 +38,33 @@
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2));
+ hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2)));
}
hashWidth /= 25;
}
+ Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
BuildHash();
}
-template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
LSHSearch(const arma::mat& referenceSet,
const size_t numProj,
const size_t numTables,
const double hashWidthIn,
const size_t secondHashSize,
- const size_t bucketSize,
- const MetricType metric) :
+ const size_t bucketSize) :
referenceSet(referenceSet),
querySet(referenceSet),
numProj(numProj),
numTables(numTables),
hashWidth(hashWidthIn),
secondHashSize(secondHashSize),
- bucketSize(bucketSize),
- metric(metric)
+ bucketSize(bucketSize)
{
if (hashWidth == 0.0) // The user has not provided any value.
{
@@ -76,18 +74,20 @@
size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
- hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
- referenceSet.unsafe_col(p2));
+ hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+ referenceSet.unsafe_col(p2)));
}
hashWidth /= 25;
}
+ Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
BuildHash();
}
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
InsertNeighbor(const size_t queryIndex,
const size_t pos,
const size_t neighbor,
@@ -110,9 +110,9 @@
(*neighborPtr)(pos, queryIndex) = neighbor;
}
-template<typename SortPolicy, typename MetricType>
+template<typename SortPolicy>
inline force_inline
-double LSHSearch<SortPolicy, MetricType>::
+double LSHSearch<SortPolicy>::
BaseCase(const size_t queryIndex, const size_t referenceIndex)
{
// If the datasets are the same, then this search is only using one dataset
@@ -135,8 +135,8 @@
return distance;
}
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
ReturnIndicesFromTable(const size_t queryIndex,
arma::uvec& referenceIndices,
size_t numTablesToSearch)
@@ -199,8 +199,8 @@
}
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
Search(const size_t k,
arma::Mat<size_t>& resultingNeighbors,
arma::mat& distances,
@@ -244,8 +244,8 @@
std::endl;
}
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
BuildHash()
{
// The first level hash for a single table outputs a 'numProj'-dimensional
More information about the mlpack-svn
mailing list