[mlpack-svn] r10052 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 27 09:48:31 EDT 2011


Author: pram
Date: 2011-10-27 09:48:31 -0400 (Thu, 27 Oct 2011)
New Revision: 10052

Modified:
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
Log:
Oct 25th, Rank-approximate fixed for very high k-values

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc	2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc	2011-10-27 13:48:31 UTC (rev 10052)
@@ -12,7 +12,7 @@
 PROGRAM_INFO("Approx-maximum Inner Product", "This program "
   	     "returns the approx-maximum inner product for a "
   	     "given query over a set of points (references).", 
- 	     "");
+ 	     "approx_maxip");
 
 PARAM_STRING_REQ("r", "The reference set", "");
 PARAM_STRING_REQ("q", "The set of queries", "");
@@ -111,8 +111,15 @@
 	      << rdata.n_cols  << " / " << (float) approx_comp << " = "
 	      <<(float) (rdata.n_cols / approx_comp) << endl;
 
+
     if (CLI::GetParam<string>("rank_file") != "") {
       string rank_file = CLI::GetParam<string>("rank_file");
+      double epsilon = CLI::GetParam<double>("approx_maxip/epsilon");
+      double alpha = CLI::GetParam<double>("approx_maxip/alpha");
+
+      check_nn_utils::check_rank_bound(rank_file, rdata.n_cols,
+				       epsilon, alpha, &apc);
+
       check_nn_utils::compute_error(rank_file,  rdata.n_cols, &apc);
     } else {
       check_nn_utils::compute_error(&rdata, &qdata, &apc);

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc	2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc	2011-10-27 13:48:31 UTC (rev 10052)
@@ -298,12 +298,16 @@
     do {
       n = (size_t) (beta * (double)(set_size)) +1;
       (*samples)(--set_size) = n;
-    } while (set_size > rank_approx);
+//     } while (set_size > rank_approx);
+    } while (set_size > 0);
     
-    while (set_size > 0) {
-      set_size--;
-      (*samples)(set_size) = std::min(k, set_size + 1);
-    }
+
+    // Maybe we do not do this, this is throwing things off
+    // Let's try removing this
+//     while (set_size > 0) {
+//       set_size--;
+//       (*samples)(set_size) = std::min(k, set_size + 1);
+//     }
   } else {
     while (set_size > 0) {
       set_size--;
@@ -321,6 +325,10 @@
    
   // Check that the pointers are not NULL
   assert(reference_node != NULL);
+  if (!CLI::HasParam("approx_maxip/no_tree"))
+    assert(reference_node->is_leaf() 
+ 	   || is_base(reference_node)
+	   || is_almost_satisfied());
 
   // Obtain the number of samples to be obtained
   size_t set_size
@@ -332,9 +340,14 @@
 //     = min_samples_per_q_ - query_node->stat().samples();
 
   sample_size = std::min(sample_size, query_samples_needed_);
-  assert(sample_size <= sample_limit_);
 
+  if (!CLI::HasParam("approx_maxip/no_tree")) {
+//     printf("Leaf size: %zu, Sample size: %zu\n", 
+// 	   set_size, sample_size); fflush(NULL);
+    assert(sample_size <= sample_limit_);
+  }
 
+
   // Get the query point from the matrix
   arma::vec q = queries_.unsafe_col(query_); 
 
@@ -582,11 +595,44 @@
 }
 
 
+void ApproxMaxIP::CheckPrune(TreeType* ref_node) {
+
+  size_t missed_nns = 0;
+  double max_ip = 0.0;
+
+  // Get the query point from the matrix
+  arma::vec q = queries_.unsafe_col(query_);
+
+  double min_ip = max_ips_(knns_ -1, query_); 
+
+  // We'll do the same for the references
+  for (size_t reference_index = ref_node->begin(); 
+       reference_index < ref_node->end(); reference_index++) {
+
+    arma::vec r = references_.unsafe_col(reference_index);
+
+    double ip_r = arma::dot(q, r);
+    if (ip_r > min_ip)
+      missed_nns++;
+
+    if (ip_r > max_ip)
+      max_ip = ip_r;
+    
+  } // for reference_index
+
+  if (missed_nns > 0) 
+    printf("Prune %zu - Missed candidates: %zu\n"
+	   "QLBound: %lg, QRBound: %lg, ActualQRBound: %lg\n",
+	   number_of_prunes_, missed_nns,
+	   min_ip, MaxNodeIP_(ref_node), max_ip);
+
+}
+
 void ApproxMaxIP::ComputeApproxRecursion_(TreeType* reference_node, 
 					  double upper_bound_ip) {
 
   assert(reference_node != NULL);
-  //assert(upper_bound_ip == MaxNodeIP_(reference_node));
+  //  assert(upper_bound_ip == MaxNodeIP_(reference_node));
 
 
   // check if the query has enough number of samples
@@ -594,16 +640,23 @@
     if (upper_bound_ip < max_ips_(knns_ -1, query_)) { 
       // Pruned by distance
       number_of_prunes_++;
+
+      if (CLI::HasParam("approx_maxip/check_prune"))
+	CheckPrune(reference_node);
+
       query_samples_needed_ 
 	-= sample_sizes_[reference_node->end()
 			 - reference_node->begin() - 1];
 
     } else if (reference_node->is_leaf()) {
-      // base case for the single tree case
+	// base case for the single tree case
       ComputeBaseCase_(reference_node);
       query_samples_needed_
 	-= (reference_node->end() - reference_node->begin());
 
+      // trying to see if this was the issue (DIDN'T WORK)
+      // ComputeApproxBaseCase_(reference_node);
+
     } else if (is_base(reference_node)) {
       // base case for the approximate case
       ComputeApproxBaseCase_(reference_node);
@@ -611,25 +664,26 @@
     } else if (is_almost_satisfied()) {
       // base case for the approximate case
       ComputeApproxBaseCase_(reference_node);
-
-    } else {
+    }  else {
       // Recurse on both as above
       double left_ip = MaxNodeIP_(reference_node->left());
       double right_ip = MaxNodeIP_(reference_node->right());
-
+      
       if (left_ip > right_ip) {
 	ComputeApproxRecursion_(reference_node->left(), 
-				   left_ip);
+				left_ip);
 	ComputeApproxRecursion_(reference_node->right(),
-				   right_ip);
+				right_ip);
       } else {
 	ComputeApproxRecursion_(reference_node->right(),
-				   right_ip);
+				right_ip);
 	ComputeApproxRecursion_(reference_node->left(), 
-				   left_ip);
+				left_ip);
       }
     }    
-  }  
+  } else {
+//     assert(query_samples_needed_ <= 0);
+  }
 } // ComputeApproxRecursion_
 
 
@@ -704,7 +758,7 @@
 	query_node->right()->stat().add_total_points(extra_points_encountered);
 	size_t extra_points_sampled
 	  = query_node->stat().samples()
-	  - std::max(query_node->left()->stat().samples(),
+	  - std::min(query_node->left()->stat().samples(),
 		     query_node->right()->stat().samples());
 	assert(extra_points_sampled >= 0);
 	query_node->left()->stat().add_samples(extra_points_sampled);
@@ -728,6 +782,11 @@
       query_node->stat().set_total_points(
 	  query_node->left()->stat().total_points());
 
+//       printf("%zu: L:%zu, R:%zu\n", query_node->stat().samples(),
+// 	     query_node->left()->stat().samples(),
+// 	     query_node->right()->stat().samples()); fflush(NULL);
+
+
       assert(query_node->stat().samples() <= 
 	     std::min(query_node->left()->stat().samples(),
 		      query_node->right()->stat().samples()));
@@ -786,7 +845,7 @@
       query_node->right()->stat().add_total_points(extra_points_encountered);
       size_t extra_points_sampled
 	= query_node->stat().samples()
-	- std::max(query_node->left()->stat().samples(),
+	- std::min(query_node->left()->stat().samples(),
 		   query_node->right()->stat().samples());
       assert(extra_points_sampled >= 0);
       query_node->left()->stat().add_samples(extra_points_sampled);
@@ -1073,8 +1132,15 @@
     for (query_ = 0; query_ < queries_.n_cols; ++query_) {
       query_samples_needed_ = min_samples_per_q_;
 
-      ComputeApproxRecursion_(reference_tree_, 
-			      MaxNodeIP_(reference_tree_));
+      if (CLI::HasParam("approx_maxip/no_tree")) {
+	// ComputeApproxBaseCase_(reference_tree_);
+	ComputeApproxBaseCase_(reference_tree_->left());
+	ComputeApproxBaseCase_(reference_tree_->right());
+      } else 
+	ComputeApproxRecursion_(reference_tree_, 
+				MaxNodeIP_(reference_tree_));
+
+//       assert(!(query_samples_needed_ > 0));
     }
     CLI::StopTimer("approx_maxip/fast_single");
 

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-27 13:48:31 UTC (rev 10052)
@@ -45,6 +45,9 @@
 
 PARAM_FLAG("check_prune", "The flag to trigger the "
 	   "checking of the prune.", "approx_maxip");
+
+PARAM_FLAG("no_tree", "The flag to trigger the tree-less "
+	   "rank-approximate search.", "approx_maxip");
 /**
  * Performs maximum-inner-product-search. 
  * This class will build the trees and 
@@ -251,6 +254,9 @@
   void ComputeBaseCase_(CTreeType* query_node,
 			TreeType* reference_node);
 
+
+
+
   /**
    * The recursive function for the approximate computation
    */
@@ -347,6 +353,7 @@
 		       arma::mat* ips);
 
   void CheckPrune(CTreeType* query_node, TreeType* ref_node);
+  void CheckPrune(TreeType* reference_node);
 }; //class ApproxMaxIP
 
 #endif

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc	2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc	2011-10-27 13:48:31 UTC (rev 10052)
@@ -64,7 +64,7 @@
     max_k = ks(number_of_ks -1);
   }
 
-  arma::vec eps(25);
+  arma::vec eps(100);
   // arma::vec als(10);
   size_t num_eps = 0; //, num_als = 0;
 
@@ -78,12 +78,20 @@
 
   free(pch);
 
-  Log::Info << number_of_ks << " values for k," << endl
+  Log::Warn << number_of_ks << " values for k," << endl
 	    << num_eps << " values for epsilon." << endl;
 
 
   Log::Warn << "Starting loop for Fast Approx-Search." << endl;
 
+  // If you did multiple repetitions, put the loop here 
+  // have the following:
+  // size_t total_reps = CLI::GetParam<int>("reps");
+  // arma::mat res = arma::zeros<arma::mat>(number_of_ks * num_eps, 5);
+  // for (size_t reps = 0; reps < total_reps; reps++) {
+
+
+
   ApproxMaxIP fast_approx;
   vector< arma::Mat<size_t>* > all_solutions;
   vector<double> all_th_speedups;
@@ -147,14 +155,77 @@
   assert(median_ranks.size() == number_of_ks * num_eps);
   assert(avg_precisions.size() == all_th_speedups.size());
 
-  printf("k\te\tp\tmr\tth_sp\n");
-  for (size_t i = 0; i < number_of_ks; i++) {
-    for (size_t j = 0; j < num_eps; j++) {
 
-      printf("%zu\t%lg\t%lg\t%zu\t%lg\n", ks(i), eps(j),
-	     avg_precisions[i * num_eps + j],
-	     median_ranks[i * num_eps + j],
-	     all_th_speedups[i * num_eps + j]); fflush(NULL);
+
+  // If we are performing reps, we have to make sure 
+  // that the check_nn_utils::compute_error() is called
+  // only once (especially for the Yahoo data set)
+
+
+//   for (size_t i = 0; i < number_of_ks; i++) {
+//     for (size_t j = 0; j < num_eps; j++) {
+
+//       res(i * num_eps + j, 0) += (double)  ks(i);
+//       res(i * num_eps + j, 1) += eps(j);
+// 	 res(i * num_eps + j, 2) += avg_precisions[i * num_eps + j];
+// 	 res(i * num_eps + j, 3) += (double) median_ranks[i * num_eps + j];
+// 	 res(i * num_eps + j, 4) += all_th_speedups[i * num_eps + j]);
+//     }
+//   }
+
+
+// } // reps-loop
+//   if (CLI::GetParam<string>("res_file") != "") {
+  
+//     string res_file = GetParam<string>("res_file");
+//     FILE *res_fp = fopen(res_file.c_str(), "w");
+//     for (size_t i = 0; i < res.n_rows; i++) {
+//       for (size_t j = 0; j < res.n_cols; j++) {
+// 	    fprintf(res_fp, "%lg", res(i, j) / (double) total_reps);
+//          if (j == res.n_cols -1)
+//            fprintf(res_fp, "\n");
+//          else
+//            fprintf(res_fp, ",");
+//       }
+//     }
+//   } else {
+
+//     printf("k\te\tp\tmr\tth_sp\n");
+//     for (size_t i = 0; i < res.n_rows; i++) {
+//       for (size_t j = 0; j < res.n_cols; j++) {
+// 	    printf("%lg", res(i, j) / (double) total_reps);
+//          if (j == res.n_cols -1)
+//            printf("\n");
+//          else
+//            printf(",");
+//       }
+//     }
+//   }
+
+
+
+  if (CLI::GetParam<string>("res_file") != "") {
+  
+    string res_file = CLI::GetParam<string>("res_file");
+    FILE *res_fp = fopen(res_file.c_str(), "w");
+    for (size_t i = 0; i < number_of_ks; i++)
+      for (size_t j = 0; j < num_eps; j++)
+	fprintf(res_fp, "%zu,%lg,%lg,%zu,%lg\n", ks(i), eps(j),
+		avg_precisions[i * num_eps + j],
+		median_ranks[i * num_eps + j],
+		all_th_speedups[i * num_eps + j]);
+  
+  } else {
+
+    printf("k\te\tp\tmr\tth_sp\n");
+    for (size_t i = 0; i < number_of_ks; i++) {
+      for (size_t j = 0; j < num_eps; j++) {
+
+	printf("%zu\t%lg\t%lg\t%zu\t%lg\n", ks(i), eps(j),
+	       avg_precisions[i * num_eps + j],
+	       median_ranks[i * num_eps + j],
+	       all_th_speedups[i * num_eps + j]); fflush(NULL);
+      }
     }
   }
 }  // end main
@@ -177,11 +248,13 @@
 	     "");
 // PARAM_STRING("alphas", "The comma-separated list of alphas", "");
 
-// PARAM_INT("reps", "The number of times the rank-approximate"
-//  	  " algorithm is to be repeated for the same setting.",
-//  	  "", 1);
+PARAM_INT("reps", "The number of times the rank-approximate"
+	  " algorithm is to be repeated for the same setting.",
+  	  "", 1);
 
 PARAM_INT("max_k", "The max value of knns to be tried.", "", 1);
 
 PARAM_STRING("rank_file", "The file containing the ranks.",
 	     "", "");
+PARAM_STRING("res_file", "The file where the results are to be written.",
+	     "", "");

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h	2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h	2011-10-27 13:48:31 UTC (rev 10052)
@@ -63,6 +63,8 @@
 
     all_ranks_list.set_size(qdata->n_cols * indices->n_rows);
 
+    size_t total_considerable_candidates
+      = indices->n_cols * indices->n_rows;
 
     double perc_done = 10.0;
     double done_sky = 1.0;
@@ -93,7 +95,9 @@
 	    all_correct++;
 	} else {
 	  // if no result found, just penalize the worst rank
-	  all_ranks_list( i * k + j ) = rdata->n_cols;
+//	  all_ranks_list( i * k + j ) = rdata->n_cols;
+	  all_ranks_list( i * k + j ) = -1;
+	  total_considerable_candidates--;
 	}
 
       }
@@ -130,12 +134,27 @@
     Log::Info << "Errors Computed!" << endl;
 
 
+//     double avg_precision = (double) all_correct 
+//       / (double) (k * qdata->n_cols);
+
+//     size_t median_rank = arma::median(all_ranks_list);
+
     double avg_precision = (double) all_correct 
-      / (double) (k * qdata->n_cols);
+      / (double) (total_considerable_candidates);
 
-    size_t median_rank = arma::median(all_ranks_list);
+    arma::uvec tcc = find(all_ranks_list != -1);
 
+    assert(tcc.n_elem == total_considerable_candidates);
 
+    arma::Col<size_t> hemmed_all_ranks_list;
+    hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+    for (size_t i = 0; i < tcc.n_elem; i++) 
+      hemmed_all_ranks_list(i) = all_ranks_list(tcc(i));
+
+    size_t median_rank = arma::median(hemmed_all_ranks_list);
+
+
     Log::Warn << "Avg. Precision@" << k << ": " 
 	      << avg_precision << endl 
 	      << "Median Rank@" << k << ": "
@@ -155,6 +174,9 @@
     all_ranks_list.set_size(indices->n_cols * indices->n_rows);
 
 
+    size_t total_considerable_candidates
+      = indices->n_cols * indices->n_rows;
+
     double perc_done = 10.0;
     double done_sky = 1.0;
         
@@ -190,13 +212,14 @@
 
 
 	if ((*indices)(j, i) != (size_t) -1) {
-	  // assert((*indices)(j, i) != -1);
 	  size_t rank = rank_ind((*indices)(j, i)); 
 	  all_ranks_list( i * k + j ) = rank;
 	  if (rank < k + 1)
 	    all_correct++;
 	} else {
-	  all_ranks_list( i * k + j ) = rdata_size;
+// 	  all_ranks_list( i * k + j ) = rdata_size;
+	  all_ranks_list( i * k + j ) = -1;
+	  total_considerable_candidates--;
 	}	
       }
 
@@ -230,16 +253,33 @@
     Log::Info << "Errors Computed!" << endl;
 
 
+//     double avg_precision = (double) all_correct 
+//       / (double) (k * indices->n_cols);
+
+//     size_t median_rank = arma::median(all_ranks_list);
+
     double avg_precision = (double) all_correct 
-      / (double) (k * indices->n_cols);
+      / (double) (total_considerable_candidates);
 
-    size_t median_rank = arma::median(all_ranks_list);
+    arma::uvec tcc = find(all_ranks_list != -1);
 
+    assert(tcc.n_elem == total_considerable_candidates);
 
+    arma::Col<size_t> hemmed_all_ranks_list;
+    hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+    for (size_t i = 0; i < tcc.n_elem; i++) 
+      hemmed_all_ranks_list(i) = all_ranks_list(tcc(i));
+
+    size_t median_rank = arma::median(hemmed_all_ranks_list);
+
+
     Log::Warn << "Avg. Precision@" << k << ": " 
 	      << avg_precision << endl 
 	      << "Median Rank@" << k << ": "
-	      << median_rank << endl;
+	      << median_rank << endl
+	      << "TCC: " << total_considerable_candidates
+	      << endl;
 
   }
 
@@ -255,17 +295,22 @@
 
     vector< arma::Col<size_t>* > all_ranks_lists;
     vector<size_t> all_corrects;
+    vector<size_t> total_considerable_candidates;
 
     // set up the list first
     for (size_t i = 0; i < solutions.size(); i++) {
 
       arma::Col<size_t>* all_ranks_list = new arma::Col<size_t>();
-      size_t all_correct = 0;
+      //size_t all_correct = 0;
 
       all_ranks_list->set_size(solutions[i]->n_cols * solutions[i]->n_rows);
       all_ranks_lists.push_back(all_ranks_list);
 
-      all_corrects.push_back(all_correct);
+      //all_corrects.push_back(all_correct);
+      all_corrects.push_back(0);
+
+      total_considerable_candidates.push_back(solutions[i]->n_cols 
+					      * solutions[i]->n_rows);
     }
 
 
@@ -327,7 +372,9 @@
 
 	  } else {
 	    // if no result found, just penalize the worst rank
-	    (*all_ranks_lists[ind])( i * k + j ) = rdata_size;
+// 	    (*all_ranks_lists[ind])( i * k + j ) = rdata_size;
+	    (*all_ranks_lists[ind])( i * k + j ) = -1;
+	    total_considerable_candidates[ind]--;
 	  }
 
 	} // top k neighbors
@@ -363,10 +410,25 @@
 
     for (size_t ind = 0; ind < num_solutions; ind++) {
 
+//       precisions->push_back((double) all_corrects[ind]
+// 			    / (double) all_ranks_lists[ind]->n_elem);
+//       median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
       precisions->push_back((double) all_corrects[ind]
-			    / (double) all_ranks_lists[ind]->n_elem);
-      median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+			    / (double) total_considerable_candidates[ind]);
+
+      arma::uvec tcc = find((*all_ranks_lists[ind]) != -1);
+
+      assert(tcc.n_elem == total_considerable_candidates[ind]);
+
+      arma::Col<size_t> hemmed_all_ranks_list;
+      hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+      for (size_t i = 0; i < tcc.n_elem; i++) 
+	hemmed_all_ranks_list(i) = (*all_ranks_lists[ind])(tcc(i));
+
+      median_ranks->push_back(arma::median(hemmed_all_ranks_list));
     
+      delete(all_ranks_lists[ind]);
     }
 
     return;
@@ -381,17 +443,20 @@
 
     vector< arma::Col<size_t>* > all_ranks_lists;
     vector<size_t> all_corrects;
+    vector<size_t> total_considerable_candidates;
 
     // set up the list first
     for (size_t i = 0; i < solutions.size(); i++) {
 
       arma::Col<size_t>* all_ranks_list = new arma::Col<size_t>();
-      size_t all_correct = 0;
+      //size_t all_correct = 0;
 
       all_ranks_list->set_size(solutions[i]->n_cols * solutions[i]->n_rows);
       all_ranks_lists.push_back(all_ranks_list);
 
-      all_corrects.push_back(all_correct);
+      all_corrects.push_back(0);
+      total_considerable_candidates.push_back(solutions[i]->n_cols
+					      * solutions[i]->n_rows);
     }
 
 
@@ -438,7 +503,9 @@
 	    if (rank < k + 1)
 	      all_corrects[ind]++;
 	  } else {
-	    (*all_ranks_lists[ind])( i * k + j ) = rdata->n_cols;
+// 	    (*all_ranks_lists[ind])( i * k + j ) = rdata->n_cols;
+	    (*all_ranks_lists[ind])( i * k + j ) = -1;
+	    total_considerable_candidates[ind]--;
 	  }
 	} // top k neighbors
       } // all solutions for this query
@@ -475,13 +542,133 @@
 
     for (size_t ind = 0; ind < num_solutions; ind++) {
 
+//       precisions->push_back((double) all_corrects[ind]
+// 			    / (double) all_ranks_lists[ind]->n_elem);
+//       median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+
       precisions->push_back((double) all_corrects[ind]
-			    / (double) all_ranks_lists[ind]->n_elem);
-      median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+			    / (double) total_considerable_candidates[ind]);
+
+      arma::uvec tcc = find((*all_ranks_lists[ind]) != -1);
+
+      assert(tcc.n_elem == total_considerable_candidates[ind]);
+
+      arma::Col<size_t> hemmed_all_ranks_list;
+      hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+      for (size_t i = 0; i < tcc.n_elem; i++) 
+	hemmed_all_ranks_list(i) = (*all_ranks_lists[ind])(tcc(i));
+
+      median_ranks->push_back(arma::median(hemmed_all_ranks_list));
+
       delete(all_ranks_lists[ind]);
     }
 
     return;
   }
 
+
+
+
+
+  void check_rank_bound(string rank_file, size_t rdata_size,
+			double epsilon, double alpha,
+			arma::Mat<size_t>* indices) {
+
+    size_t all_correct = 0;
+
+    size_t k = indices->n_rows;
+
+    size_t rank_error = (size_t) ( epsilon * (double) rdata_size / 100.0);
+
+    size_t total_considerable_candidates = indices->n_cols;
+
+    double perc_done = 10.0;
+    double done_sky = 1.0;
+        
+    FILE *rank_fp = fopen(rank_file.c_str(), "r");
+
+    // do it with a loop over the queries.
+    for (size_t i = 0; i < indices->n_cols; i++) {
+
+      // obtaining the rank list
+      arma::Col<size_t> rank_ind;
+      rank_ind.set_size(rdata_size);
+
+      if (rank_fp != NULL) {
+	char *line = NULL;
+	size_t len = 0;
+	getline(&line, &len, rank_fp);
+
+	char *pch = strtok(line, ",\n");
+	size_t rank_index = 0;
+
+	while(pch != NULL) {
+	  rank_ind(rank_index++) = atoi(pch);
+	  pch = strtok(NULL, ",\n");
+	}
+
+	free(line);
+	free(pch);
+	assert(rank_index == rdata_size);
+      }
+
+
+//       for (size_t j = 0; j < k; j++) {
+
+
+      if ((*indices)(k-1, i) != (size_t) -1) {
+	size_t rank = rank_ind((*indices)(k -1, i)); 
+	if (rank < rank_error +1)
+	  all_correct++;
+      } else {
+	total_considerable_candidates--;
+      }	
+
+      double pdone = i * 100 / indices->n_cols;
+
+      if (pdone >= done_sky * perc_done) {
+	if (done_sky > 1) {
+	  printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL); 
+	} else {
+	  printf("=%zu%%", (size_t) pdone); fflush(NULL);
+	}
+	done_sky++;
+      }
+    } // query-loop
+
+
+    fclose(rank_fp);
+
+    double pdone = 100;
+
+    if (pdone >= done_sky * perc_done) {
+      if (done_sky > 1) {
+	printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL); 
+      } else {
+	printf("=%zu%%", (size_t) pdone); fflush(NULL);
+      }
+      done_sky++;
+    }
+    printf("\n");fflush(NULL);
+
+    Log::Info << "Errors Computed!" << endl;
+
+
+//     double avg_precision = (double) all_correct 
+//       / (double) (k * indices->n_cols);
+
+//     size_t median_rank = arma::median(all_ranks_list);
+
+    double avg_precision = (double) all_correct 
+      / (double) (indices->n_cols);
+
+
+    Log::Warn << "Actual Alpha @" << k << ": " 
+	      << avg_precision << endl 
+	      << "TCC: " << total_considerable_candidates
+	      << endl << "Alpha: " << alpha << endl
+	      << "Rank Error: " << rank_error << endl;
+  }
+
 }; // end check_nn_utils




More information about the mlpack-svn mailing list