[mlpack-svn] r10052 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 27 09:48:31 EDT 2011
Author: pram
Date: 2011-10-27 09:48:31 -0400 (Thu, 27 Oct 2011)
New Revision: 10052
Modified:
mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc
mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
Log:
Oct 25th, Rank-approximate fixed for very high k-values
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc 2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc 2011-10-27 13:48:31 UTC (rev 10052)
@@ -12,7 +12,7 @@
PROGRAM_INFO("Approx-maximum Inner Product", "This program "
"returns the approx-maximum inner product for a "
"given query over a set of points (references).",
- "");
+ "approx_maxip");
PARAM_STRING_REQ("r", "The reference set", "");
PARAM_STRING_REQ("q", "The set of queries", "");
@@ -111,8 +111,15 @@
<< rdata.n_cols << " / " << (float) approx_comp << " = "
<<(float) (rdata.n_cols / approx_comp) << endl;
+
if (CLI::GetParam<string>("rank_file") != "") {
string rank_file = CLI::GetParam<string>("rank_file");
+ double epsilon = CLI::GetParam<double>("approx_maxip/epsilon");
+ double alpha = CLI::GetParam<double>("approx_maxip/alpha");
+
+ check_nn_utils::check_rank_bound(rank_file, rdata.n_cols,
+ epsilon, alpha, &apc);
+
check_nn_utils::compute_error(rank_file, rdata.n_cols, &apc);
} else {
check_nn_utils::compute_error(&rdata, &qdata, &apc);
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc 2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc 2011-10-27 13:48:31 UTC (rev 10052)
@@ -298,12 +298,16 @@
do {
n = (size_t) (beta * (double)(set_size)) +1;
(*samples)(--set_size) = n;
- } while (set_size > rank_approx);
+// } while (set_size > rank_approx);
+ } while (set_size > 0);
- while (set_size > 0) {
- set_size--;
- (*samples)(set_size) = std::min(k, set_size + 1);
- }
+
+ // Maybe we do not do this, this is throwing things off
+ // Let's try removing this
+// while (set_size > 0) {
+// set_size--;
+// (*samples)(set_size) = std::min(k, set_size + 1);
+// }
} else {
while (set_size > 0) {
set_size--;
@@ -321,6 +325,10 @@
// Check that the pointers are not NULL
assert(reference_node != NULL);
+ if (!CLI::HasParam("approx_maxip/no_tree"))
+ assert(reference_node->is_leaf()
+ || is_base(reference_node)
+ || is_almost_satisfied());
// Obtain the number of samples to be obtained
size_t set_size
@@ -332,9 +340,14 @@
// = min_samples_per_q_ - query_node->stat().samples();
sample_size = std::min(sample_size, query_samples_needed_);
- assert(sample_size <= sample_limit_);
+ if (!CLI::HasParam("approx_maxip/no_tree")) {
+// printf("Leaf size: %zu, Sample size: %zu\n",
+// set_size, sample_size); fflush(NULL);
+ assert(sample_size <= sample_limit_);
+ }
+
// Get the query point from the matrix
arma::vec q = queries_.unsafe_col(query_);
@@ -582,11 +595,44 @@
}
+void ApproxMaxIP::CheckPrune(TreeType* ref_node) {
+
+ size_t missed_nns = 0;
+ double max_ip = 0.0;
+
+ // Get the query point from the matrix
+ arma::vec q = queries_.unsafe_col(query_);
+
+ double min_ip = max_ips_(knns_ -1, query_);
+
+ // We'll do the same for the references
+ for (size_t reference_index = ref_node->begin();
+ reference_index < ref_node->end(); reference_index++) {
+
+ arma::vec r = references_.unsafe_col(reference_index);
+
+ double ip_r = arma::dot(q, r);
+ if (ip_r > min_ip)
+ missed_nns++;
+
+ if (ip_r > max_ip)
+ max_ip = ip_r;
+
+ } // for reference_index
+
+ if (missed_nns > 0)
+ printf("Prune %zu - Missed candidates: %zu\n"
+ "QLBound: %lg, QRBound: %lg, ActualQRBound: %lg\n",
+ number_of_prunes_, missed_nns,
+ min_ip, MaxNodeIP_(ref_node), max_ip);
+
+}
+
void ApproxMaxIP::ComputeApproxRecursion_(TreeType* reference_node,
double upper_bound_ip) {
assert(reference_node != NULL);
- //assert(upper_bound_ip == MaxNodeIP_(reference_node));
+ // assert(upper_bound_ip == MaxNodeIP_(reference_node));
// check if the query has enough number of samples
@@ -594,16 +640,23 @@
if (upper_bound_ip < max_ips_(knns_ -1, query_)) {
// Pruned by distance
number_of_prunes_++;
+
+ if (CLI::HasParam("approx_maxip/check_prune"))
+ CheckPrune(reference_node);
+
query_samples_needed_
-= sample_sizes_[reference_node->end()
- reference_node->begin() - 1];
} else if (reference_node->is_leaf()) {
- // base case for the single tree case
+ // base case for the single tree case
ComputeBaseCase_(reference_node);
query_samples_needed_
-= (reference_node->end() - reference_node->begin());
+ // trying to see if this was the issue (DIDN'T WORK)
+ // ComputeApproxBaseCase_(reference_node);
+
} else if (is_base(reference_node)) {
// base case for the approximate case
ComputeApproxBaseCase_(reference_node);
@@ -611,25 +664,26 @@
} else if (is_almost_satisfied()) {
// base case for the approximate case
ComputeApproxBaseCase_(reference_node);
-
- } else {
+ } else {
// Recurse on both as above
double left_ip = MaxNodeIP_(reference_node->left());
double right_ip = MaxNodeIP_(reference_node->right());
-
+
if (left_ip > right_ip) {
ComputeApproxRecursion_(reference_node->left(),
- left_ip);
+ left_ip);
ComputeApproxRecursion_(reference_node->right(),
- right_ip);
+ right_ip);
} else {
ComputeApproxRecursion_(reference_node->right(),
- right_ip);
+ right_ip);
ComputeApproxRecursion_(reference_node->left(),
- left_ip);
+ left_ip);
}
}
- }
+ } else {
+// assert(query_samples_needed_ <= 0);
+ }
} // ComputeApproxRecursion_
@@ -704,7 +758,7 @@
query_node->right()->stat().add_total_points(extra_points_encountered);
size_t extra_points_sampled
= query_node->stat().samples()
- - std::max(query_node->left()->stat().samples(),
+ - std::min(query_node->left()->stat().samples(),
query_node->right()->stat().samples());
assert(extra_points_sampled >= 0);
query_node->left()->stat().add_samples(extra_points_sampled);
@@ -728,6 +782,11 @@
query_node->stat().set_total_points(
query_node->left()->stat().total_points());
+// printf("%zu: L:%zu, R:%zu\n", query_node->stat().samples(),
+// query_node->left()->stat().samples(),
+// query_node->right()->stat().samples()); fflush(NULL);
+
+
assert(query_node->stat().samples() <=
std::min(query_node->left()->stat().samples(),
query_node->right()->stat().samples()));
@@ -786,7 +845,7 @@
query_node->right()->stat().add_total_points(extra_points_encountered);
size_t extra_points_sampled
= query_node->stat().samples()
- - std::max(query_node->left()->stat().samples(),
+ - std::min(query_node->left()->stat().samples(),
query_node->right()->stat().samples());
assert(extra_points_sampled >= 0);
query_node->left()->stat().add_samples(extra_points_sampled);
@@ -1073,8 +1132,15 @@
for (query_ = 0; query_ < queries_.n_cols; ++query_) {
query_samples_needed_ = min_samples_per_q_;
- ComputeApproxRecursion_(reference_tree_,
- MaxNodeIP_(reference_tree_));
+ if (CLI::HasParam("approx_maxip/no_tree")) {
+ // ComputeApproxBaseCase_(reference_tree_);
+ ComputeApproxBaseCase_(reference_tree_->left());
+ ComputeApproxBaseCase_(reference_tree_->right());
+ } else
+ ComputeApproxRecursion_(reference_tree_,
+ MaxNodeIP_(reference_tree_));
+
+// assert(!(query_samples_needed_ > 0));
}
CLI::StopTimer("approx_maxip/fast_single");
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-27 13:48:31 UTC (rev 10052)
@@ -45,6 +45,9 @@
PARAM_FLAG("check_prune", "The flag to trigger the "
"checking of the prune.", "approx_maxip");
+
+PARAM_FLAG("no_tree", "The flag to trigger the tree-less "
+ "rank-approximate search.", "approx_maxip");
/**
* Performs maximum-inner-product-search.
* This class will build the trees and
@@ -251,6 +254,9 @@
void ComputeBaseCase_(CTreeType* query_node,
TreeType* reference_node);
+
+
+
/**
* The recursive function for the approximate computation
*/
@@ -347,6 +353,7 @@
arma::mat* ips);
void CheckPrune(CTreeType* query_node, TreeType* ref_node);
+ void CheckPrune(TreeType* reference_node);
}; //class ApproxMaxIP
#endif
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc 2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_tester.cc 2011-10-27 13:48:31 UTC (rev 10052)
@@ -64,7 +64,7 @@
max_k = ks(number_of_ks -1);
}
- arma::vec eps(25);
+ arma::vec eps(100);
// arma::vec als(10);
size_t num_eps = 0; //, num_als = 0;
@@ -78,12 +78,20 @@
free(pch);
- Log::Info << number_of_ks << " values for k," << endl
+ Log::Warn << number_of_ks << " values for k," << endl
<< num_eps << " values for epsilon." << endl;
Log::Warn << "Starting loop for Fast Approx-Search." << endl;
+ // If you did multiple repetitions, put the loop here
+ // have the following:
+ // size_t total_reps = CLI::GetParam<int>("reps");
+ // arma::mat res = arma::zeros<arma::mat>(number_of_ks * num_eps, 5);
+ // for (size_t reps = 0; reps < total_reps; reps++) {
+
+
+
ApproxMaxIP fast_approx;
vector< arma::Mat<size_t>* > all_solutions;
vector<double> all_th_speedups;
@@ -147,14 +155,77 @@
assert(median_ranks.size() == number_of_ks * num_eps);
assert(avg_precisions.size() == all_th_speedups.size());
- printf("k\te\tp\tmr\tth_sp\n");
- for (size_t i = 0; i < number_of_ks; i++) {
- for (size_t j = 0; j < num_eps; j++) {
- printf("%zu\t%lg\t%lg\t%zu\t%lg\n", ks(i), eps(j),
- avg_precisions[i * num_eps + j],
- median_ranks[i * num_eps + j],
- all_th_speedups[i * num_eps + j]); fflush(NULL);
+
+ // If we are performing reps, we have to make sure
+ // that the check_nn_utils::compute_error() is called
+ // only once (especially for the Yahoo data set)
+
+
+// for (size_t i = 0; i < number_of_ks; i++) {
+// for (size_t j = 0; j < num_eps; j++) {
+
+// res(i * num_eps + j, 0) += (double) ks(i);
+// res(i * num_eps + j, 1) += eps(j);
+// res(i * num_eps + j, 2) += avg_precisions[i * num_eps + j];
+// res(i * num_eps + j, 3) += (double) median_ranks[i * num_eps + j];
+// res(i * num_eps + j, 4) += all_th_speedups[i * num_eps + j]);
+// }
+// }
+
+
+// } // reps-loop
+// if (CLI::GetParam<string>("res_file") != "") {
+
+// string res_file = GetParam<string>("res_file");
+// FILE *res_fp = fopen(res_file.c_str(), "w");
+// for (size_t i = 0; i < res.n_rows; i++) {
+// for (size_t j = 0; j < res.n_cols; j++) {
+// fprintf(res_fp, "%lg", res(i, j) / (double) total_reps);
+// if (j == res.n_cols -1)
+// fprintf(res_fp, "\n");
+// else
+// fprintf(res_fp, ",");
+// }
+// }
+// } else {
+
+// printf("k\te\tp\tmr\tth_sp\n");
+// for (size_t i = 0; i < res.n_rows; i++) {
+// for (size_t j = 0; j < res.n_cols; j++) {
+// printf("%lg", res(i, j) / (double) total_reps);
+// if (j == res.n_cols -1)
+// printf("\n");
+// else
+// printf(",");
+// }
+// }
+// }
+
+
+
+ if (CLI::GetParam<string>("res_file") != "") {
+
+ string res_file = CLI::GetParam<string>("res_file");
+ FILE *res_fp = fopen(res_file.c_str(), "w");
+ for (size_t i = 0; i < number_of_ks; i++)
+ for (size_t j = 0; j < num_eps; j++)
+ fprintf(res_fp, "%zu,%lg,%lg,%zu,%lg\n", ks(i), eps(j),
+ avg_precisions[i * num_eps + j],
+ median_ranks[i * num_eps + j],
+ all_th_speedups[i * num_eps + j]);
+
+ } else {
+
+ printf("k\te\tp\tmr\tth_sp\n");
+ for (size_t i = 0; i < number_of_ks; i++) {
+ for (size_t j = 0; j < num_eps; j++) {
+
+ printf("%zu\t%lg\t%lg\t%zu\t%lg\n", ks(i), eps(j),
+ avg_precisions[i * num_eps + j],
+ median_ranks[i * num_eps + j],
+ all_th_speedups[i * num_eps + j]); fflush(NULL);
+ }
}
}
} // end main
@@ -177,11 +248,13 @@
"");
// PARAM_STRING("alphas", "The comma-separated list of alphas", "");
-// PARAM_INT("reps", "The number of times the rank-approximate"
-// " algorithm is to be repeated for the same setting.",
-// "", 1);
+PARAM_INT("reps", "The number of times the rank-approximate"
+ " algorithm is to be repeated for the same setting.",
+ "", 1);
PARAM_INT("max_k", "The max value of knns to be tried.", "", 1);
PARAM_STRING("rank_file", "The file containing the ranks.",
"", "");
+PARAM_STRING("res_file", "The file where the results are to be written.",
+ "", "");
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h 2011-10-27 00:35:58 UTC (rev 10051)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h 2011-10-27 13:48:31 UTC (rev 10052)
@@ -63,6 +63,8 @@
all_ranks_list.set_size(qdata->n_cols * indices->n_rows);
+ size_t total_considerable_candidates
+ = indices->n_cols * indices->n_rows;
double perc_done = 10.0;
double done_sky = 1.0;
@@ -93,7 +95,9 @@
all_correct++;
} else {
// if no result found, just penalize the worst rank
- all_ranks_list( i * k + j ) = rdata->n_cols;
+// all_ranks_list( i * k + j ) = rdata->n_cols;
+ all_ranks_list( i * k + j ) = -1;
+ total_considerable_candidates--;
}
}
@@ -130,12 +134,27 @@
Log::Info << "Errors Computed!" << endl;
+// double avg_precision = (double) all_correct
+// / (double) (k * qdata->n_cols);
+
+// size_t median_rank = arma::median(all_ranks_list);
+
double avg_precision = (double) all_correct
- / (double) (k * qdata->n_cols);
+ / (double) (total_considerable_candidates);
- size_t median_rank = arma::median(all_ranks_list);
+ arma::uvec tcc = find(all_ranks_list != -1);
+ assert(tcc.n_elem == total_considerable_candidates);
+ arma::Col<size_t> hemmed_all_ranks_list;
+ hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+ for (size_t i = 0; i < tcc.n_elem; i++)
+ hemmed_all_ranks_list(i) = all_ranks_list(tcc(i));
+
+ size_t median_rank = arma::median(hemmed_all_ranks_list);
+
+
Log::Warn << "Avg. Precision@" << k << ": "
<< avg_precision << endl
<< "Median Rank@" << k << ": "
@@ -155,6 +174,9 @@
all_ranks_list.set_size(indices->n_cols * indices->n_rows);
+ size_t total_considerable_candidates
+ = indices->n_cols * indices->n_rows;
+
double perc_done = 10.0;
double done_sky = 1.0;
@@ -190,13 +212,14 @@
if ((*indices)(j, i) != (size_t) -1) {
- // assert((*indices)(j, i) != -1);
size_t rank = rank_ind((*indices)(j, i));
all_ranks_list( i * k + j ) = rank;
if (rank < k + 1)
all_correct++;
} else {
- all_ranks_list( i * k + j ) = rdata_size;
+// all_ranks_list( i * k + j ) = rdata_size;
+ all_ranks_list( i * k + j ) = -1;
+ total_considerable_candidates--;
}
}
@@ -230,16 +253,33 @@
Log::Info << "Errors Computed!" << endl;
+// double avg_precision = (double) all_correct
+// / (double) (k * indices->n_cols);
+
+// size_t median_rank = arma::median(all_ranks_list);
+
double avg_precision = (double) all_correct
- / (double) (k * indices->n_cols);
+ / (double) (total_considerable_candidates);
- size_t median_rank = arma::median(all_ranks_list);
+ arma::uvec tcc = find(all_ranks_list != -1);
+ assert(tcc.n_elem == total_considerable_candidates);
+ arma::Col<size_t> hemmed_all_ranks_list;
+ hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+ for (size_t i = 0; i < tcc.n_elem; i++)
+ hemmed_all_ranks_list(i) = all_ranks_list(tcc(i));
+
+ size_t median_rank = arma::median(hemmed_all_ranks_list);
+
+
Log::Warn << "Avg. Precision@" << k << ": "
<< avg_precision << endl
<< "Median Rank@" << k << ": "
- << median_rank << endl;
+ << median_rank << endl
+ << "TCC: " << total_considerable_candidates
+ << endl;
}
@@ -255,17 +295,22 @@
vector< arma::Col<size_t>* > all_ranks_lists;
vector<size_t> all_corrects;
+ vector<size_t> total_considerable_candidates;
// set up the list first
for (size_t i = 0; i < solutions.size(); i++) {
arma::Col<size_t>* all_ranks_list = new arma::Col<size_t>();
- size_t all_correct = 0;
+ //size_t all_correct = 0;
all_ranks_list->set_size(solutions[i]->n_cols * solutions[i]->n_rows);
all_ranks_lists.push_back(all_ranks_list);
- all_corrects.push_back(all_correct);
+ //all_corrects.push_back(all_correct);
+ all_corrects.push_back(0);
+
+ total_considerable_candidates.push_back(solutions[i]->n_cols
+ * solutions[i]->n_rows);
}
@@ -327,7 +372,9 @@
} else {
// if no result found, just penalize the worst rank
- (*all_ranks_lists[ind])( i * k + j ) = rdata_size;
+// (*all_ranks_lists[ind])( i * k + j ) = rdata_size;
+ (*all_ranks_lists[ind])( i * k + j ) = -1;
+ total_considerable_candidates[ind]--;
}
} // top k neighbors
@@ -363,10 +410,25 @@
for (size_t ind = 0; ind < num_solutions; ind++) {
+// precisions->push_back((double) all_corrects[ind]
+// / (double) all_ranks_lists[ind]->n_elem);
+// median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
precisions->push_back((double) all_corrects[ind]
- / (double) all_ranks_lists[ind]->n_elem);
- median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+ / (double) total_considerable_candidates[ind]);
+
+ arma::uvec tcc = find((*all_ranks_lists[ind]) != -1);
+
+ assert(tcc.n_elem == total_considerable_candidates[ind]);
+
+ arma::Col<size_t> hemmed_all_ranks_list;
+ hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+ for (size_t i = 0; i < tcc.n_elem; i++)
+ hemmed_all_ranks_list(i) = (*all_ranks_lists[ind])(tcc(i));
+
+ median_ranks->push_back(arma::median(hemmed_all_ranks_list));
+ delete(all_ranks_lists[ind]);
}
return;
@@ -381,17 +443,20 @@
vector< arma::Col<size_t>* > all_ranks_lists;
vector<size_t> all_corrects;
+ vector<size_t> total_considerable_candidates;
// set up the list first
for (size_t i = 0; i < solutions.size(); i++) {
arma::Col<size_t>* all_ranks_list = new arma::Col<size_t>();
- size_t all_correct = 0;
+ //size_t all_correct = 0;
all_ranks_list->set_size(solutions[i]->n_cols * solutions[i]->n_rows);
all_ranks_lists.push_back(all_ranks_list);
- all_corrects.push_back(all_correct);
+ all_corrects.push_back(0);
+ total_considerable_candidates.push_back(solutions[i]->n_cols
+ * solutions[i]->n_rows);
}
@@ -438,7 +503,9 @@
if (rank < k + 1)
all_corrects[ind]++;
} else {
- (*all_ranks_lists[ind])( i * k + j ) = rdata->n_cols;
+// (*all_ranks_lists[ind])( i * k + j ) = rdata->n_cols;
+ (*all_ranks_lists[ind])( i * k + j ) = -1;
+ total_considerable_candidates[ind]--;
}
} // top k neighbors
} // all solutions for this query
@@ -475,13 +542,133 @@
for (size_t ind = 0; ind < num_solutions; ind++) {
+// precisions->push_back((double) all_corrects[ind]
+// / (double) all_ranks_lists[ind]->n_elem);
+// median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+
precisions->push_back((double) all_corrects[ind]
- / (double) all_ranks_lists[ind]->n_elem);
- median_ranks->push_back(arma::median(*all_ranks_lists[ind]));
+ / (double) total_considerable_candidates[ind]);
+
+ arma::uvec tcc = find((*all_ranks_lists[ind]) != -1);
+
+ assert(tcc.n_elem == total_considerable_candidates[ind]);
+
+ arma::Col<size_t> hemmed_all_ranks_list;
+ hemmed_all_ranks_list.set_size(tcc.n_elem);
+
+ for (size_t i = 0; i < tcc.n_elem; i++)
+ hemmed_all_ranks_list(i) = (*all_ranks_lists[ind])(tcc(i));
+
+ median_ranks->push_back(arma::median(hemmed_all_ranks_list));
+
delete(all_ranks_lists[ind]);
}
return;
}
+
+
+
+
+ void check_rank_bound(string rank_file, size_t rdata_size,
+ double epsilon, double alpha,
+ arma::Mat<size_t>* indices) {
+
+ size_t all_correct = 0;
+
+ size_t k = indices->n_rows;
+
+ size_t rank_error = (size_t) ( epsilon * (double) rdata_size / 100.0);
+
+ size_t total_considerable_candidates = indices->n_cols;
+
+ double perc_done = 10.0;
+ double done_sky = 1.0;
+
+ FILE *rank_fp = fopen(rank_file.c_str(), "r");
+
+ // do it with a loop over the queries.
+ for (size_t i = 0; i < indices->n_cols; i++) {
+
+ // obtaining the rank list
+ arma::Col<size_t> rank_ind;
+ rank_ind.set_size(rdata_size);
+
+ if (rank_fp != NULL) {
+ char *line = NULL;
+ size_t len = 0;
+ getline(&line, &len, rank_fp);
+
+ char *pch = strtok(line, ",\n");
+ size_t rank_index = 0;
+
+ while(pch != NULL) {
+ rank_ind(rank_index++) = atoi(pch);
+ pch = strtok(NULL, ",\n");
+ }
+
+ free(line);
+ free(pch);
+ assert(rank_index == rdata_size);
+ }
+
+
+// for (size_t j = 0; j < k; j++) {
+
+
+ if ((*indices)(k-1, i) != (size_t) -1) {
+ size_t rank = rank_ind((*indices)(k -1, i));
+ if (rank < rank_error +1)
+ all_correct++;
+ } else {
+ total_considerable_candidates--;
+ }
+
+ double pdone = i * 100 / indices->n_cols;
+
+ if (pdone >= done_sky * perc_done) {
+ if (done_sky > 1) {
+ printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL);
+ } else {
+ printf("=%zu%%", (size_t) pdone); fflush(NULL);
+ }
+ done_sky++;
+ }
+ } // query-loop
+
+
+ fclose(rank_fp);
+
+ double pdone = 100;
+
+ if (pdone >= done_sky * perc_done) {
+ if (done_sky > 1) {
+ printf("\b\b\b=%zu%%", (size_t) pdone); fflush(NULL);
+ } else {
+ printf("=%zu%%", (size_t) pdone); fflush(NULL);
+ }
+ done_sky++;
+ }
+ printf("\n");fflush(NULL);
+
+ Log::Info << "Errors Computed!" << endl;
+
+
+// double avg_precision = (double) all_correct
+// / (double) (k * indices->n_cols);
+
+// size_t median_rank = arma::median(all_ranks_list);
+
+ double avg_precision = (double) all_correct
+ / (double) (indices->n_cols);
+
+
+ Log::Warn << "Actual Alpha @" << k << ": "
+ << avg_precision << endl
+ << "TCC: " << total_considerable_candidates
+ << endl << "Alpha: " << alpha << endl
+ << "Rank Error: " << rank_error << endl;
+ }
+
}; // end check_nn_utils
More information about the mlpack-svn
mailing list