[mlpack-svn] r14143 - mlpack/trunk/src/mlpack/methods/lsh

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Jan 19 00:45:16 EST 2013


Author: pram
Date: 2013-01-19 00:45:15 -0500 (Sat, 19 Jan 2013)
New Revision: 14143

Modified:
   mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
   mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
Log:
LSHSearch class MetricType template removed and only metric::SquareEuclideanDistance used appropriately throughout the class

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_main.cpp	2013-01-19 05:45:15 UTC (rev 14143)
@@ -101,8 +101,8 @@
   }
 
   // Pick up the LSH-specific parameters.
-  const size_t numProj = CLI::GetParam<int>("num_projections");
-  const size_t numTables = CLI::GetParam<int>("num_tables");
+  const size_t numProj = CLI::GetParam<int>("projections");
+  const size_t numTables = CLI::GetParam<int>("tables");
   const double hashWidth = CLI::GetParam<double>("hash_width");
 
   arma::Mat<size_t> neighbors;

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search.hpp	2013-01-19 05:45:15 UTC (rev 14143)
@@ -39,10 +39,8 @@
  * of the given queries.
  *
  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
- * @tparam MetricType The metric to use for computation.
  */
-template<typename SortPolicy = NearestNeighborSort,
-         typename MetricType = mlpack::metric::SquaredEuclideanDistance>
+template<typename SortPolicy = NearestNeighborSort>
 class LSHSearch
 {
  public:
@@ -66,7 +64,6 @@
    * @param bucketSize The size of the bucket in the second hash table. This is
    *     the maximum number of points that can be hashed into single bucket.
    *     Default values are already provided here.
-   * @param metric An optional instance of the MetricType class.
    */
   LSHSearch(const arma::mat& referenceSet,
             const arma::mat& querySet,
@@ -74,8 +71,7 @@
             const size_t numTables,
             const double hashWidth = 0.0,
             const size_t secondHashSize = 99901,
-            const size_t bucketSize = 500,
-            const MetricType metric = MetricType());
+            const size_t bucketSize = 500);
 
   /**
    * This function initializes the LSH class. It builds the hash on the
@@ -96,15 +92,13 @@
    * @param bucketSize The size of the bucket in the second hash table. This is
    *     the maximum number of points that can be hashed into single bucket.
    *     Default values are already provided here.
-   * @param metric An optional instance of the MetricType class.
    */
   LSHSearch(const arma::mat& referenceSet,
             const size_t numProj,
             const size_t numTables,
             const double hashWidth = 0.0,
             const size_t secondHashSize = 99901,
-            const size_t bucketSize = 500,
-            const MetricType metric = MetricType());
+            const size_t bucketSize = 500);
 
   /**
    * Compute the nearest neighbors and store the output in the given matrices.
@@ -216,7 +210,7 @@
   const size_t bucketSize;
 
   //! Instantiation of the metric.
-  MetricType metric;
+  metric::SquaredEuclideanDistance metric;
 
   //! The final hash table; should be (< secondHashSize) x bucketSize.
   arma::Mat<size_t> secondHashTable;

Modified: mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-19 02:51:50 UTC (rev 14142)
+++ mlpack/trunk/src/mlpack/methods/lsh/lsh_search_impl.hpp	2013-01-19 05:45:15 UTC (rev 14143)
@@ -13,24 +13,22 @@
 namespace neighbor {
 
 // Construct the object.
-template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
 LSHSearch(const arma::mat& referenceSet,
           const arma::mat& querySet,
           const size_t numProj,
           const size_t numTables,
           const double hashWidthIn,
           const size_t secondHashSize,
-          const size_t bucketSize,
-          const MetricType metric) :
+          const size_t bucketSize) :
   referenceSet(referenceSet),
   querySet(querySet),
   numProj(numProj),
   numTables(numTables),
   hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
-  bucketSize(bucketSize),
-  metric(metric)
+  bucketSize(bucketSize)
 {
   if (hashWidth == 0.0) // The user has not provided any value.
   {
@@ -40,33 +38,33 @@
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
-                                        referenceSet.unsafe_col(p2));
+      hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+                                             referenceSet.unsafe_col(p2)));
     }
 
     hashWidth /= 25;
   }
 
+  Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
   BuildHash();
 }
 
-template<typename SortPolicy, typename MetricType>
-LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+LSHSearch<SortPolicy>::
 LSHSearch(const arma::mat& referenceSet,
           const size_t numProj,
           const size_t numTables,
           const double hashWidthIn,
           const size_t secondHashSize,
-          const size_t bucketSize,
-          const MetricType metric) :
+          const size_t bucketSize) :
   referenceSet(referenceSet),
   querySet(referenceSet),
   numProj(numProj),
   numTables(numTables),
   hashWidth(hashWidthIn),
   secondHashSize(secondHashSize),
-  bucketSize(bucketSize),
-  metric(metric)
+  bucketSize(bucketSize)
 {
   if (hashWidth == 0.0) // The user has not provided any value.
   {
@@ -76,18 +74,20 @@
       size_t p1 = (size_t) math::RandInt(referenceSet.n_cols);
       size_t p2 = (size_t) math::RandInt(referenceSet.n_cols);
 
-      hashWidth += MetricType::Evaluate(referenceSet.unsafe_col(p1),
-                                        referenceSet.unsafe_col(p2));
+      hashWidth += std::sqrt(metric.Evaluate(referenceSet.unsafe_col(p1),
+                                             referenceSet.unsafe_col(p2)));
     }
 
     hashWidth /= 25;
   }
 
+  Log::Info << "Hash width chosen as: " << hashWidth << std::endl;
+
   BuildHash();
 }
 
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
 InsertNeighbor(const size_t queryIndex,
                const size_t pos,
                const size_t neighbor,
@@ -110,9 +110,9 @@
   (*neighborPtr)(pos, queryIndex) = neighbor;
 }
 
-template<typename SortPolicy, typename MetricType>
+template<typename SortPolicy>
 inline force_inline
-double LSHSearch<SortPolicy, MetricType>::
+double LSHSearch<SortPolicy>::
 BaseCase(const size_t queryIndex, const size_t referenceIndex)
 {
   // If the datasets are the same, then this search is only using one dataset
@@ -135,8 +135,8 @@
   return distance;
 }
 
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
 ReturnIndicesFromTable(const size_t queryIndex,
                        arma::uvec& referenceIndices,
                        size_t numTablesToSearch)
@@ -199,8 +199,8 @@
 }
 
 
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
 Search(const size_t k,
        arma::Mat<size_t>& resultingNeighbors,
        arma::mat& distances,
@@ -244,8 +244,8 @@
       std::endl;
 }
 
-template<typename SortPolicy, typename MetricType>
-void LSHSearch<SortPolicy, MetricType>::
+template<typename SortPolicy>
+void LSHSearch<SortPolicy>::
 BuildHash()
 {
   // The first level hash for a single table outputs a 'numProj'-dimensional




More information about the mlpack-svn mailing list