[mlpack-svn] r10187 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 13:07:59 EST 2011
Author: pram
Date: 2011-11-08 13:07:58 -0500 (Tue, 08 Nov 2011)
New Revision: 10187
Added:
mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc
Log:
New pairwise distance computer
Added: mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/new_pairwise_dists.cc 2011-11-08 18:07:58 UTC (rev 10187)
@@ -0,0 +1,97 @@
+#include <armadillo>
+#include <string>
+#include <vector>
+
+#include <mlpack/core.h>
+#include <mlpack/core/kernels/lmetric.hpp>
+
+using namespace mlpack;
+using namespace std;
+
+// Add params for large scale or small scale computation
+PROGRAM_INFO("Distance Computer", "This program computes the "
+ "complete distance list for the given queries and "
+ "references.", "");
+
+PARAM_STRING_REQ("r", "The reference set", "");
+PARAM_STRING_REQ("q", "The set of queries", "");
+PARAM_STRING_REQ("dist_file", "The file where the rank "
+ "matrix would be written in.", "");
+PARAM_INT_REQ("num_q", "The number of queries to be used for "
+ "this computation.", "");
+
+int main (int argc, char *argv[]) {
+
+ CLI::ParseCommandLine(argc, argv);
+
+ arma::mat rdata, qdata;
+ string rfile = CLI::GetParam<string>("r");
+ string qfile = CLI::GetParam<string>("q");
+
+ Log::Warn << "Loading files..." << endl;
+ if (rdata.load(rfile.c_str()) == false)
+ Log::Fatal << "Reference file "<< rfile << " not found." << endl;
+
+ if (qdata.load(qfile.c_str()) == false)
+ Log::Fatal << "Query file " << qfile << " not found." << endl;
+
+ rdata = arma::trans(rdata);
+ qdata = arma::trans(qdata);
+
+ Log::Warn << "File loaded..." << endl;
+
+ Log::Warn << "R(" << rdata.n_rows << ", " << rdata.n_cols
+ << "), Q(" << qdata.n_rows << ", " << qdata.n_cols
+ << ")" << endl;
+
+ string dist_file = CLI::GetParam<string>("dist_file");
+ FILE *pfile = fopen(dist_file.c_str(), "w");
+
+ mlpack::kernel::SquaredEuclideanDistance dist_kernel
+ = mlpack::kernel::SquaredEuclideanDistance();
+
+ double perc_done = 10.0;
+ double done_sky = 1.0;
+
+ // do it with a loop over the queries.
+ size_t num_q = CLI::GetParam<int>("num_q");
+
+ for (size_t i = 0; i < num_q; i++) {
+
+ arma::vec q = qdata.unsafe_col(i);
+
+ for (size_t j = 0; j < rdata.n_cols; j++) {
+
+ arma::vec r = rdata.unsafe_col(j);
+ // obtaining the distances
+ fprintf(pfile, "%lg\n",
+ sqrt(dist_kernel.Evaluate(q,r)));
+ }
+
+ double pdone = i * 100 / num_q;
+
+ if (pdone >= done_sky * perc_done) {
+ if (done_sky > 1) {
+ printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL);
+ } else {
+ printf("=%zu%%", (size_t) pdone); fflush(NULL);
+ }
+ done_sky++;
+ }
+ } // query-loop
+
+ double pdone = 100;
+
+ if (pdone >= done_sky * perc_done) {
+ if (done_sky > 1) {
+ printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL);
+ } else {
+ printf("=%zu%%", (size_t) pdone); fflush(NULL);
+ }
+ done_sky++;
+ }
+ printf("\n");fflush(NULL);
+
+ fclose(pfile);
+ Log::Info << "Distances computed!" << endl;
+} // end main
More information about the mlpack-svn
mailing list