[mlpack-svn] r10186 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 13:07:02 EST 2011
Author: pram
Date: 2011-11-08 13:07:01 -0500 (Tue, 08 Nov 2011)
New Revision: 10186
Modified:
mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc
Log:
MaxIP search commit after submission, need to clean up code and add approxMaxIP and fix the dual tree
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-11-08 18:07:01 UTC (rev 10186)
@@ -112,4 +112,13 @@
target_link_libraries(rank_inverter
mlpack
mlpack_contrib
+)
+
+add_executable(new_pair_dist
+ EXCLUDE_FROM_ALL
+ new_pairwise_dists.cc
+)
+target_link_libraries(new_pair_dist
+ mlpack
+ mlpack_contrib
)
\ No newline at end of file
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h 2011-11-08 18:07:01 UTC (rev 10186)
@@ -236,8 +236,6 @@
} // query-loop
- fclose(rank_fp);
-
double pdone = 100;
if (pdone >= done_sky * perc_done) {
@@ -253,6 +251,7 @@
Log::Info << "Errors Computed!" << endl;
+
// double avg_precision = (double) all_correct
// / (double) (k * indices->n_cols);
@@ -281,6 +280,8 @@
<< "TCC: " << total_considerable_candidates
<< endl;
+ fclose(rank_fp);
+
}
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-11-08 18:07:01 UTC (rev 10186)
@@ -20,7 +20,7 @@
arma::vec centroid = reference_node->bound().center();
// +1: Can be cached in the reference tree
- double c_norm = reference_node->stat().dist_to_origin();
+ double c_norm = reference_node->stat().center_norm();
assert(arma::norm(q, 2) == query_norms_(query_));
@@ -96,7 +96,7 @@
arma::vec centroid = reference_node->bound().center();
- double c_norm = reference_node->stat().dist_to_origin();
+ double c_norm = reference_node->stat().center_norm();
double rad = std::sqrt(reference_node->bound().radius());
double max_cos_qp = 1.0;
@@ -186,7 +186,7 @@
arma::vec centroid = reference_node->bound().center();
- double c_norm = reference_node->stat().dist_to_origin();
+ double c_norm = reference_node->stat().center_norm();
double rad = std::sqrt(reference_node->bound().radius());
if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) {
@@ -240,6 +240,65 @@
}
+// PLAIN OLD METRIC TREE -- returns actual IP (not |p| cos pq)
+double MaxIP::MaxNodeIP_(TreeTypeA* query_node,
+ TreeType* reference_node) {
+
+ // counting the split decisions
+ split_decisions_++;
+
+ arma::vec q_centroid = query_node->bound().center();
+ double q_rad = std::sqrt(query_node->bound().radius());
+ double q_norm = query_node->stat().center_norm();
+
+ arma::vec centroid = reference_node->bound().center();
+ double c_norm = reference_node->stat().center_norm();
+ double rad = std::sqrt(reference_node->bound().radius());
+
+ // +1
+ double q_dot_r = arma::dot(q_centroid, centroid);
+ double cos_phi = q_dot_r / (c_norm * q_norm);
+
+ double min_cos = 1.0;
+
+ if (cos_phi < query_node->stat().cosine_origin()
+ && cos_phi < reference_node->stat().cosine_origin()) {
+
+ double cos_w_qr = query_node->stat().cosine_origin()
+ * reference_node->stat().cosine_origin()
+ - query_node->stat().sine_origin() * reference_node->stat().sine_origin();
+
+ if (cos_w_qr > cos_phi) {
+ // balls do not intersect
+ // in that case, compute cos(phi - w_q - w_r)
+
+ double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+ double sin_w_qr = std::sqrt(1 - cos_w_qr * cos_w_qr);
+
+ min_cos = cos_phi * cos_w_qr + sin_phi * sin_w_qr;
+ }
+ // else they do intersect so min_cos = 1.0
+ } // same as above
+
+ if (min_cos > 0.0) {
+ // there is some potential
+
+ double k = (rad * q_norm) / (q_rad * c_norm);
+
+ double term3 = (k + cos_phi) / std::sqrt(k * k + 2 * k * cos_phi + 1);
+ double term4 = (q_rad * c_norm + rad * q_norm * cos_phi) / (k + cos_phi)
+ + q_norm * rad;
+
+ return (q_dot_r + rad * q_rad + term3 * term4);
+
+ } else {
+
+ return 0.0;
+
+ }
+}
+
+
void MaxIP::ComputeBaseCase_(TreeType* reference_node) {
assert(reference_node != NULL);
@@ -431,7 +490,48 @@
} // ComputeBaseCase_
+// PLAIN OLD METRIC TREE
+void MaxIP::ComputeBaseCase_(TreeTypeA* query_node,
+ TreeType* reference_node) {
+ // Check that the pointers are not NULL
+ assert(reference_node != NULL);
+ assert(reference_node->is_leaf());
+ assert(query_node != NULL);
+
+ // query node may not be a leaf
+ // assert(query_node->is_leaf());
+
+ // Used to find the query node's new lower bound
+ double query_worst_ip = DBL_MAX;
+ bool new_bound = false;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // checking if this node has potential
+ double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+ if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+ // this node has potential
+ ComputeBaseCase_(reference_node);
+
+ double min_ip = max_ips_(knns_ -1, query_);
+
+ if (query_worst_ip > min_ip) {
+ query_worst_ip = min_ip;
+ new_bound = true;
+ }
+ } // for query_
+
+ // Update the lower bound for the query_node
+ if (new_bound)
+ query_node->stat().set_bound(query_worst_ip);
+
+} // ComputeBaseCase_
+
+
// CONE TREE
void MaxIP::CheckPrune(CTreeTypeA* query_node, TreeType* ref_node) {
@@ -526,7 +626,53 @@
}
+// PLAIN OLD METRIC TREE
+void MaxIP::CheckPrune(TreeTypeA* query_node, TreeType* ref_node) {
+ size_t missed_nns = 0;
+ double max_ip = 0.0;
+ double min_ip = DBL_MAX;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // Get the query point from the matrix
+ arma::vec q = queries_.unsafe_col(query_);
+
+ double ip1 = max_ips_(knns_ -1, query_);
+ if (min_ip > ip1)
+ min_ip = ip1;
+
+ // We'll do the same for the references
+ for (size_t reference_index = ref_node->begin();
+ reference_index < ref_node->end(); reference_index++) {
+
+ arma::vec r = references_.unsafe_col(reference_index);
+
+ double ip = arma::dot(q, r);
+ if (ip > max_ips_(knns_-1, query_))
+ missed_nns++;
+
+ if (ip > max_ip)
+ max_ip = ip;
+
+ } // for reference_index
+ } // for query_
+
+ if (missed_nns > 0 || query_node->stat().bound() != min_ip) {
+ printf("Prune %zu - Missed candidates: %zu\n"
+ "QLBound: %lg, ActualQLBound: %lg\n"
+ "QRBound: %lg, ActualQRBound: %lg\n",
+ number_of_prunes_, missed_nns,
+ query_node->stat().bound(), min_ip,
+ MaxNodeIP_(query_node, ref_node), max_ip);
+ exit(0);
+ }
+
+}
+
+
// CONE TREE
void MaxIP::ComputeNeighborsRecursion_(CTreeTypeA* query_node,
TreeType* reference_node,
@@ -549,24 +695,57 @@
ComputeBaseCase_(query_node, reference_node);
} else if (query_node->is_leaf()) {
// Only query is a leaf
+
+ if (CLI::HasParam("maxip/alt_dual_traversal")) {
+ // Trying to do single-tree on the leaves
+
+ // Used to find the query node's new lower bound
+ double query_worst_p_cos_pq = DBL_MAX;
+ bool new_bound = false;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // checking if this node has potential
+ double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+ if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+ // this node has potential
+ ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+ double p_cos_pq = max_ips_(knns_ -1, query_)
+ / query_norms_(query_);
+
+ if (query_worst_p_cos_pq > p_cos_pq) {
+ query_worst_p_cos_pq = p_cos_pq;
+ new_bound = true;
+ }
+ } // for query_
+
+ // Update the lower bound for the query_node
+ if (new_bound)
+ query_node->stat().set_bound(query_worst_p_cos_pq);
+ } else {
+
+ // We'll order the computation by distance
+ double left_p_cos_pq = MaxNodeIP_(query_node,
+ reference_node->left());
+ double right_p_cos_pq = MaxNodeIP_(query_node,
+ reference_node->right());
- // We'll order the computation by distance
- double left_p_cos_pq = MaxNodeIP_(query_node,
- reference_node->left());
- double right_p_cos_pq = MaxNodeIP_(query_node,
- reference_node->right());
-
- if (left_p_cos_pq > right_p_cos_pq) {
- ComputeNeighborsRecursion_(query_node, reference_node->left(),
- left_p_cos_pq);
- ComputeNeighborsRecursion_(query_node, reference_node->right(),
- right_p_cos_pq);
- } else {
- ComputeNeighborsRecursion_(query_node, reference_node->right(),
- right_p_cos_pq);
- ComputeNeighborsRecursion_(query_node, reference_node->left(),
- left_p_cos_pq);
- }
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_p_cos_pq);
+ }
+ } // alt-traversal
} else if (reference_node->is_leaf()) {
// Only reference is a leaf
double left_p_cos_pq
@@ -585,104 +764,52 @@
query_node->right()->stat().bound()));
} else {
- if (CLI::HasParam("maxip/alt_dual_traversal")) {
- // try something new
- // if query_node->radius() > reference_node->stat().cosine_origin()
- // traverse down the reference tree
- // else
- // traverse down the query tree
-
- if (query_node->bound().radius()
- > reference_node->stat().cosine_origin()) {
-
- // go down the reference tree
- double left_p_cos_pq = MaxNodeIP_(query_node,
- reference_node->left());
- double right_p_cos_pq = MaxNodeIP_(query_node,
- reference_node->right());
+ // Recurse on both as above
+ double left_p_cos_pq = MaxNodeIP_(query_node->left(),
+ reference_node->left());
+ double right_p_cos_pq = MaxNodeIP_(query_node->left(),
+ reference_node->right());
- if (left_p_cos_pq > right_p_cos_pq) {
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(),
- left_p_cos_pq);
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(),
- right_p_cos_pq);
- } else {
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(),
- right_p_cos_pq);
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(),
- left_p_cos_pq);
- }
- } else {
-
- // go down the query tree
- double left_p_cos_pq = MaxNodeIP_(query_node->left(),
- reference_node);
- double right_p_cos_pq = MaxNodeIP_(query_node->right(),
- reference_node);
-
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node,
- left_p_cos_pq);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node,
- right_p_cos_pq);
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_p_cos_pq);
+ }
- // Update the upper bound as above
- query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
- query_node->right()->stat().bound()));
- }
- } else {
- // Recurse on both as above
- double left_p_cos_pq = MaxNodeIP_(query_node->left(),
- reference_node->left());
- double right_p_cos_pq = MaxNodeIP_(query_node->left(),
- reference_node->right());
+ left_p_cos_pq = MaxNodeIP_(query_node->right(),
+ reference_node->left());
+ right_p_cos_pq = MaxNodeIP_(query_node->right(),
+ reference_node->right());
- if (left_p_cos_pq > right_p_cos_pq) {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(),
- left_p_cos_pq);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(),
- right_p_cos_pq);
- } else {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(),
- right_p_cos_pq);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(),
- left_p_cos_pq);
- }
-
- left_p_cos_pq = MaxNodeIP_(query_node->right(),
- reference_node->left());
- right_p_cos_pq = MaxNodeIP_(query_node->right(),
- reference_node->right());
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_p_cos_pq);
+ }
- if (left_p_cos_pq > right_p_cos_pq) {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(),
- left_p_cos_pq);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(),
- right_p_cos_pq);
- } else {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(),
- right_p_cos_pq);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(),
- left_p_cos_pq);
- }
-
- // Update the upper bound as above
- query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
- query_node->right()->stat().bound()));
- } // alt-traversal
+ // Update the upper bound as above
+ query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+ query_node->right()->stat().bound()));
} // All cases of dual-tree traversal
} // ComputeNeighborsRecursion_
@@ -830,7 +957,147 @@
} // ComputeNeighborsRecursion_
+// PLAIN OLD METRIC TREE
+void MaxIP::ComputeNeighborsRecursion_(TreeTypeA* query_node,
+ TreeType* reference_node,
+ double upper_bound_ip) {
+ assert(query_node != NULL);
+ assert(reference_node != NULL);
+ //assert(upper_bound_ip == MaxNodeIP_(query_node, reference_node));
+
+ if (upper_bound_ip < query_node->stat().bound()) {
+ // Pruned
+ number_of_prunes_ += query_node->end() - query_node->begin();
+
+ if (CLI::HasParam("maxip/check_prune"))
+ CheckPrune(query_node, reference_node);
+ }
+ // node->is_leaf() works as one would expect
+ else if (query_node->is_leaf() && reference_node->is_leaf()) {
+ // Base Case
+ ComputeBaseCase_(query_node, reference_node);
+ } else if (query_node->is_leaf()) {
+ // Only query is a leaf
+
+ if (CLI::HasParam("maxip/alt_dual_traversal")) {
+ // Trying to do single-tree on the leaves
+
+ // Used to find the query node's new lower bound
+ double query_worst_ip = DBL_MAX;
+ bool new_bound = false;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // checking if this node has potential
+ double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+ if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+ // this node has potential
+ ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+ double ip = max_ips_(knns_ -1, query_);
+
+ if (query_worst_ip > ip) {
+ query_worst_ip = ip;
+ new_bound = true;
+ }
+ } // for query_
+
+ // Update the lower bound for the query_node
+ if (new_bound)
+ query_node->stat().set_bound(query_worst_ip);
+ } else {
+
+ // We'll order the computation by distance
+ double left_ip = MaxNodeIP_(query_node,
+ reference_node->left());
+ double right_ip = MaxNodeIP_(query_node,
+ reference_node->right());
+
+ if (left_ip > right_ip) {
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_ip);
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_ip);
+ } else {
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_ip);
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_ip);
+ }
+ } // alt-traversal
+ } else if (reference_node->is_leaf()) {
+ // Only reference is a leaf
+ double left_ip
+ = MaxNodeIP_(query_node->left(), reference_node);
+ double right_ip
+ = MaxNodeIP_(query_node->right(), reference_node);
+
+ ComputeNeighborsRecursion_(query_node->left(), reference_node,
+ left_ip);
+ ComputeNeighborsRecursion_(query_node->right(), reference_node,
+ right_ip);
+
+ // We need to update the upper bound based on the new upper bounds of
+ // the children
+ query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+ query_node->right()->stat().bound()));
+ } else {
+
+ // Recurse on both as above
+ double left_ip = MaxNodeIP_(query_node->left(),
+ reference_node->left());
+ double right_ip = MaxNodeIP_(query_node->left(),
+ reference_node->right());
+
+ if (left_ip > right_ip) {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_ip);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_ip);
+ } else {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_ip);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_ip);
+ }
+
+ left_ip = MaxNodeIP_(query_node->right(),
+ reference_node->left());
+ right_ip = MaxNodeIP_(query_node->right(),
+ reference_node->right());
+
+ if (left_ip > right_ip) {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_ip);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_ip);
+ } else {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_ip);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_ip);
+ }
+
+ // Update the upper bound as above
+ query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+ query_node->right()->stat().bound()));
+ } // All cases of dual-tree traversal
+} // ComputeNeighborsRecursion_
+
+
+
void MaxIP::Init(const arma::mat& queries_in,
const arma::mat& references_in) {
@@ -877,22 +1144,41 @@
if (mlpack::CLI::HasParam("maxip/dual_tree")) {
+ size_t qleaf_size;
+
+ if (mlpack::CLI::GetParam<double>("maxip/qleaf_size") != 0.0)
+
+ qleaf_size = (size_t) (mlpack::CLI::GetParam<double>("maxip/qleaf_size")
+ * (double) queries_.n_cols) / 100.0;
+ else
+ qleaf_size = leaf_size_ - 1;
+
if (mlpack::CLI::HasParam("maxip/alt_tree")) {
// using cosine tree
query_tree_B_
= proximity::MakeGenCosineTree<CTreeTypeB>(queries_,
- leaf_size_,
+ qleaf_size + 1,
&old_from_new_queries_,
NULL);
set_norms_in_cones_(query_tree_B_);
+ } else if (mlpack::CLI::HasParam("maxip/alt_tree2")) {
+
+ // using plain old Metric tree
+ query_tree_
+ = proximity::MakeGenMetricTree<TreeTypeA>(queries_,
+ qleaf_size + 1,
+ &old_from_new_queries_,
+ NULL);
+ set_angles_in_balls_(query_tree_);
+
} else {
// using cone tree
query_tree_A_
= proximity::MakeGenConeTree<CTreeTypeA>(queries_,
- leaf_size_,
+ qleaf_size + 1,
&old_from_new_queries_,
NULL);
set_norms_in_cones_(query_tree_A_);
@@ -986,12 +1272,14 @@
reset_tree_(reference_tree_);
+ } else if (mlpack::CLI::HasParam("maxip/alt_tree2")) {
+ if (query_tree_ != NULL)
+ reset_tree_(query_tree_);
+
} else {
- if (query_tree_A_ != NULL)
- reset_tree_(query_tree_A_);
+ if (query_tree_A_ != NULL)
+ reset_tree_(query_tree_A_);
}
-
-
} // WarmInit
void MaxIP::reset_tree_(CTreeTypeA* tree) {
@@ -1019,7 +1307,19 @@
return;
} // reset_tree_
+void MaxIP::reset_tree_(TreeTypeA* tree) {
+ assert(tree != NULL);
+ tree->stat().set_bound(0.0);
+ if (!tree->is_leaf()) {
+ reset_tree_(tree->left());
+ reset_tree_(tree->right());
+ }
+
+ return;
+} // reset_tree_
+
+
void MaxIP::reset_tree_(TreeType* tree) {
assert(tree != NULL);
tree->stat().reset();
@@ -1040,7 +1340,7 @@
double c_norm = arma::norm(tree->bound().center(), 2);
double rad = std::sqrt(tree->bound().radius());
- tree->stat().set_dist_to_origin(c_norm);
+ tree->stat().set_center_norm(c_norm);
if (rad <= c_norm)
tree->stat().set_angles(rad / c_norm, (size_t) 1);
else
@@ -1055,7 +1355,30 @@
return;
} // set_angles_in_balls_
+void MaxIP::set_angles_in_balls_(TreeTypeA* tree) {
+ assert(tree != NULL);
+
+ // set up node stats
+ double c_norm = arma::norm(tree->bound().center(), 2);
+ double rad = std::sqrt(tree->bound().radius());
+
+ tree->stat().set_center_norm(c_norm);
+ if (rad <= c_norm)
+ tree->stat().set_angles(rad / c_norm, (size_t) 1);
+ else
+ tree->stat().set_angles(-1.0, (size_t) 0);
+
+ // traverse down the children
+ if (!tree->is_leaf()) {
+ set_angles_in_balls_(tree->left());
+ set_angles_in_balls_(tree->right());
+ }
+
+ return;
+} // set_angles_in_balls_
+
+
void MaxIP::set_norms_in_cones_(CTreeTypeA* tree) {
assert(tree != NULL);
@@ -1089,8 +1412,6 @@
} // set_norms_in_cones_
-
-
double MaxIP::ComputeNeighbors(arma::Mat<size_t>* resulting_neighbors,
arma::mat* ips) {
@@ -1106,6 +1427,9 @@
if (mlpack::CLI::HasParam("maxip/alt_tree"))
ComputeNeighborsRecursion_(query_tree_B_, reference_tree_,
MaxNodeIP_(query_tree_B_, reference_tree_));
+ else if (mlpack::CLI::HasParam("maxip/alt_tree2"))
+ ComputeNeighborsRecursion_(query_tree_, reference_tree_,
+ MaxNodeIP_(query_tree_, reference_tree_));
else
ComputeNeighborsRecursion_(query_tree_A_, reference_tree_,
MaxNodeIP_(query_tree_A_, reference_tree_));
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-11-08 18:07:01 UTC (rev 10186)
@@ -35,10 +35,14 @@
private:
double bound_;
double center_norm_;
+ double cosine_origin_;
+ double sine_origin_;
public:
double bound() { return bound_; }
double center_norm() { return center_norm_; }
+ double cosine_origin() { return cosine_origin_; }
+ double sine_origin() { return sine_origin_; }
void set_bound(double bound) {
bound_ = bound;
@@ -48,9 +52,21 @@
center_norm_ = val;
}
+ void set_angles(double val, size_t type = 0) {
+ if (type == 0) { // given value is the cosine
+ cosine_origin_ = val;
+ sine_origin_ = std::sqrt(1 - val * val);
+ } else { // the given value is the sine
+ sine_origin_ = val;
+ cosine_origin_ = std::sqrt(1 - val * val);
+ }
+ }
+
QueryStat() {
bound_ = 0.0;
center_norm_ = 0.0;
+ cosine_origin_ = 0.0;
+ sine_origin_ = 0.0;
}
~QueryStat() {}
@@ -58,27 +74,31 @@
void Init(const arma::mat& data, size_t begin, size_t count) {
bound_ = 0.0;
center_norm_ = 0.0;
+ cosine_origin_ = 0.0;
+ sine_origin_ = 0.0;
}
void Init(const arma::mat& data, size_t begin, size_t count,
QueryStat& left_stat, QueryStat& right_stat) {
bound_ = 0.0;
center_norm_ = 0.0;
+ cosine_origin_ = 0.0;
+ sine_origin_ = 0.0;
}
}; // QueryStat
class RefStat {
private:
+ double center_norm_;
double cosine_origin_;
double sine_origin_;
- double dist_to_origin_;
bool has_cos_phi_;
double cos_phi_;
public:
double cosine_origin() { return cosine_origin_; }
double sine_origin() { return sine_origin_; }
- double dist_to_origin() { return dist_to_origin_; }
+ double center_norm() { return center_norm_; }
bool has_cos_phi() { return has_cos_phi_; }
double cos_phi() { return cos_phi_; }
@@ -93,8 +113,8 @@
}
}
- void set_dist_to_origin(double val) {
- dist_to_origin_ = val;
+ void set_center_norm(double val) {
+ center_norm_ = val;
}
void set_cos_phi(double cos_phi) {
@@ -111,7 +131,7 @@
RefStat() {
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
- dist_to_origin_ = 0.0;
+ center_norm_ = 0.0;
has_cos_phi_ = false;
cos_phi_ = -1.0;
}
@@ -120,7 +140,7 @@
void Init(const arma::mat& data, size_t begin, size_t count) {
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
- dist_to_origin_ = 0.0;
+ center_norm_ = 0.0;
has_cos_phi_ = false;
cos_phi_ = -1.0;
}
@@ -129,7 +149,7 @@
RefStat& left_stat, RefStat& right_stat) {
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
- dist_to_origin_ = 0.0;
+ center_norm_ = 0.0;
has_cos_phi_ = false;
cos_phi_ = -1.0;
}
@@ -141,6 +161,7 @@
typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeTypeA;
typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat, QueryStat> CTreeTypeB;
+ typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, QueryStat> TreeTypeA;
/////////////////////////////// Members ////////////////////////////
@@ -158,6 +179,8 @@
CTreeTypeA* query_tree_A_;
CTreeTypeB* query_tree_B_;
+ TreeTypeA* query_tree_;
+
// The total number of prunes.
size_t number_of_prunes_;
size_t ball_has_origin_;
@@ -196,6 +219,7 @@
reference_tree_ = NULL;
query_tree_A_ = NULL;
query_tree_B_ = NULL;
+ query_tree_ = NULL;
}
/**
@@ -212,6 +236,9 @@
if (query_tree_B_ != NULL)
delete query_tree_B_;
+ if (query_tree_ != NULL)
+ delete query_tree_;
+
}
/////////////////////////// Helper Functions //////////////////////
@@ -231,6 +258,7 @@
*/
double MaxNodeIP_(CTreeTypeA *query_node, TreeType* reference_node);
double MaxNodeIP_(CTreeTypeB *query_node, TreeType* reference_node);
+ double MaxNodeIP_(TreeTypeA *query_node, TreeType* reference_node);
/**
* Performs exhaustive computation at the leaves.
@@ -242,6 +270,7 @@
*/
void ComputeBaseCase_(CTreeTypeA* query_node, TreeType* reference_node);
void ComputeBaseCase_(CTreeTypeB* query_node, TreeType* reference_node);
+ void ComputeBaseCase_(TreeTypeA* query_node, TreeType* reference_node);
/**
* The recursive function
@@ -260,10 +289,18 @@
TreeType* reference_node,
double upper_bound_ip);
+ void ComputeNeighborsRecursion_(TreeTypeA* query_node,
+ TreeType* reference_node,
+ double upper_bound_ip);
+
void reset_tree_(CTreeTypeA *tree);
void reset_tree_(CTreeTypeB *tree);
+ void reset_tree_(TreeTypeA *tree);
void reset_tree_(TreeType *tree);
+
void set_angles_in_balls_(TreeType *tree);
+ void set_angles_in_balls_(TreeTypeA *tree);
+
void set_norms_in_cones_(CTreeTypeA *tree);
void set_norms_in_cones_(CTreeTypeB *tree);
@@ -288,6 +325,8 @@
delete query_tree_A_;
if (query_tree_B_ != NULL)
delete query_tree_B_;
+ if (query_tree_ != NULL)
+ delete query_tree_;
}
/**
@@ -317,6 +356,7 @@
void CheckPrune(CTreeTypeA* query_node, TreeType* ref_node);
void CheckPrune(CTreeTypeB* query_node, TreeType* ref_node);
+ void CheckPrune(TreeTypeA* query_node, TreeType* ref_node);
}; //class MaxIP
@@ -329,6 +369,8 @@
"maxip", 1);
PARAM_INT("leaf_size", "The leaf size for the ball-tree",
"maxip", 20);
+PARAM_DOUBLE("qleaf_size", "The leaf size for the query-tree",
+ "maxip", 0.0);
PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
" pruning using the angles as well", "maxip");
@@ -341,7 +383,11 @@
PARAM_FLAG("alt_tree", "The flag to trigger the "
"alternate query tree.",
"maxip");
+PARAM_FLAG("alt_tree2", "The flag to trigger the "
+ "yet another alternate query tree.",
+ "maxip");
+
PARAM_FLAG("check_prune", "The flag to trigger the "
"checking of the prune.", "maxip");
PARAM_FLAG("alt_dual_traversal", "The flag to trigger the "
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc 2011-11-08 18:07:01 UTC (rev 10186)
@@ -31,8 +31,8 @@
PARAM_FLAG("dofastexact", "The flag to trigger the tree-based"
" search algorithm", "");
-// PARAM_FLAG("print_results", "The flag to trigger the "
-// "printing of the output", "");
+PARAM_FLAG("check_results", "The flag to trigger the "
+ "checking of the output", "");
// PARAM_STRING("maxip_file", "The file where the output "
// "will be written into", "", "results.txt");
@@ -116,46 +116,48 @@
fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
Log::Info << "Tree-based Max IP Computed." << endl;
-// if (CLI::HasParam("print_results")) {
-// FILE *fp=fopen(CLI::GetParam<string>("maxip_file").c_str(), "w");
-// if (fp == NULL)
-// Log::Fatal << "Error while opening "
-// << CLI::GetParam<string>("maxip_file")
-// << endl;
-
-// for(size_t i = 0 ; i < exc.n_elem / knns ; i++) {
-// fprintf(fp, "%zu", i);
-// for(size_t j = 0; j < knns; j++)
-// fprintf(fp, ", %zu, %lg",
-// exc(i*knns+j), die(i*knns+j));
-// fprintf(fp, "\n");
-// }
-// fclose(fp);
-// }
- }
-
- if (CLI::HasParam("donaive") && CLI::HasParam("dofastexact")) {
- check_nn_utils::count_mismatched_neighbors(nac, din, exc, die);
Log::Warn << "Speed of fast-exact over naive: "
- << naive_comp << " / " << (float) fast_comp << " = "
- <<(float) (naive_comp / fast_comp) << endl;
- } else if (CLI::HasParam("dofastexact")) {
- Log::Warn << "Speed of fast-exact over naive: "
<< rdata.n_cols << " / " << (float) fast_comp << " = "
<<(float) (rdata.n_cols / fast_comp) << endl;
- if (CLI::GetParam<string>("rank_file") != "") {
- string rank_file = CLI::GetParam<string>("rank_file");
- check_nn_utils::compute_error(rank_file, rdata.n_cols, &exc);
- } else {
- check_nn_utils::compute_error(&rdata, &qdata, &exc);
+ // if (CLI::HasParam("print_results")) {
+ // FILE *fp=fopen(CLI::GetParam<string>("maxip_file").c_str(), "w");
+ // if (fp == NULL)
+ // Log::Fatal << "Error while opening "
+ // << CLI::GetParam<string>("maxip_file")
+ // << endl;
+
+ // for(size_t i = 0 ; i < exc.n_elem / knns ; i++) {
+ // fprintf(fp, "%zu", i);
+ // for(size_t j = 0; j < knns; j++)
+ // fprintf(fp, ", %zu, %lg",
+ // exc(i*knns+j), die(i*knns+j));
+ // fprintf(fp, "\n");
+ // }
+ // fclose(fp);
+ // }
+ }
+
+ if (CLI::HasParam("check_results")) {
+ if (CLI::HasParam("donaive") && CLI::HasParam("dofastexact")) {
+ check_nn_utils::count_mismatched_neighbors(nac, din, exc, die);
+ Log::Warn << "Speed of fast-exact over naive: "
+ << naive_comp << " / " << (float) fast_comp << " = "
+ <<(float) (naive_comp / fast_comp) << endl;
+ } else if (CLI::HasParam("dofastexact")) {
+ if (CLI::GetParam<string>("rank_file") != "") {
+ string rank_file = CLI::GetParam<string>("rank_file");
+ check_nn_utils::compute_error(rank_file, rdata.n_cols, &exc);
+ } else {
+ check_nn_utils::compute_error(&rdata, &qdata, &exc);
+ }
+ } else if (CLI::HasParam("donaive")) {
+ if (CLI::GetParam<string>("rank_file") != "") {
+ string rank_file = CLI::GetParam<string>("rank_file");
+ check_nn_utils::compute_error(rank_file, rdata.n_cols, &nac);
+ } else {
+ check_nn_utils::compute_error(&rdata, &qdata, &nac);
+ }
}
- } else if (CLI::HasParam("donaive")) {
- if (CLI::GetParam<string>("rank_file") != "") {
- string rank_file = CLI::GetParam<string>("rank_file");
- check_nn_utils::compute_error(rank_file, rdata.n_cols, &nac);
- } else {
- check_nn_utils::compute_error(&rdata, &qdata, &nac);
- }
}
}
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc 2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc 2011-11-08 18:07:01 UTC (rev 10186)
@@ -1,8 +1,10 @@
-#include <fastlib/fastlib.h>
+#include <armadillo>
+#include <string>
+
+#include <mlpack/core.h>
+
#include "exact_max_ip.h"
-#include <string>
-#include <armadillo>
using namespace mlpack;
using namespace std;
@@ -53,128 +55,103 @@
string qfile = CLI::GetParam<string>("q");
Log::Warn << "Loading files..." << endl;
- if (!data::Load(rfile.c_str(), rdata))
+ if (rdata.load(rfile.c_str()) == false)
Log::Fatal << "Reference file "<< rfile << " not found." << endl;
- if (!data::Load(qfile.c_str(), qdata))
+ if (qdata.load(qfile.c_str()) == false)
Log::Fatal << "Query file " << qfile << " not found." << endl;
+ rdata = arma::trans(rdata);
+ qdata = arma::trans(qdata);
+
Log::Warn << "File loaded..." << endl;
Log::Warn << "R(" << rdata.n_rows << ", " << rdata.n_cols
<< "), Q(" << qdata.n_rows << ", " << qdata.n_cols
<< ")" << endl;
+
+ arma::Col<size_t> ks(10);
+ size_t max_k = 1, number_of_ks = 1;
+ ks(0) = 1;
- if (CLI::HasParam("max_k")) {
+ if (CLI::GetParam<int>("max_k") != 1) {
- MaxIP naive, fast_exact;
- double naive_comp = (double) rdata.n_cols;
- arma::Mat<size_t> nac;
- arma::mat din;
+ max_k = CLI::GetParam<int>("max_k");
+ number_of_ks = max_k;
+ ks.set_size(number_of_ks);
- size_t max_k = CLI::GetParam<int>("max_k");
- arma::vec speedups(max_k);
+ for (size_t i = 0; i < max_k; i++)
+ ks(i) = i+1;
+ } else if (CLI::GetParam<string>("k_values") != "") {
- if (CLI::HasParam("check_nn")) {
- Log::Warn << "Starting naive computation..." <<endl;
- naive.InitNaive(qdata, rdata);
- naive.WarmInit(max_k);
- naive_comp = naive.ComputeNaive(&nac, &din);
- Log::Warn << "Naive computation done..." << endl;
- }
-
- Log::Warn << "Starting loop for Fast Exact Search." << endl;
-
- fast_exact.Init(qdata, rdata);
-
- for (knns = 1; knns <= max_k; knns++) {
-
- printf("k = %zu", knns); fflush(NULL);
- arma::Mat<size_t> exc;
- arma::mat die;
- double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
-
- if (CLI::HasParam("check_nn")) {
- size_t errors = count_mismatched_neighbors(din, max_k, die, knns);
-
- if (errors > 0) {
- Log::Warn << knns << "-NN error: " << errors << " / "
- << exc.n_elem << endl;
- }
- }
-
- speedups(knns -1) = naive_comp / fast_comp;
- printf(": %lg\n", speedups(knns -1));
-
- fast_exact.WarmInit(knns+1);
- }
-
- printf("\n");
- }
-
- if (CLI::HasParam("k_values")) {
-
- MaxIP naive, fast_exact;
- double naive_comp = (double) rdata.n_cols;
- arma::Col<size_t> nac;
- arma::vec din;
-
- size_t number_of_ks = 0;
+ number_of_ks = 0;
string k_values = CLI::GetParam<string>("k_values");
- arma::Col<size_t> ks(10);
- char *temp = (char *) k_values.c_str();
- char *pch = strtok(temp, ",");
+
+ char *pch = strtok((char *) k_values.c_str(), ",");
while (pch != NULL) {
ks(number_of_ks++) = atoi(pch);
pch = strtok(NULL, ",");
}
- arma::vec speedups(number_of_ks);
+ free(pch);
+ max_k = ks(number_of_ks -1);
+ }
- if (CLI::HasParam("check_nn")) {
- Log::Warn << "Starting naive computation..." <<endl;
- naive.InitNaive(qdata, rdata);
- naive.WarmInit(ks(number_of_ks - 1));
- naive_comp = naive.ComputeNaive(&nac, &din);
- Log::Warn << "Naive computation done..." << endl;
- }
+ MaxIP naive, fast_exact;
+ double naive_comp = (double) rdata.n_cols;
+ arma::Mat<size_t> nac;
+ arma::mat din;
- Log::Warn << "Starting loop for Fast Exact Search." << endl;
+ arma::vec speedups(number_of_ks);
- fast_exact.Init(qdata, rdata);
+ if (CLI::HasParam("check_nn")) {
+ Log::Warn << "Starting naive computation..." <<endl;
+ naive.InitNaive(qdata, rdata);
+ naive.WarmInit(ks(number_of_ks - 1));
+ naive_comp = naive.ComputeNaive(&nac, &din);
+ Log::Warn << "Naive computation done..." << endl;
+ }
- for (size_t knns = 0; knns < number_of_ks; knns++) {
+ Log::Warn << "Starting loop for Fast Exact Search." << endl;
- printf("k = %zu", ks(knns)); fflush(NULL);
- arma::Col<size_t> exc;
- arma::vec die;
- double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
+ fast_exact.Init(qdata, rdata);
- if (CLI::HasParam("check_nn")) {
- size_t errors = count_mismatched_neighbors(din, ks(number_of_ks -1),
- die, ks(knns));
+ for (size_t knns = 0; knns < number_of_ks; knns++) {
+
+ arma::Mat<size_t> exc;
+ arma::mat die;
+ double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
- if (errors > 0) {
- Log::Warn << ks(knns) << "-NN error: " << errors << " / "
- << exc.n_elem << endl;
- }
- }
+// printf("Search done %lg\n", fast_comp); fflush(NULL);
- speedups(knns -1) = naive_comp / fast_comp;
- printf(": %lg\n", speedups(knns -1));
+ if (CLI::HasParam("check_nn")) {
+ size_t errors = count_mismatched_neighbors(din, ks(number_of_ks -1),
+ die, ks(knns));
- if (knns < number_of_ks - 1)
- fast_exact.WarmInit(ks(knns+1));
+ if (errors > 0) {
+ Log::Warn << ks(knns) << "-NN error: " << errors << " / "
+ << exc.n_elem << endl;
+ }
}
- printf("\n");
+ speedups(knns) = naive_comp / fast_comp;
+ printf("k = %zu", ks(knns)); fflush(NULL);
+ printf(": %lg\n", speedups(knns)); fflush(NULL);
+
+ if (knns < number_of_ks - 1)
+ fast_exact.WarmInit(ks(knns+1));
+
+ exc.reset();
+ die.reset();
}
+
+ printf("\n");
Log::Warn << "Search completed for all values of k...printing results now"
- << endl;
+ << endl;
// if (CLI::HasParam("print_speedups")) {
// string speedup_file = CLI::GetParam<string>("speedup_file");
More information about the mlpack-svn
mailing list