[mlpack-svn] r10186 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Nov 8 13:07:02 EST 2011


Author: pram
Date: 2011-11-08 13:07:01 -0500 (Tue, 08 Nov 2011)
New Revision: 10186

Modified:
   mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
   mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc
Log:
MaxIP search commit after submission, need to clean up code and add approxMaxIP and fix the dual tree

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-11-08 18:07:01 UTC (rev 10186)
@@ -112,4 +112,13 @@
 target_link_libraries(rank_inverter
    mlpack
    mlpack_contrib
+)
+
+add_executable(new_pair_dist
+   EXCLUDE_FROM_ALL
+   new_pairwise_dists.cc
+)
+target_link_libraries(new_pair_dist
+   mlpack
+   mlpack_contrib
 )
\ No newline at end of file

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/check_nn_utils.h	2011-11-08 18:07:01 UTC (rev 10186)
@@ -236,8 +236,6 @@
     } // query-loop
 
 
-    fclose(rank_fp);
-
     double pdone = 100;
 
     if (pdone >= done_sky * perc_done) {
@@ -253,6 +251,7 @@
     Log::Info << "Errors Computed!" << endl;
 
 
+
 //     double avg_precision = (double) all_correct 
 //       / (double) (k * indices->n_cols);
 
@@ -281,6 +280,8 @@
 	      << "TCC: " << total_considerable_candidates
 	      << endl;
 
+    fclose(rank_fp);
+
   }
 
 

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-11-08 18:07:01 UTC (rev 10186)
@@ -20,7 +20,7 @@
   arma::vec centroid = reference_node->bound().center();
 
   // +1: Can be cached in the reference tree
-  double c_norm = reference_node->stat().dist_to_origin();
+  double c_norm = reference_node->stat().center_norm();
 
   assert(arma::norm(q, 2) == query_norms_(query_));
 
@@ -96,7 +96,7 @@
 
   arma::vec centroid = reference_node->bound().center();
 
-  double c_norm = reference_node->stat().dist_to_origin();
+  double c_norm = reference_node->stat().center_norm();
   double rad = std::sqrt(reference_node->bound().radius());
 
   double max_cos_qp = 1.0;
@@ -186,7 +186,7 @@
 
   arma::vec centroid = reference_node->bound().center();
 
-  double c_norm = reference_node->stat().dist_to_origin();
+  double c_norm = reference_node->stat().center_norm();
   double rad = std::sqrt(reference_node->bound().radius());
 
   if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) { 
@@ -240,6 +240,65 @@
 }
 
 
+// PLAIN OLD METRIC TREE -- returns actual IP (not |p| cos pq)
+double MaxIP::MaxNodeIP_(TreeTypeA* query_node,
+			 TreeType* reference_node) {
+
+  // counting the split decisions 
+  split_decisions_++;
+
+  arma::vec q_centroid = query_node->bound().center();
+  double q_rad = std::sqrt(query_node->bound().radius());
+  double q_norm = query_node->stat().center_norm();
+
+  arma::vec centroid = reference_node->bound().center();
+  double c_norm = reference_node->stat().center_norm();
+  double rad = std::sqrt(reference_node->bound().radius());
+
+  // +1
+  double q_dot_r =  arma::dot(q_centroid, centroid);
+  double cos_phi = q_dot_r / (c_norm * q_norm);
+
+  double min_cos = 1.0; 
+
+  if (cos_phi < query_node->stat().cosine_origin()
+      && cos_phi < reference_node->stat().cosine_origin()) {
+
+    double cos_w_qr = query_node->stat().cosine_origin()
+      * reference_node->stat().cosine_origin()
+      - query_node->stat().sine_origin() * reference_node->stat().sine_origin();
+
+    if (cos_w_qr > cos_phi) { 
+      // balls do not intersect
+      // in that case, compute cos(phi - w_q - w_r)
+
+      double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+      double sin_w_qr = std::sqrt(1 - cos_w_qr * cos_w_qr);
+
+      min_cos = cos_phi * cos_w_qr + sin_phi * sin_w_qr;
+    } 
+    // else they do intersect so min_cos = 1.0
+  } // same as above
+
+  if (min_cos > 0.0) {
+    // there is some potential
+
+    double k = (rad * q_norm) / (q_rad * c_norm);
+
+    double term3 = (k + cos_phi) / std::sqrt(k * k + 2 * k * cos_phi + 1);
+    double term4 = (q_rad * c_norm + rad * q_norm * cos_phi) / (k + cos_phi)
+      + q_norm * rad;
+
+    return (q_dot_r + rad * q_rad + term3 * term4);
+
+  } else {
+
+    return 0.0;
+
+  }
+}
+
+
 void MaxIP::ComputeBaseCase_(TreeType* reference_node) {
    
   assert(reference_node != NULL);
@@ -431,7 +490,48 @@
 
 } // ComputeBaseCase_
   
+// PLAIN OLD METRIC TREE
+void MaxIP::ComputeBaseCase_(TreeTypeA* query_node, 
+			     TreeType* reference_node) {
 
+  // Check that the pointers are not NULL
+  assert(reference_node != NULL);
+  assert(reference_node->is_leaf());
+  assert(query_node != NULL);
+
+  // query node may not be a leaf
+  //  assert(query_node->is_leaf());
+
+  // Used to find the query node's new lower bound
+  double query_worst_ip = DBL_MAX;
+  bool new_bound = false;
+    
+  // Iterating over the queries individually
+  for (query_ = query_node->begin();
+       query_ < query_node->end(); query_++) {
+
+    // checking if this node has potential
+    double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+    if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+      // this node has potential
+      ComputeBaseCase_(reference_node);
+
+    double min_ip = max_ips_(knns_ -1, query_);
+
+    if (query_worst_ip > min_ip) {
+      query_worst_ip = min_ip;
+      new_bound = true;
+    }
+  } // for query_
+  
+  // Update the lower bound for the query_node
+  if (new_bound) 
+    query_node->stat().set_bound(query_worst_ip);
+
+} // ComputeBaseCase_
+
+
 // CONE TREE
 void MaxIP::CheckPrune(CTreeTypeA* query_node, TreeType* ref_node) {
 
@@ -526,7 +626,53 @@
 
 }
 
+// PLAIN OLD METRIC TREE
+void MaxIP::CheckPrune(TreeTypeA* query_node, TreeType* ref_node) {
 
+  size_t missed_nns = 0;
+  double max_ip = 0.0;
+  double min_ip = DBL_MAX;
+
+  // Iterating over the queries individually
+  for (query_ = query_node->begin();
+       query_ < query_node->end(); query_++) {
+
+    // Get the query point from the matrix
+    arma::vec q = queries_.unsafe_col(query_);
+
+    double ip1 = max_ips_(knns_ -1, query_);
+    if (min_ip > ip1)
+      min_ip = ip1;
+
+    // We'll do the same for the references
+    for (size_t reference_index = ref_node->begin(); 
+	 reference_index < ref_node->end(); reference_index++) {
+
+      arma::vec r = references_.unsafe_col(reference_index);
+
+      double ip = arma::dot(q, r);
+      if (ip > max_ips_(knns_-1, query_))
+	missed_nns++;
+
+      if (ip > max_ip)
+	max_ip = ip;
+      
+    } // for reference_index
+  } // for query_
+  
+  if (missed_nns > 0 || query_node->stat().bound() != min_ip) {
+    printf("Prune %zu - Missed candidates: %zu\n"
+	   "QLBound: %lg, ActualQLBound: %lg\n"
+	   "QRBound: %lg, ActualQRBound: %lg\n",
+	   number_of_prunes_, missed_nns,
+	   query_node->stat().bound(), min_ip, 
+	   MaxNodeIP_(query_node, ref_node), max_ip);
+    exit(0);
+  }
+
+}
+
+
 // CONE TREE
 void MaxIP::ComputeNeighborsRecursion_(CTreeTypeA* query_node,
 				       TreeType* reference_node, 
@@ -549,24 +695,57 @@
     ComputeBaseCase_(query_node, reference_node);
   } else if (query_node->is_leaf()) {
     // Only query is a leaf
+   
+    if (CLI::HasParam("maxip/alt_dual_traversal")) {
+      // Trying to do single-tree on the leaves
+
+      // Used to find the query node's new lower bound
+      double query_worst_p_cos_pq = DBL_MAX;
+      bool new_bound = false;
+    
+      // Iterating over the queries individually
+      for (query_ = query_node->begin();
+	   query_ < query_node->end(); query_++) {
+
+	// checking if this node has potential
+	double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+	if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+	  // this node has potential
+	  ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+	double p_cos_pq = max_ips_(knns_ -1, query_)
+	  / query_norms_(query_);
+
+	if (query_worst_p_cos_pq > p_cos_pq) {
+	  query_worst_p_cos_pq = p_cos_pq;
+	  new_bound = true;
+	}
+      } // for query_
+  
+      // Update the lower bound for the query_node
+      if (new_bound) 
+	query_node->stat().set_bound(query_worst_p_cos_pq);
+    } else {
+
+      // We'll order the computation by distance 
+      double left_p_cos_pq = MaxNodeIP_(query_node,
+					reference_node->left());
+      double right_p_cos_pq = MaxNodeIP_(query_node,
+					 reference_node->right());
       
-    // We'll order the computation by distance 
-    double left_p_cos_pq = MaxNodeIP_(query_node,
-				      reference_node->left());
-    double right_p_cos_pq = MaxNodeIP_(query_node,
-				       reference_node->right());
-      
-    if (left_p_cos_pq > right_p_cos_pq) {
-      ComputeNeighborsRecursion_(query_node, reference_node->left(), 
-				 left_p_cos_pq);
-      ComputeNeighborsRecursion_(query_node, reference_node->right(), 
-				 right_p_cos_pq);
-    } else {
-      ComputeNeighborsRecursion_(query_node, reference_node->right(), 
-				 right_p_cos_pq);
-      ComputeNeighborsRecursion_(query_node, reference_node->left(), 
-				 left_p_cos_pq);
-    }
+      if (left_p_cos_pq > right_p_cos_pq) {
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_p_cos_pq);
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_p_cos_pq);
+      } else {
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_p_cos_pq);
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_p_cos_pq);
+      }
+    } // alt-traversal
   } else if (reference_node->is_leaf()) {
     // Only reference is a leaf 
     double left_p_cos_pq
@@ -585,104 +764,52 @@
 					  query_node->right()->stat().bound()));
   } else {
 
-    if (CLI::HasParam("maxip/alt_dual_traversal")) {
-      // try something new 
-      // if query_node->radius() > reference_node->stat().cosine_origin()
-      // traverse down the reference tree
-      // else
-      // traverse down the query tree
-
-      if (query_node->bound().radius() 
-	  > reference_node->stat().cosine_origin()) {
-
-	// go down the reference tree
-	double left_p_cos_pq = MaxNodeIP_(query_node, 
-					  reference_node->left());
-	double right_p_cos_pq = MaxNodeIP_(query_node, 
-					   reference_node->right());
+    // Recurse on both as above
+    double left_p_cos_pq = MaxNodeIP_(query_node->left(), 
+				      reference_node->left());
+    double right_p_cos_pq = MaxNodeIP_(query_node->left(), 
+				       reference_node->right());
       
-	if (left_p_cos_pq > right_p_cos_pq) {
-	  ComputeNeighborsRecursion_(query_node,
-				     reference_node->left(), 
-				     left_p_cos_pq);
-	  ComputeNeighborsRecursion_(query_node,
-				     reference_node->right(), 
-				     right_p_cos_pq);
-	} else {
-	  ComputeNeighborsRecursion_(query_node,
-				     reference_node->right(), 
-				     right_p_cos_pq);
-	  ComputeNeighborsRecursion_(query_node,
-				     reference_node->left(), 
-				     left_p_cos_pq);
-	}
-      } else {
-      
-	// go down the query tree
-	double left_p_cos_pq = MaxNodeIP_(query_node->left(),
-					  reference_node);
-	double right_p_cos_pq = MaxNodeIP_(query_node->right(), 
-					   reference_node);
-      
-	ComputeNeighborsRecursion_(query_node->left(),
-				   reference_node, 
-				   left_p_cos_pq);
-	ComputeNeighborsRecursion_(query_node->right(),
-				   reference_node, 
-				   right_p_cos_pq);
+    if (left_p_cos_pq > right_p_cos_pq) {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+    } else {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+    }
 
-	// Update the upper bound as above
-	query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
-					      query_node->right()->stat().bound()));
-      }
-    }  else {
-      // Recurse on both as above
-      double left_p_cos_pq = MaxNodeIP_(query_node->left(), 
-					reference_node->left());
-      double right_p_cos_pq = MaxNodeIP_(query_node->left(), 
-					 reference_node->right());
+    left_p_cos_pq = MaxNodeIP_(query_node->right(),
+			       reference_node->left());
+    right_p_cos_pq = MaxNodeIP_(query_node->right(), 
+				reference_node->right());
       
-      if (left_p_cos_pq > right_p_cos_pq) {
-	ComputeNeighborsRecursion_(query_node->left(),
-				   reference_node->left(), 
-				   left_p_cos_pq);
-	ComputeNeighborsRecursion_(query_node->left(),
-				   reference_node->right(), 
-				   right_p_cos_pq);
-      } else {
-	ComputeNeighborsRecursion_(query_node->left(),
-				   reference_node->right(), 
-				   right_p_cos_pq);
-	ComputeNeighborsRecursion_(query_node->left(),
-				   reference_node->left(), 
-				   left_p_cos_pq);
-      }
-
-      left_p_cos_pq = MaxNodeIP_(query_node->right(),
-				 reference_node->left());
-      right_p_cos_pq = MaxNodeIP_(query_node->right(), 
-				  reference_node->right());
+    if (left_p_cos_pq > right_p_cos_pq) {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+    } else {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+    }
       
-      if (left_p_cos_pq > right_p_cos_pq) {
-	ComputeNeighborsRecursion_(query_node->right(),
-				   reference_node->left(), 
-				   left_p_cos_pq);
-	ComputeNeighborsRecursion_(query_node->right(),
-				   reference_node->right(), 
-				   right_p_cos_pq);
-      } else {
-	ComputeNeighborsRecursion_(query_node->right(),
-				   reference_node->right(), 
-				   right_p_cos_pq);
-	ComputeNeighborsRecursion_(query_node->right(),
-				   reference_node->left(), 
-				   left_p_cos_pq);
-      }
-      
-      // Update the upper bound as above
-      query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
-					    query_node->right()->stat().bound()));
-    } // alt-traversal
+    // Update the upper bound as above
+    query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+					  query_node->right()->stat().bound()));
   } // All cases of dual-tree traversal
 } // ComputeNeighborsRecursion_
 
@@ -830,7 +957,147 @@
 } // ComputeNeighborsRecursion_
 
 
+// PLAIN OLD METRIC TREE
+void MaxIP::ComputeNeighborsRecursion_(TreeTypeA* query_node,
+				       TreeType* reference_node, 
+				       double upper_bound_ip) {
 
+  assert(query_node != NULL);
+  assert(reference_node != NULL);
+  //assert(upper_bound_ip == MaxNodeIP_(query_node, reference_node));
+
+  if (upper_bound_ip < query_node->stat().bound()) { 
+    // Pruned
+    number_of_prunes_ += query_node->end() - query_node->begin();
+
+    if (CLI::HasParam("maxip/check_prune"))
+      CheckPrune(query_node, reference_node);
+  }
+  // node->is_leaf() works as one would expect
+  else if (query_node->is_leaf() && reference_node->is_leaf()) {
+    // Base Case
+    ComputeBaseCase_(query_node, reference_node);
+  } else if (query_node->is_leaf()) {
+    // Only query is a leaf
+      
+    if (CLI::HasParam("maxip/alt_dual_traversal")) {
+      // Trying to do single-tree on the leaves
+
+      // Used to find the query node's new lower bound
+      double query_worst_ip = DBL_MAX;
+      bool new_bound = false;
+    
+      // Iterating over the queries individually
+      for (query_ = query_node->begin();
+	   query_ < query_node->end(); query_++) {
+
+	// checking if this node has potential
+	double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+	if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+	  // this node has potential
+	  ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+	double ip = max_ips_(knns_ -1, query_); 
+	
+	if (query_worst_ip > ip) {
+	  query_worst_ip = ip;
+	  new_bound = true;
+	}
+      } // for query_
+  
+      // Update the lower bound for the query_node
+      if (new_bound) 
+	query_node->stat().set_bound(query_worst_ip);
+    } else {
+
+      // We'll order the computation by distance 
+      double left_ip = MaxNodeIP_(query_node,
+				  reference_node->left());
+      double right_ip = MaxNodeIP_(query_node,
+				   reference_node->right());
+      
+      if (left_ip > right_ip) {
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_ip);
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_ip);
+      } else {
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_ip);
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_ip);
+      }
+    } // alt-traversal
+  } else if (reference_node->is_leaf()) {
+    // Only reference is a leaf 
+    double left_ip
+      = MaxNodeIP_(query_node->left(), reference_node);
+    double right_ip
+      = MaxNodeIP_(query_node->right(), reference_node);
+      
+    ComputeNeighborsRecursion_(query_node->left(), reference_node, 
+			       left_ip);
+    ComputeNeighborsRecursion_(query_node->right(), reference_node, 
+			       right_ip);
+      
+    // We need to update the upper bound based on the new upper bounds of 
+    // the children
+    query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+					  query_node->right()->stat().bound()));
+  } else {
+
+    // Recurse on both as above
+    double left_ip = MaxNodeIP_(query_node->left(), 
+				reference_node->left());
+    double right_ip = MaxNodeIP_(query_node->left(), 
+				 reference_node->right());
+      
+    if (left_ip > right_ip) {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_ip);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_ip);
+    } else {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_ip);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_ip);
+    }
+
+    left_ip = MaxNodeIP_(query_node->right(),
+			 reference_node->left());
+    right_ip = MaxNodeIP_(query_node->right(), 
+			  reference_node->right());
+      
+    if (left_ip > right_ip) {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_ip);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_ip);
+    } else {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_ip);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_ip);
+    }
+      
+    // Update the upper bound as above
+    query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+					  query_node->right()->stat().bound()));
+  } // All cases of dual-tree traversal
+} // ComputeNeighborsRecursion_
+
+
+
 void MaxIP::Init(const arma::mat& queries_in,
 		 const arma::mat& references_in) {
     
@@ -877,22 +1144,41 @@
    
   if (mlpack::CLI::HasParam("maxip/dual_tree")) {
 
+    size_t qleaf_size;
+
+    if (mlpack::CLI::GetParam<double>("maxip/qleaf_size") != 0.0) 
+
+      qleaf_size = (size_t) (mlpack::CLI::GetParam<double>("maxip/qleaf_size")
+			     * (double) queries_.n_cols) / 100.0;
+    else 
+      qleaf_size = leaf_size_ - 1;
+
     if (mlpack::CLI::HasParam("maxip/alt_tree")) { 
       
       // using cosine tree 
       query_tree_B_
 	= proximity::MakeGenCosineTree<CTreeTypeB>(queries_,
-						   leaf_size_,
+						   qleaf_size + 1,
 						   &old_from_new_queries_,
 						   NULL);
       set_norms_in_cones_(query_tree_B_);
     
+    } else if (mlpack::CLI::HasParam("maxip/alt_tree2")) { 
+      
+      // using plain old Metric tree
+      query_tree_
+	= proximity::MakeGenMetricTree<TreeTypeA>(queries_,
+						  qleaf_size + 1,
+						  &old_from_new_queries_,
+						  NULL);
+      set_angles_in_balls_(query_tree_);
+    
     } else {
       
       // using cone tree
       query_tree_A_
 	= proximity::MakeGenConeTree<CTreeTypeA>(queries_,
-						 leaf_size_,
+						 qleaf_size + 1,
 						 &old_from_new_queries_,
 						 NULL);
       set_norms_in_cones_(query_tree_A_);
@@ -986,12 +1272,14 @@
 
       reset_tree_(reference_tree_);
     
+    } else if (mlpack::CLI::HasParam("maxip/alt_tree2")) {
+      if (query_tree_ != NULL) 
+	reset_tree_(query_tree_);
+    
     } else {
-	if (query_tree_A_ != NULL)
-	  reset_tree_(query_tree_A_);
+      if (query_tree_A_ != NULL)
+	reset_tree_(query_tree_A_);
     }
-
-
 } // WarmInit
 
 void MaxIP::reset_tree_(CTreeTypeA* tree) {
@@ -1019,7 +1307,19 @@
   return;
 } // reset_tree_
 
+void MaxIP::reset_tree_(TreeTypeA* tree) {
+  assert(tree != NULL);
+  tree->stat().set_bound(0.0);
 
+  if (!tree->is_leaf()) {
+    reset_tree_(tree->left());
+    reset_tree_(tree->right());
+  }
+
+  return;
+} // reset_tree_
+
+
 void MaxIP::reset_tree_(TreeType* tree) {
   assert(tree != NULL);
   tree->stat().reset();
@@ -1040,7 +1340,7 @@
   double c_norm = arma::norm(tree->bound().center(), 2);
   double rad = std::sqrt(tree->bound().radius());
   
-  tree->stat().set_dist_to_origin(c_norm);
+  tree->stat().set_center_norm(c_norm);
   if (rad <= c_norm)
     tree->stat().set_angles(rad / c_norm, (size_t) 1);
   else
@@ -1055,7 +1355,30 @@
   return;
 } // set_angles_in_balls_
 
+void MaxIP::set_angles_in_balls_(TreeTypeA* tree) {
 
+  assert(tree != NULL);
+
+  // set up node stats
+  double c_norm = arma::norm(tree->bound().center(), 2);
+  double rad = std::sqrt(tree->bound().radius());
+  
+  tree->stat().set_center_norm(c_norm);
+  if (rad <= c_norm)
+    tree->stat().set_angles(rad / c_norm, (size_t) 1);
+  else
+    tree->stat().set_angles(-1.0, (size_t) 0);
+
+  // traverse down the children
+  if (!tree->is_leaf()) {
+    set_angles_in_balls_(tree->left());
+    set_angles_in_balls_(tree->right());
+  }
+
+  return;
+} // set_angles_in_balls_
+
+
 void MaxIP::set_norms_in_cones_(CTreeTypeA* tree) {
 
   assert(tree != NULL);
@@ -1089,8 +1412,6 @@
 } // set_norms_in_cones_
 
 
-
-
 double MaxIP::ComputeNeighbors(arma::Mat<size_t>* resulting_neighbors,
 			       arma::mat* ips) {
 
@@ -1106,6 +1427,9 @@
     if (mlpack::CLI::HasParam("maxip/alt_tree"))
       ComputeNeighborsRecursion_(query_tree_B_, reference_tree_,
 				 MaxNodeIP_(query_tree_B_, reference_tree_));
+    else if (mlpack::CLI::HasParam("maxip/alt_tree2"))
+      ComputeNeighborsRecursion_(query_tree_, reference_tree_,
+				 MaxNodeIP_(query_tree_, reference_tree_));
     else
       ComputeNeighborsRecursion_(query_tree_A_, reference_tree_,
 				 MaxNodeIP_(query_tree_A_, reference_tree_));

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-11-08 18:07:01 UTC (rev 10186)
@@ -35,10 +35,14 @@
   private:
     double bound_;
     double center_norm_;
+    double cosine_origin_;
+    double sine_origin_;
 
   public:
     double bound() { return bound_; }
     double center_norm() { return center_norm_; }
+    double cosine_origin() { return cosine_origin_; }
+    double sine_origin() { return sine_origin_; }
 
     void set_bound(double bound) { 
       bound_ = bound;
@@ -48,9 +52,21 @@
       center_norm_ = val; 
     }
 
+   void set_angles(double val, size_t type = 0) { 
+      if (type == 0) { // given value is the cosine
+	cosine_origin_ = val;
+	sine_origin_ = std::sqrt(1 - val * val);
+      } else { // the given value is the sine
+	sine_origin_ = val;
+	cosine_origin_ = std::sqrt(1 - val * val);
+      }
+    }
+
     QueryStat() {
       bound_ = 0.0;
       center_norm_ = 0.0;
+      cosine_origin_ = 0.0;
+      sine_origin_ = 0.0;
     }
 
     ~QueryStat() {}
@@ -58,27 +74,31 @@
     void Init(const arma::mat& data, size_t begin, size_t count) {
       bound_ = 0.0;
       center_norm_ = 0.0;
+      cosine_origin_ = 0.0;
+      sine_origin_ = 0.0;
     }
 
     void Init(const arma::mat& data, size_t begin, size_t count,
 	      QueryStat& left_stat, QueryStat& right_stat) {
       bound_ = 0.0;
       center_norm_ = 0.0;
+      cosine_origin_ = 0.0;
+      sine_origin_ = 0.0;
     }
   }; // QueryStat
 
   class RefStat {
   private:
+    double center_norm_;
     double cosine_origin_;
     double sine_origin_;
-    double dist_to_origin_;
     bool has_cos_phi_;
     double cos_phi_;
 
   public:
     double cosine_origin() { return cosine_origin_; }
     double sine_origin() { return sine_origin_; }
-    double dist_to_origin() { return dist_to_origin_; }
+    double center_norm() { return center_norm_; }
     bool has_cos_phi() { return has_cos_phi_; }
     double cos_phi() { return cos_phi_; }
 
@@ -93,8 +113,8 @@
       }
     }
 
-    void set_dist_to_origin(double val) {
-      dist_to_origin_ = val;
+    void set_center_norm(double val) {
+      center_norm_ = val;
     }
 
     void set_cos_phi(double cos_phi) {
@@ -111,7 +131,7 @@
     RefStat() {
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
-      dist_to_origin_ = 0.0;
+      center_norm_ = 0.0;
       has_cos_phi_ = false;
       cos_phi_ = -1.0;
     }
@@ -120,7 +140,7 @@
     void Init(const arma::mat& data, size_t begin, size_t count) {
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
-      dist_to_origin_ = 0.0;
+      center_norm_ = 0.0;
       has_cos_phi_ = false;
       cos_phi_ = -1.0;
     }
@@ -129,7 +149,7 @@
 	      RefStat& left_stat, RefStat& right_stat) {
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
-      dist_to_origin_ = 0.0;
+      center_norm_ = 0.0;
       has_cos_phi_ = false;
       cos_phi_ = -1.0;
     }
@@ -141,6 +161,7 @@
   typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
   typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeTypeA;
   typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat, QueryStat> CTreeTypeB;
+  typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, QueryStat> TreeTypeA;
    
   
   /////////////////////////////// Members ////////////////////////////
@@ -158,6 +179,8 @@
   CTreeTypeA* query_tree_A_;
   CTreeTypeB* query_tree_B_;
 
+  TreeTypeA* query_tree_;
+
   // The total number of prunes.
   size_t number_of_prunes_;
   size_t ball_has_origin_;
@@ -196,6 +219,7 @@
     reference_tree_ = NULL;
     query_tree_A_ = NULL;
     query_tree_B_ = NULL;
+    query_tree_ = NULL;
   } 
   
   /**
@@ -212,6 +236,9 @@
     if (query_tree_B_ != NULL)
       delete query_tree_B_;
 
+    if (query_tree_ != NULL)
+      delete query_tree_;
+
   }
     
   /////////////////////////// Helper Functions //////////////////////
@@ -231,6 +258,7 @@
    */
   double MaxNodeIP_(CTreeTypeA *query_node, TreeType* reference_node);
   double MaxNodeIP_(CTreeTypeB *query_node, TreeType* reference_node);
+  double MaxNodeIP_(TreeTypeA *query_node, TreeType* reference_node);
 
   /**
    * Performs exhaustive computation at the leaves.  
@@ -242,6 +270,7 @@
    */
   void ComputeBaseCase_(CTreeTypeA* query_node, TreeType* reference_node);
   void ComputeBaseCase_(CTreeTypeB* query_node, TreeType* reference_node);
+  void ComputeBaseCase_(TreeTypeA* query_node, TreeType* reference_node);
   
   /**
    * The recursive function
@@ -260,10 +289,18 @@
 				  TreeType* reference_node, 
 				  double upper_bound_ip);
 
+  void ComputeNeighborsRecursion_(TreeTypeA* query_node,
+				  TreeType* reference_node, 
+				  double upper_bound_ip);
+
   void reset_tree_(CTreeTypeA *tree);
   void reset_tree_(CTreeTypeB *tree);
+  void reset_tree_(TreeTypeA *tree);
   void reset_tree_(TreeType *tree);
+
   void set_angles_in_balls_(TreeType *tree);
+  void set_angles_in_balls_(TreeTypeA *tree);
+
   void set_norms_in_cones_(CTreeTypeA *tree);
   void set_norms_in_cones_(CTreeTypeB *tree);
 
@@ -288,6 +325,8 @@
       delete query_tree_A_;
     if (query_tree_B_ != NULL)
       delete query_tree_B_;
+    if (query_tree_ != NULL)
+      delete query_tree_;
   }
 
   /**
@@ -317,6 +356,7 @@
 
   void CheckPrune(CTreeTypeA* query_node, TreeType* ref_node);
   void CheckPrune(CTreeTypeB* query_node, TreeType* ref_node);
+  void CheckPrune(TreeTypeA* query_node, TreeType* ref_node);
 }; //class MaxIP
 
 
@@ -329,6 +369,8 @@
 	  "maxip", 1);
 PARAM_INT("leaf_size", "The leaf size for the ball-tree", 
 	  "maxip", 20);
+PARAM_DOUBLE("qleaf_size", "The leaf size for the query-tree", 
+	     "maxip", 0.0);
 
 PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
 	   " pruning using the angles as well", "maxip");
@@ -341,7 +383,11 @@
 PARAM_FLAG("alt_tree", "The flag to trigger the "
 	   "alternate query tree.",
 	   "maxip");
+PARAM_FLAG("alt_tree2", "The flag to trigger the "
+	   "yet another alternate query tree.",
+	   "maxip");
 
+
 PARAM_FLAG("check_prune", "The flag to trigger the "
 	   "checking of the prune.", "maxip");
 PARAM_FLAG("alt_dual_traversal", "The flag to trigger the "

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc	2011-11-08 18:07:01 UTC (rev 10186)
@@ -31,8 +31,8 @@
 PARAM_FLAG("dofastexact", "The flag to trigger the tree-based"
 	   " search algorithm", "");
 
-// PARAM_FLAG("print_results", "The flag to trigger the "
-// 	   "printing of the output", "");
+PARAM_FLAG("check_results", "The flag to trigger the "
+ 	   "checking of the output", "");
 
 // PARAM_STRING("maxip_file", "The file where the output "
 // 	     "will be written into", "", "results.txt");
@@ -116,46 +116,48 @@
     fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
     Log::Info << "Tree-based Max IP Computed." << endl;
 
-//     if (CLI::HasParam("print_results")) {
-//       FILE *fp=fopen(CLI::GetParam<string>("maxip_file").c_str(), "w");
-//       if (fp == NULL)
-// 	Log::Fatal << "Error while opening " 
-// 		  << CLI::GetParam<string>("maxip_file") 
-// 		  << endl;
-
-//       for(size_t i = 0 ; i < exc.n_elem / knns ; i++) {
-//         fprintf(fp, "%zu", i);
-//         for(size_t j = 0; j < knns; j++)
-//           fprintf(fp, ", %zu, %lg", 
-//                   exc(i*knns+j), die(i*knns+j));
-//         fprintf(fp, "\n");
-//       }
-//       fclose(fp);
-//     }
-  }
-
-  if (CLI::HasParam("donaive") && CLI::HasParam("dofastexact")) {
-    check_nn_utils::count_mismatched_neighbors(nac, din, exc, die);
     Log::Warn << "Speed of fast-exact over naive: "
-	      << naive_comp << " / " << (float) fast_comp << " = "
-	      <<(float) (naive_comp / fast_comp) << endl;
-  } else if (CLI::HasParam("dofastexact")) {
-    Log::Warn << "Speed of fast-exact over naive: "
 	      << rdata.n_cols  << " / " << (float) fast_comp << " = "
 	      <<(float) (rdata.n_cols / fast_comp) << endl;
 
-    if (CLI::GetParam<string>("rank_file") != "") {
-      string rank_file = CLI::GetParam<string>("rank_file");
-      check_nn_utils::compute_error(rank_file,  rdata.n_cols, &exc);
-    } else {
-      check_nn_utils::compute_error(&rdata, &qdata, &exc);
+    //     if (CLI::HasParam("print_results")) {
+    //       FILE *fp=fopen(CLI::GetParam<string>("maxip_file").c_str(), "w");
+    //       if (fp == NULL)
+    // 	Log::Fatal << "Error while opening " 
+    // 		  << CLI::GetParam<string>("maxip_file") 
+    // 		  << endl;
+
+    //       for(size_t i = 0 ; i < exc.n_elem / knns ; i++) {
+    //         fprintf(fp, "%zu", i);
+    //         for(size_t j = 0; j < knns; j++)
+    //           fprintf(fp, ", %zu, %lg", 
+    //                   exc(i*knns+j), die(i*knns+j));
+    //         fprintf(fp, "\n");
+    //       }
+    //       fclose(fp);
+    //     }
+  }
+
+  if (CLI::HasParam("check_results")) {
+    if (CLI::HasParam("donaive") && CLI::HasParam("dofastexact")) {
+      check_nn_utils::count_mismatched_neighbors(nac, din, exc, die);
+      Log::Warn << "Speed of fast-exact over naive: "
+		<< naive_comp << " / " << (float) fast_comp << " = "
+		<<(float) (naive_comp / fast_comp) << endl;
+    } else if (CLI::HasParam("dofastexact")) {
+      if (CLI::GetParam<string>("rank_file") != "") {
+	string rank_file = CLI::GetParam<string>("rank_file");
+	check_nn_utils::compute_error(rank_file,  rdata.n_cols, &exc);
+      } else {
+	check_nn_utils::compute_error(&rdata, &qdata, &exc);
+      }
+    } else if (CLI::HasParam("donaive")) {
+      if (CLI::GetParam<string>("rank_file") != "") {
+	string rank_file = CLI::GetParam<string>("rank_file");
+	check_nn_utils::compute_error(rank_file,  rdata.n_cols, &nac);
+      } else {
+	check_nn_utils::compute_error(&rdata, &qdata, &nac);
+      }
     }
-  } else if (CLI::HasParam("donaive")) {
-    if (CLI::GetParam<string>("rank_file") != "") {
-      string rank_file = CLI::GetParam<string>("rank_file");
-      check_nn_utils::compute_error(rank_file,  rdata.n_cols, &nac);
-    } else {
-      check_nn_utils::compute_error(&rdata, &qdata, &nac);
-    }
   }
 }

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc	2011-11-08 17:37:25 UTC (rev 10185)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_tester.cc	2011-11-08 18:07:01 UTC (rev 10186)
@@ -1,8 +1,10 @@
-#include <fastlib/fastlib.h>
+#include <armadillo>
+#include <string>
+
+#include <mlpack/core.h>
+
 #include "exact_max_ip.h"
 
-#include <string>
-#include <armadillo>
 
 using namespace mlpack;
 using namespace std;
@@ -53,128 +55,103 @@
   string qfile = CLI::GetParam<string>("q");
 
   Log::Warn << "Loading files..." << endl;
-  if (!data::Load(rfile.c_str(), rdata))
+  if (rdata.load(rfile.c_str()) == false)
     Log::Fatal << "Reference file "<< rfile << " not found." << endl;
 
-  if (!data::Load(qfile.c_str(), qdata)) 
+  if (qdata.load(qfile.c_str()) == false) 
     Log::Fatal << "Query file " << qfile << " not found." << endl;
 
+  rdata = arma::trans(rdata);
+  qdata = arma::trans(qdata);
+
   Log::Warn << "File loaded..." << endl;
   
   Log::Warn << "R(" << rdata.n_rows << ", " << rdata.n_cols 
 	   << "), Q(" << qdata.n_rows << ", " << qdata.n_cols 
 	   << ")" << endl;
 
+ 
+  arma::Col<size_t> ks(10);
+  size_t max_k = 1, number_of_ks = 1;
+  ks(0) = 1;
 
 
-  if (CLI::HasParam("max_k")) {
+  if (CLI::GetParam<int>("max_k") != 1) {
 
-    MaxIP naive, fast_exact;
-    double naive_comp = (double) rdata.n_cols;
-    arma::Mat<size_t> nac;
-    arma::mat din;
+    max_k = CLI::GetParam<int>("max_k");
+    number_of_ks = max_k;
+    ks.set_size(number_of_ks);
 
-    size_t max_k = CLI::GetParam<int>("max_k");
-    arma::vec speedups(max_k);
+    for (size_t i = 0; i < max_k; i++)
+      ks(i) = i+1;
+  } else if (CLI::GetParam<string>("k_values") != "") {
 
-    if (CLI::HasParam("check_nn")) { 
-      Log::Warn << "Starting naive computation..." <<endl;
-      naive.InitNaive(qdata, rdata);
-      naive.WarmInit(max_k);
-      naive_comp = naive.ComputeNaive(&nac, &din);
-      Log::Warn << "Naive computation done..." << endl;
-    }
-
-    Log::Warn << "Starting loop for Fast Exact Search." << endl;
-
-    fast_exact.Init(qdata, rdata);
-
-    for (knns = 1; knns <= max_k; knns++) {
-
-      printf("k = %zu", knns); fflush(NULL);
-      arma::Mat<size_t> exc;
-      arma::mat die;
-      double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
-
-      if (CLI::HasParam("check_nn")) {
-	size_t errors = count_mismatched_neighbors(din, max_k, die, knns);
-
-	if (errors > 0) {
-	  Log::Warn << knns << "-NN error: " << errors << " / "
-		    << exc.n_elem << endl;
-	}
-      }
-
-      speedups(knns -1) = naive_comp / fast_comp;
-      printf(": %lg\n", speedups(knns -1));
-
-      fast_exact.WarmInit(knns+1);
-    }
-
-    printf("\n");
-  }
-
-  if (CLI::HasParam("k_values")) {
-
-    MaxIP naive, fast_exact;
-    double naive_comp = (double) rdata.n_cols;
-    arma::Col<size_t> nac;
-    arma::vec din;
-
-    size_t number_of_ks = 0;
+    number_of_ks = 0;
     string k_values = CLI::GetParam<string>("k_values");
-    arma::Col<size_t> ks(10);
 
-    char *temp = (char *) k_values.c_str();
-    char *pch = strtok(temp, ",");
+
+    char *pch = strtok((char *) k_values.c_str(), ",");
     while (pch != NULL) {
       ks(number_of_ks++) = atoi(pch);
       pch = strtok(NULL, ",");
     }
 
-    arma::vec speedups(number_of_ks);
+    free(pch);
+    max_k = ks(number_of_ks -1);
+  }
 
-    if (CLI::HasParam("check_nn")) { 
-      Log::Warn << "Starting naive computation..." <<endl;
-      naive.InitNaive(qdata, rdata);
-      naive.WarmInit(ks(number_of_ks - 1));
-      naive_comp = naive.ComputeNaive(&nac, &din);
-      Log::Warn << "Naive computation done..." << endl;
-    }
+  MaxIP naive, fast_exact;
+  double naive_comp = (double) rdata.n_cols;
+  arma::Mat<size_t> nac;
+  arma::mat din;
 
-    Log::Warn << "Starting loop for Fast Exact Search." << endl;
+  arma::vec speedups(number_of_ks);
 
-    fast_exact.Init(qdata, rdata);
+  if (CLI::HasParam("check_nn")) { 
+    Log::Warn << "Starting naive computation..." <<endl;
+    naive.InitNaive(qdata, rdata);
+    naive.WarmInit(ks(number_of_ks - 1));
+    naive_comp = naive.ComputeNaive(&nac, &din);
+    Log::Warn << "Naive computation done..." << endl;
+  }
 
-    for (size_t knns = 0; knns < number_of_ks; knns++) {
+  Log::Warn << "Starting loop for Fast Exact Search." << endl;
 
-      printf("k = %zu", ks(knns)); fflush(NULL);
-      arma::Col<size_t> exc;
-      arma::vec die;
-      double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
+  fast_exact.Init(qdata, rdata);
 
-      if (CLI::HasParam("check_nn")) {
-	size_t errors = count_mismatched_neighbors(din, ks(number_of_ks -1),
-						   die, ks(knns));
+  for (size_t knns = 0; knns < number_of_ks; knns++) {
+    
+    arma::Mat<size_t> exc;
+    arma::mat die;
+    double fast_comp = fast_exact.ComputeNeighbors(&exc, &die);
 
-	if (errors > 0) {
-	  Log::Warn << ks(knns) << "-NN error: " << errors << " / "
-		    << exc.n_elem << endl;
-	}
-      }
+//     printf("Search done %lg\n", fast_comp); fflush(NULL);
 
-      speedups(knns -1) = naive_comp / fast_comp;
-      printf(": %lg\n", speedups(knns -1));
+    if (CLI::HasParam("check_nn")) {
+      size_t errors = count_mismatched_neighbors(din, ks(number_of_ks -1),
+						 die, ks(knns));
 
-      if (knns < number_of_ks - 1)
-	fast_exact.WarmInit(ks(knns+1));
+      if (errors > 0) {
+	Log::Warn << ks(knns) << "-NN error: " << errors << " / "
+		  << exc.n_elem << endl;
+      }
     }
 
-    printf("\n");
+    speedups(knns) = naive_comp / fast_comp;
+    printf("k = %zu", ks(knns)); fflush(NULL);
+    printf(": %lg\n", speedups(knns)); fflush(NULL);
+
+    if (knns < number_of_ks - 1)
+      fast_exact.WarmInit(ks(knns+1));
+
+    exc.reset();
+    die.reset();
   }
+
+  printf("\n");
   
   Log::Warn << "Search completed for all values of k...printing results now"
-	   << endl;
+	    << endl;
 
 //   if (CLI::HasParam("print_speedups")) {
 //     string speedup_file = CLI::GetParam<string>("speedup_file");




More information about the mlpack-svn mailing list