[mlpack-svn] r10187 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 13:07:59 EST 2011


Author: pram
Date: 2011-11-08 13:07:58 -0500 (Tue, 08 Nov 2011)
New Revision: 10187

Added:
   mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc
Log:
New pairwise distance computer

Added: mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc	                        (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc	2011-11-08 18:07:58 UTC (rev 10187)
@@ -0,0 +1,97 @@
+#include <armadillo>
+#include <string>
+#include <vector>
+
+#include <mlpack/core.h>
+#include <mlpack/core/kernels/lmetric.hpp>
+
+using namespace mlpack;
+using namespace std;
+
+// Add params for large scale or small scale computation
+PROGRAM_INFO("Distance Computer", "This program computes the "
+	     "complete distance list for the given queries and "
+	     "references.", "");
+
+PARAM_STRING_REQ("r", "The reference set", "");
+PARAM_STRING_REQ("q", "The set of queries", "");
+PARAM_STRING_REQ("dist_file", "The file where the rank "
+		 "matrix would be written in.", "");
+PARAM_INT_REQ("num_q", "The number of queries to be used for "
+	      "this computation.", "");
+
+int main (int argc, char *argv[]) {
+  
+  CLI::ParseCommandLine(argc, argv);
+
+  arma::mat rdata, qdata;
+  string rfile = CLI::GetParam<string>("r");
+  string qfile = CLI::GetParam<string>("q");
+
+  Log::Warn << "Loading files..." << endl;
+  if (rdata.load(rfile.c_str()) == false)
+    Log::Fatal << "Reference file "<< rfile << " not found." << endl;
+
+  if (qdata.load(qfile.c_str()) == false) 
+    Log::Fatal << "Query file " << qfile << " not found." << endl;
+
+  rdata = arma::trans(rdata);
+  qdata = arma::trans(qdata);
+
+  Log::Warn << "File loaded..." << endl;
+  
+  Log::Warn << "R(" << rdata.n_rows << ", " << rdata.n_cols 
+	    << "), Q(" << qdata.n_rows << ", " << qdata.n_cols 
+	    << ")" << endl;
+
+  string dist_file = CLI::GetParam<string>("dist_file");
+  FILE *pfile = fopen(dist_file.c_str(), "w");
+
+  mlpack::kernel::SquaredEuclideanDistance dist_kernel 
+    = mlpack::kernel::SquaredEuclideanDistance();
+
+  double perc_done = 10.0;
+  double done_sky = 1.0;
+	
+  // do it with a loop over the queries.
+  size_t num_q = CLI::GetParam<int>("num_q");
+
+  for (size_t i = 0; i < num_q; i++) {
+
+    arma::vec q = qdata.unsafe_col(i);
+
+    for (size_t j = 0; j < rdata.n_cols; j++) {
+
+      arma::vec r = rdata.unsafe_col(j);
+      // obtaining the distances
+      fprintf(pfile, "%lg\n",
+	      sqrt(dist_kernel.Evaluate(q,r)));
+    }
+
+    double pdone = i * 100 / num_q;
+
+    if (pdone >= done_sky * perc_done) {
+      if (done_sky > 1) {
+	printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL); 
+      } else {
+	printf("=%zu%%", (size_t) pdone); fflush(NULL);
+      }
+      done_sky++;
+    }
+  } // query-loop
+
+  double pdone = 100;
+  
+  if (pdone >= done_sky * perc_done) {
+    if (done_sky > 1) {
+      printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL); 
+    } else {
+      printf("=%zu%%", (size_t) pdone); fflush(NULL);
+    }
+    done_sky++;
+  }
+  printf("\n");fflush(NULL);
+
+  fclose(pfile);
+  Log::Info << "Distances computed!" << endl;
+} // end main




More information about the mlpack-svn mailing list