[mlpack-svn] r10076 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Oct 29 23:16:22 EDT 2011


Author: pram
Date: 2011-10-29 23:16:22 -0400 (Sat, 29 Oct 2011)
New Revision: 10076

Added:
   mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h
Modified:
   mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
   mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
Log:
Cosine tree -- co-axial cone tree added and tested

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-10-30 03:16:22 UTC (rev 10076)
@@ -10,14 +10,17 @@
    gen_metric_tree.h
    gen_metric_tree_impl.h
    gen_metric_tree_impl.cc
+
+   cosine.h
+
+   # Cosine tree
+   dcosinebound.h
+   gen_cosine_tree.h
+   gen_cosine_tree_impl.h
+   #gen_cosine_tree_impl.cc
    # Cone bound
-   cosine.h
    dconebound.h
    dconebound_impl.h
-   # Cosine tree
-   #gen_cosine_tree.h
-   #gen_cosine_tree_impl.h
-   #gen_cosine_tree_impl.cc
    # cone tree
    gen_cone_tree.h
    gen_cone_tree_impl.h

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc	2011-10-30 03:16:22 UTC (rev 10076)
@@ -1,9 +1,12 @@
 #include <armadillo>
 #include <string>
 #include "general_spacetree.h"
-#include "dconebound.h"
+// #include "dconebound.h"
+// #include "gen_cosine_tree.h"
+#include "dcosinebound.h"
 #include "gen_cosine_tree.h"
 
+
 using namespace mlpack;
 using namespace std;
 
@@ -32,21 +35,20 @@
 
   rdata = arma::trans(rdata);
 
-  typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat> CTreeType;
+  typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat> CTreeType;
 
   arma::Col<size_t> old_from_new_data;
 
   CTreeType *test_tree
     = proximity::MakeGenCosineTree<CTreeType>(rdata, 20,
-					      &old_from_new_data,
-					      NULL);
+					      &old_from_new_data); //,
+// 					      NULL);
 
   if (CLI::HasParam("print_tree")) {
     test_tree->Print();
   } else {
     Log::Info << "Tree built" << endl;
   }
-
 } // end main
 
 

Added: mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h	                        (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -0,0 +1,139 @@
+/**
+ * @file dcosinebound.h
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Interface to a cosine bound that works with cosine 
+ * similarity only (for now)
+ *
+ * @experimental
+ */
+
+#ifndef TREE_DCOSINEBOUND_H
+#define TREE_DCOSINEBOUND_H
+
+#include "mlpack/core/math/range.hpp"
+
+// Awaiting transition
+#include "cosine.h"
+
+#include <armadillo>
+
+/**
+ * Cosine bound that works in arbitrary metric spaces.
+ *
+ * See LMetric for an example metric template parameter.
+ *
+ * To initialize this, set the radius with @c set_radius
+ * and set the point by initializing @c point() directly.
+ */
+template<typename TMetric = Cosine, typename TPoint = arma::vec>
+class DCosineBound {
+public:
+  typedef TPoint Point;
+  typedef TMetric Metric;
+
+private:
+  double rad_min_;
+  double rad_max_;
+  double radius_;
+  TPoint center_;
+
+public:
+  /***
+   * Return the radius of the cosine bound.
+   */
+  double rad_min() const { return rad_min_; }
+  double rad_max() const { return rad_max_; }
+  double radius() const { return radius_; }
+
+  /***
+   * Set the radius of the bound.
+   */
+  void set_radius(double rad_min, double rad_max) { 
+
+    rad_min_ = rad_min;
+    rad_max_ = rad_max;
+    radius_ = rad_max - rad_min;
+  }
+
+  /***
+   * Return the center point.
+   */
+  const TPoint& center() const { return center_; }
+
+  /***
+   * Return the center point.
+   */
+  TPoint& center() { return center_; }
+
+  // IMPLEMENT THESE LATER IN CASE THIS WORKS OUT
+  //   /**
+  //    * Determines if a point is within this bound.
+  //    */
+  //   bool Contains(const Point& point) const;
+
+  //   /**
+  //    * Gets the center.
+  //    *
+  //    * Don't really use this directly.  This is only here for consistency
+  //    * with DHrectBound, so it can plug in more directly if a "centroid"
+  //    * is needed.
+  //    */
+  //   void CalculateMidpoint(Point *centroid) const;
+
+  //   /**
+  //    * Calculates maximum bound-to-point cosine.
+  //    */
+  //   double MaxCosine(const Point& point) const;
+
+  //   /**
+  //    * Calculates maximum bound-to-bound cosine.
+  //    */
+  //   double MaxCosine(const DCosineBound& other) const;
+
+  //   /**
+  //    * Computes maximum distance.
+  //    */
+  //   double MinCosine(const Point& point) const;
+
+  //   /**
+  //    * Computes maximum distance.
+  //    */
+  //   double MinCosine(const DCosineBound& other) const;
+
+  //   /**
+  //    * Calculates minimum and maximum bound-to-bound squared distance.
+  //    *
+  //    * Example: bound1.MinDistanceSq(other) for minimum squared distance.
+  //    */
+  //   mlpack::math::Range RangeCosine(const DCosineBound& other) const;
+
+  //   /**
+  //    * Calculates closest-to-their-midpoint bounding box distance,
+  //    * i.e. calculates their midpoint and finds the minimum box-to-point
+  //    * distance.
+  //    *
+  //    * Equivalent to:
+  //    * <code>
+  //    * other.CalcMidpoint(&other_midpoint)
+  //    * return MaxCosineToPoint(other_midpoint)
+  //    * </code>
+  //    */
+  //   double MaxToMid(const DCosineBound& other) const;
+
+  //   /**
+  //    * Computes minimax distance, where the other node is trying to avoid me.
+  //    */
+  //   double MinimaxCosine(const DCosineBound& other) const;
+
+  //   /**
+  //    * Calculates midpoint-to-midpoint bounding box distance.
+  //    */
+  //   double MidCosine(const DCosineBound& other) const;
+  //   double MidCosine(const Point& point) const;
+
+}; // DCosineBound
+
+// #include "dcosinebound_impl.h"
+
+#endif

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-30 03:16:22 UTC (rev 10076)
@@ -77,7 +77,9 @@
   return (query_norms_(query_) * (c_norm + rad));
 }
 
-double MaxIP::MaxNodeIP_(CTreeType* query_node,
+
+// CONE TREE
+double MaxIP::MaxNodeIP_(CTreeTypeA* query_node,
 			 TreeType* reference_node) {
 
   // counting the split decisions 
@@ -85,16 +87,15 @@
 
   // min_{q', q} cos <qq' = cos_w
   arma::vec q = query_node->bound().center();
+
+  // FOR CONE TREE
   double cos_w = query_node->bound().radius();
   double sin_w = query_node->bound().radius_conjugate();
 
-  // +1: Can cache it in the query tree
-  //  double q_norm = arma::norm(q, 2);
   double q_norm = query_node->stat().center_norm();
 
   arma::vec centroid = reference_node->bound().center();
 
-  // +1: can be cached in the reference tree
   double c_norm = reference_node->stat().dist_to_origin();
   double rad = std::sqrt(reference_node->bound().radius());
 
@@ -164,7 +165,78 @@
   } // alt-angle-prune
 
   return (c_norm + rad);
+}
 
+
+// COSINE TREE
+double MaxIP::MaxNodeIP_(CTreeTypeB* query_node,
+			 TreeType* reference_node) {
+
+  // counting the split decisions 
+  split_decisions_++;
+
+  // min_{q', q} cos <qq' = cos_w
+  arma::vec q = query_node->bound().center();
+
+  // FOR COSINE TREE
+  double cos_w_min = query_node->bound().rad_min();
+  double cos_w_max = query_node->bound().rad_max();
+
+  double q_norm = query_node->stat().center_norm();
+
+  arma::vec centroid = reference_node->bound().center();
+
+  double c_norm = reference_node->stat().dist_to_origin();
+  double rad = std::sqrt(reference_node->bound().radius());
+
+  if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) { 
+    // HAVE TO WORK THIS OUT AND CHECK IT ANALYTICALLY
+    // using the closed-form-maximization,
+    // |p| cos (phi - w) + R
+
+    double cos_phi, sin_phi;
+
+    if (!reference_node->stat().has_cos_phi()) {
+      // +1
+      cos_phi = arma::dot(q, centroid) / (c_norm * q_norm);
+      reference_node->stat().set_cos_phi(cos_phi);
+    } else {
+      cos_phi = reference_node->stat().cos_phi();
+      split_decisions_--;
+    }
+
+    sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+    
+    if (cos_phi < cos_w_min) { 
+     
+      // bound with cos(phi - w_min)
+      double sin_w_min = std::sqrt(1 - cos_w_min * cos_w_min);
+      return (c_norm * (cos_phi * cos_w_min + sin_phi * sin_w_min) + rad);
+
+    } else if (cos_phi > cos_w_max) {
+
+      // bound with cos (w_max - phi)
+      double sin_w_max = std::sqrt(1 - cos_w_max * cos_w_max);
+      return (c_norm * (cos_phi * cos_w_max + sin_phi * sin_w_max) + rad);
+
+    } else {
+      // cos w_min < cos phi < cos w_max; 
+      // hence the query cone has the centroid
+      // so no useful bounding
+
+      cone_has_centroid_ += query_node->end() - query_node->begin();
+
+      // maybe trigger a flag which lets the peep know that 
+      // query cone too wide
+      // Although, not sure how that will benefit since the 
+      // current traversal kind of takes care of this.
+
+      return (c_norm + rad);
+    }
+  } // alt-angle-prune
+
+  return (c_norm + rad);
+
 }
 
 
@@ -273,7 +345,8 @@
 } // ComputeNeighborsRecursion_
 
 
-void MaxIP::ComputeBaseCase_(CTreeType* query_node, 
+// CONE TREE
+void MaxIP::ComputeBaseCase_(CTreeTypeA* query_node, 
 			     TreeType* reference_node) {
 
   // Check that the pointers are not NULL
@@ -313,10 +386,55 @@
     query_node->stat().set_bound(query_worst_p_cos_pq);
 
 } // ComputeBaseCase_
+
+
+
+// COSINE TREE
+void MaxIP::ComputeBaseCase_(CTreeTypeB* query_node, 
+			     TreeType* reference_node) {
+
+  // Check that the pointers are not NULL
+  assert(reference_node != NULL);
+  assert(reference_node->is_leaf());
+  assert(query_node != NULL);
+
+  // query node may not be a leaf
+  //  assert(query_node->is_leaf());
+
+  // Used to find the query node's new lower bound
+  double query_worst_p_cos_pq = DBL_MAX;
+  bool new_bound = false;
+    
+  // Iterating over the queries individually
+  for (query_ = query_node->begin();
+       query_ < query_node->end(); query_++) {
+
+    // checking if this node has potential
+    double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+    if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+      // this node has potential
+      ComputeBaseCase_(reference_node);
+
+    double p_cos_pq = max_ips_(knns_ -1, query_)
+      / query_norms_(query_);
+
+    if (query_worst_p_cos_pq > p_cos_pq) {
+      query_worst_p_cos_pq = p_cos_pq;
+      new_bound = true;
+    }
+  } // for query_
   
+  // Update the lower bound for the query_node
+  if (new_bound) 
+    query_node->stat().set_bound(query_worst_p_cos_pq);
 
-void MaxIP::CheckPrune(CTreeType* query_node, TreeType* ref_node) {
+} // ComputeBaseCase_
+  
 
+// CONE TREE
+void MaxIP::CheckPrune(CTreeTypeA* query_node, TreeType* ref_node) {
+
   size_t missed_nns = 0;
   double max_p_cos_pq = 0.0;
   double min_p_cos_pq = DBL_MAX;
@@ -360,7 +478,57 @@
 
 }
 
-void MaxIP::ComputeNeighborsRecursion_(CTreeType* query_node,
+
+
+// COSINE TREE
+void MaxIP::CheckPrune(CTreeTypeB* query_node, TreeType* ref_node) {
+
+  size_t missed_nns = 0;
+  double max_p_cos_pq = 0.0;
+  double min_p_cos_pq = DBL_MAX;
+
+  // Iterating over the queries individually
+  for (query_ = query_node->begin();
+       query_ < query_node->end(); query_++) {
+
+    // Get the query point from the matrix
+    arma::vec q = queries_.unsafe_col(query_);
+
+    double p_cos_qp = max_ips_(knns_ -1, query_) / query_norms_(query_);
+    if (min_p_cos_pq > p_cos_qp)
+      min_p_cos_pq = p_cos_qp;
+
+    // We'll do the same for the references
+    for (size_t reference_index = ref_node->begin(); 
+	 reference_index < ref_node->end(); reference_index++) {
+
+      arma::vec r = references_.unsafe_col(reference_index);
+
+      double ip = arma::dot(q, r);
+      if (ip > max_ips_(knns_-1, query_))
+	missed_nns++;
+
+      double p_cos_pq = ip / query_norms_(query_);
+
+      if (p_cos_pq > max_p_cos_pq)
+	max_p_cos_pq = p_cos_pq;
+      
+    } // for reference_index
+  } // for query_
+  
+  if (missed_nns > 0 || query_node->stat().bound() != min_p_cos_pq) 
+    printf("Prune %zu - Missed candidates: %zu\n"
+	   "QLBound: %lg, ActualQLBound: %lg\n"
+	   "QRBound: %lg, ActualQRBound: %lg\n",
+	   number_of_prunes_, missed_nns,
+	   query_node->stat().bound(), min_p_cos_pq, 
+	   MaxNodeIP_(query_node, ref_node), max_p_cos_pq);
+
+}
+
+
+// CONE TREE
+void MaxIP::ComputeNeighborsRecursion_(CTreeTypeA* query_node,
 				       TreeType* reference_node, 
 				       double upper_bound_p_cos_pq) {
 
@@ -517,8 +685,152 @@
     } // alt-traversal
   } // All cases of dual-tree traversal
 } // ComputeNeighborsRecursion_
+
+
+
+// COSINE TREE
+void MaxIP::ComputeNeighborsRecursion_(CTreeTypeB* query_node,
+				       TreeType* reference_node, 
+				       double upper_bound_p_cos_pq) {
+
+  assert(query_node != NULL);
+  assert(reference_node != NULL);
+  //assert(upper_bound_p_cos_pq == MaxNodeIP_(query_node, reference_node));
+
+  if (upper_bound_p_cos_pq < query_node->stat().bound()) { 
+    // Pruned
+    number_of_prunes_ += query_node->end() - query_node->begin();
+
+    if (CLI::HasParam("maxip/check_prune"))
+      CheckPrune(query_node, reference_node);
+  }
+  // node->is_leaf() works as one would expect
+  else if (query_node->is_leaf() && reference_node->is_leaf()) {
+    // Base Case
+    ComputeBaseCase_(query_node, reference_node);
+  } else if (query_node->is_leaf()) {
+    // Only query is a leaf
+      
+
+    if (CLI::HasParam("maxip/alt_dual_traversal")) {
+      // Trying to do single-tree on the leaves
+
+      // Used to find the query node's new lower bound
+      double query_worst_p_cos_pq = DBL_MAX;
+      bool new_bound = false;
+    
+      // Iterating over the queries individually
+      for (query_ = query_node->begin();
+	   query_ < query_node->end(); query_++) {
+
+	// checking if this node has potential
+	double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+	if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+	  // this node has potential
+	  ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+	double p_cos_pq = max_ips_(knns_ -1, query_)
+	  / query_norms_(query_);
+
+	if (query_worst_p_cos_pq > p_cos_pq) {
+	  query_worst_p_cos_pq = p_cos_pq;
+	  new_bound = true;
+	}
+      } // for query_
   
+      // Update the lower bound for the query_node
+      if (new_bound) 
+	query_node->stat().set_bound(query_worst_p_cos_pq);
+    } else {
 
+      // We'll order the computation by distance 
+      double left_p_cos_pq = MaxNodeIP_(query_node,
+					reference_node->left());
+      double right_p_cos_pq = MaxNodeIP_(query_node,
+					 reference_node->right());
+      
+      if (left_p_cos_pq > right_p_cos_pq) {
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_p_cos_pq);
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_p_cos_pq);
+      } else {
+	ComputeNeighborsRecursion_(query_node, reference_node->right(), 
+				   right_p_cos_pq);
+	ComputeNeighborsRecursion_(query_node, reference_node->left(), 
+				   left_p_cos_pq);
+      }
+    } // alt-traversal
+  } else if (reference_node->is_leaf()) {
+    // Only reference is a leaf 
+    double left_p_cos_pq
+      = MaxNodeIP_(query_node->left(), reference_node);
+    double right_p_cos_pq
+      = MaxNodeIP_(query_node->right(), reference_node);
+      
+    ComputeNeighborsRecursion_(query_node->left(), reference_node, 
+			       left_p_cos_pq);
+    ComputeNeighborsRecursion_(query_node->right(), reference_node, 
+			       right_p_cos_pq);
+      
+    // We need to update the upper bound based on the new upper bounds of 
+    // the children
+    query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+					  query_node->right()->stat().bound()));
+  } else {
+
+    // Recurse on both as above
+    double left_p_cos_pq = MaxNodeIP_(query_node->left(), 
+				      reference_node->left());
+    double right_p_cos_pq = MaxNodeIP_(query_node->left(), 
+				       reference_node->right());
+      
+    if (left_p_cos_pq > right_p_cos_pq) {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+    } else {
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->left(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+    }
+
+    left_p_cos_pq = MaxNodeIP_(query_node->right(),
+			       reference_node->left());
+    right_p_cos_pq = MaxNodeIP_(query_node->right(), 
+				reference_node->right());
+      
+    if (left_p_cos_pq > right_p_cos_pq) {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+    } else {
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->right(), 
+				 right_p_cos_pq);
+      ComputeNeighborsRecursion_(query_node->right(),
+				 reference_node->left(), 
+				 left_p_cos_pq);
+    }
+      
+    // Update the upper bound as above
+    query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+					  query_node->right()->stat().bound()));
+  } // All cases of dual-tree traversal
+} // ComputeNeighborsRecursion_
+
+
+
 void MaxIP::Init(const arma::mat& queries_in,
 		 const arma::mat& references_in) {
     
@@ -564,12 +876,27 @@
   set_angles_in_balls_(reference_tree_);
    
   if (mlpack::CLI::HasParam("maxip/dual_tree")) {
-    query_tree_
-      = proximity::MakeGenConeTree<CTreeType>(queries_,
-					      leaf_size_,
-					      &old_from_new_queries_,
-					      NULL);
-    set_norms_in_cones_(query_tree_);
+
+    if (mlpack::CLI::HasParam("maxip/alt_tree")) { 
+      
+      // using cosine tree 
+      query_tree_B_
+	= proximity::MakeGenCosineTree<CTreeTypeB>(queries_,
+						   leaf_size_,
+						   &old_from_new_queries_,
+						   NULL);
+      set_norms_in_cones_(query_tree_B_);
+    
+    } else {
+      
+      // using cone tree
+      query_tree_A_
+	= proximity::MakeGenConeTree<CTreeTypeA>(queries_,
+						 leaf_size_,
+						 &old_from_new_queries_,
+						 NULL);
+      set_norms_in_cones_(query_tree_A_);
+    }
   }
 
   // saving the query norms beforehand to use 
@@ -652,13 +979,22 @@
   max_ips_ =  0.0 * arma::ones<arma::mat>(knns_, queries_.n_cols);
 
   // need to reset the querystats in the Query Tree
-  if (mlpack::CLI::HasParam("maxip/dual_tree"))
-    if (query_tree_ != NULL)
-      reset_tree_(query_tree_);
+  if (mlpack::CLI::HasParam("maxip/dual_tree")) 
+    if (mlpack::CLI::HasParam("maxip/alt_tree")) {
+      if (query_tree_B_ != NULL) 
+	reset_tree_(query_tree_B_);
 
+      reset_tree_(reference_tree_);
+    
+    } else {
+	if (query_tree_A_ != NULL)
+	  reset_tree_(query_tree_A_);
+    }
+
+
 } // WarmInit
 
-void MaxIP::reset_tree_(CTreeType* tree) {
+void MaxIP::reset_tree_(CTreeTypeA* tree) {
   assert(tree != NULL);
   tree->stat().set_bound(0.0);
 
@@ -670,6 +1006,32 @@
   return;
 } // reset_tree_
 
+
+void MaxIP::reset_tree_(CTreeTypeB* tree) {
+  assert(tree != NULL);
+  tree->stat().set_bound(0.0);
+
+  if (!tree->is_leaf()) {
+    reset_tree_(tree->left());
+    reset_tree_(tree->right());
+  }
+
+  return;
+} // reset_tree_
+
+
+void MaxIP::reset_tree_(TreeType* tree) {
+  assert(tree != NULL);
+  tree->stat().reset();
+
+  if (!tree->is_leaf()) {
+    reset_tree_(tree->left());
+    reset_tree_(tree->right());
+  }
+
+  return;
+} // reset_tree_
+
 void MaxIP::set_angles_in_balls_(TreeType* tree) {
 
   assert(tree != NULL);
@@ -694,7 +1056,7 @@
 } // set_angles_in_balls_
 
 
-void MaxIP::set_norms_in_cones_(CTreeType* tree) {
+void MaxIP::set_norms_in_cones_(CTreeTypeA* tree) {
 
   assert(tree != NULL);
 
@@ -710,7 +1072,25 @@
   return;
 } // set_norms_in_cones_
 
+void MaxIP::set_norms_in_cones_(CTreeTypeB* tree) {
 
+  assert(tree != NULL);
+
+  // set up node stats
+  tree->stat().set_center_norm(arma::norm(tree->bound().center(), 2));
+
+  // traverse down the children
+  if (!tree->is_leaf()) {
+    set_norms_in_cones_(tree->left());
+    set_norms_in_cones_(tree->right());
+  }
+
+  return;
+} // set_norms_in_cones_
+
+
+
+
 double MaxIP::ComputeNeighbors(arma::Mat<size_t>* resulting_neighbors,
 			       arma::mat* ips) {
 
@@ -722,8 +1102,14 @@
 
 
     CLI::StartTimer("maxip/fast_dual");
-    ComputeNeighborsRecursion_(query_tree_, reference_tree_,
-			       MaxNodeIP_(query_tree_, reference_tree_));
+
+    if (mlpack::CLI::HasParam("maxip/alt_tree"))
+      ComputeNeighborsRecursion_(query_tree_B_, reference_tree_,
+				 MaxNodeIP_(query_tree_B_, reference_tree_));
+    else
+      ComputeNeighborsRecursion_(query_tree_A_, reference_tree_,
+				 MaxNodeIP_(query_tree_A_, reference_tree_));
+    
     CLI::StopTimer("maxip/fast_dual");
 
     resulting_neighbors->set_size(max_ip_indices_.n_rows,

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -6,7 +6,7 @@
 #ifndef EXACT_MAX_IP_H
 #define EXACT_MAX_IP_H
 
-//#define NDEBUG
+#define NDEBUG
 
 #include <assert.h>
 #include <mlpack/core.h>
@@ -18,6 +18,8 @@
 #include "gen_metric_tree.h"
 #include "dconebound.h"
 #include "gen_cone_tree.h"
+#include "dcosinebound.h"
+#include "gen_cosine_tree.h"
 
 using namespace mlpack;
 
@@ -70,11 +72,15 @@
     double cosine_origin_;
     double sine_origin_;
     double dist_to_origin_;
+    bool has_cos_phi_;
+    double cos_phi_;
 
   public:
     double cosine_origin() { return cosine_origin_; }
     double sine_origin() { return sine_origin_; }
     double dist_to_origin() { return dist_to_origin_; }
+    bool has_cos_phi() { return has_cos_phi_; }
+    double cos_phi() { return cos_phi_; }
 
 
     void set_angles(double val, size_t type = 0) { 
@@ -91,11 +97,23 @@
       dist_to_origin_ = val;
     }
 
+    void set_cos_phi(double cos_phi) {
+      cos_phi_ = cos_phi;
+      has_cos_phi_ = true;
+    }
+
+    void reset() {
+      cos_phi_ = -1.0;
+      has_cos_phi_ = false;
+    }
+
     // FILL OUT THE INIT FUNCTIONS APPROPRIATELY LATER
     RefStat() {
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
       dist_to_origin_ = 0.0;
+      has_cos_phi_ = false;
+      cos_phi_ = -1.0;
     }
 
     ~RefStat() {}
@@ -103,6 +121,8 @@
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
       dist_to_origin_ = 0.0;
+      has_cos_phi_ = false;
+      cos_phi_ = -1.0;
     }
 
     void Init(const arma::mat& data, size_t begin, size_t count,
@@ -110,6 +130,8 @@
       cosine_origin_ = 0.0;
       sine_origin_ = 0.0;
       dist_to_origin_ = 0.0;
+      has_cos_phi_ = false;
+      cos_phi_ = -1.0;
     }
   }; // RefStat
 
@@ -117,7 +139,8 @@
   // Euclidean bounding boxes, the data are stored in a Matrix, 
   // and each node has a QueryStat for its bound.
   typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
-  typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
+  typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeTypeA;
+  typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat, QueryStat> CTreeTypeB;
    
   
   /////////////////////////////// Members ////////////////////////////
@@ -132,7 +155,8 @@
 
   // Pointers to the roots of the two trees.
   TreeType* reference_tree_;
-  CTreeType* query_tree_;
+  CTreeTypeA* query_tree_A_;
+  CTreeTypeB* query_tree_B_;
 
   // The total number of prunes.
   size_t number_of_prunes_;
@@ -170,7 +194,8 @@
    */
   MaxIP() {
     reference_tree_ = NULL;
-    query_tree_ = NULL;
+    query_tree_A_ = NULL;
+    query_tree_B_ = NULL;
   } 
   
   /**
@@ -181,8 +206,12 @@
     if (reference_tree_ != NULL) 
       delete reference_tree_;
  
-    if (query_tree_ != NULL)
-      delete query_tree_;
+    if (query_tree_A_ != NULL)
+      delete query_tree_A_;
+
+    if (query_tree_B_ != NULL)
+      delete query_tree_B_;
+
   }
     
   /////////////////////////// Helper Functions //////////////////////
@@ -200,7 +229,8 @@
    * ignoring the norm of any query.
    * So it is computing \max_(q,r) |r| cos <qr.
    */
-  double MaxNodeIP_(CTreeType *query_node, TreeType* reference_node);
+  double MaxNodeIP_(CTreeTypeA *query_node, TreeType* reference_node);
+  double MaxNodeIP_(CTreeTypeB *query_node, TreeType* reference_node);
 
   /**
    * Performs exhaustive computation at the leaves.  
@@ -210,7 +240,8 @@
   /**
    * Dual-tree: Performs exhaustive computation between two leaves.  
    */
-  void ComputeBaseCase_(CTreeType* query_node, TreeType* reference_node);
+  void ComputeBaseCase_(CTreeTypeA* query_node, TreeType* reference_node);
+  void ComputeBaseCase_(CTreeTypeB* query_node, TreeType* reference_node);
   
   /**
    * The recursive function
@@ -221,13 +252,20 @@
   /**
    * Dual-tree: The recursive function
    */
-  void ComputeNeighborsRecursion_(CTreeType* query_node,
+  void ComputeNeighborsRecursion_(CTreeTypeA* query_node,
 				  TreeType* reference_node, 
 				  double upper_bound_ip);
 
-  void reset_tree_(CTreeType *tree);
+  void ComputeNeighborsRecursion_(CTreeTypeB* query_node,
+				  TreeType* reference_node, 
+				  double upper_bound_ip);
+
+  void reset_tree_(CTreeTypeA *tree);
+  void reset_tree_(CTreeTypeB *tree);
+  void reset_tree_(TreeType *tree);
   void set_angles_in_balls_(TreeType *tree);
-  void set_norms_in_cones_(CTreeType *tree);
+  void set_norms_in_cones_(CTreeTypeA *tree);
+  void set_norms_in_cones_(CTreeTypeB *tree);
 
   size_t SortValue(double value);
 
@@ -246,8 +284,10 @@
     if (reference_tree_ != NULL)
       delete reference_tree_;
     
-    if (query_tree_ != NULL)
-      delete query_tree_;
+    if (query_tree_A_ != NULL)
+      delete query_tree_A_;
+    if (query_tree_B_ != NULL)
+      delete query_tree_B_;
   }
 
   /**
@@ -275,7 +315,8 @@
 		      arma::mat* ips);
 
 
-  void CheckPrune(CTreeType* query_node, TreeType* ref_node);
+  void CheckPrune(CTreeTypeA* query_node, TreeType* ref_node);
+  void CheckPrune(CTreeTypeB* query_node, TreeType* ref_node);
 }; //class MaxIP
 
 
@@ -293,9 +334,13 @@
 	   " pruning using the angles as well", "maxip");
 PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
 	   " pruning using the angles as well", "maxip");
+
 PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
 	   "computation, using a cone tree for the "
 	   "queries.", "maxip");
+PARAM_FLAG("alt_tree", "The flag to trigger the "
+	   "alternate query tree.",
+	   "maxip");
 
 PARAM_FLAG("check_prune", "The flag to trigger the "
 	   "checking of the prune.", "maxip");

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -11,6 +11,7 @@
 #ifndef GEN_CONE_TREE_H
 #define GEN_CONE_TREE_H
 
+#define NDEBUG
 
 #include "general_spacetree.h"
 #include "gen_cone_tree_impl.h"

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -7,6 +7,8 @@
 #ifndef GEN_CONE_TREE_IMPL_H
 #define GEN_CONE_TREE_IMPL_H
 
+#define NDEBUG
+
 #include <assert.h>
 #include <mlpack/core.h>
 #include <armadillo>

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -11,7 +11,11 @@
 #ifndef GEN_COSINE_TREE_H
 #define GEN_COSINE_TREE_H
 
+#define NDEBUG
 
+#include <armadillo>
+#include <assert.h>
+
 #include "general_spacetree.h"
 #include "gen_cosine_tree_impl.h"
 
@@ -59,6 +63,10 @@
 
     node->Init(0, matrix.n_cols);
     node->bound().center().set_size(matrix.n_rows);
+    node->bound().center() = arma::mean(matrix, 1);
+
+    assert(node->bound().center().n_elem == matrix.n_rows); 
+
     tree_gen_cosine_tree_private::SplitGenCosineTree<TCosineTree>
       (matrix, node, leaf_size, old_from_new_ptr);
     

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -7,6 +7,8 @@
 #ifndef GEN_COSINE_TREE_IMPL_H
 #define GEN_COSINE_TREE_IMPL_H
 
+#define NDEBUG
+
 #include <assert.h>
 #include <mlpack/core.h>
 #include <armadillo>
@@ -14,119 +16,24 @@
 
 namespace tree_gen_cosine_tree_private {
 
-  size_t FurthestColumnIndex(const arma::vec& pivot,
-			     const arma::mat& matrix, 
-			     size_t begin, size_t count,
-			     double *furthest_cosine);
-    
-  // This function assumes that we have points embedded in Euclidean
-  // space. The representative point (center) is chosen as the mean 
-  // of the all unit directions of all the vectors in this set.
-
-  // fixed!!
+  // not compiled
   template<typename TBound>
   void MakeLeafCosineTreeNode(const arma::mat& matrix,
 			      size_t begin, size_t count,
 			      TBound *bounds) {
 
-    bounds->center().zeros();
 
-    size_t end = begin + count;
-    for (size_t i = begin; i < end; i++) {
-      bounds->center() 
-	+= (matrix.unsafe_col(i)
- 	    / arma::norm(matrix.unsafe_col(i), 2));
-    }
-    bounds->center() /= (double) count;
-//     bounds->center() /= arma::norm(bounds->center(), 2);
-
-//     printf("c_norm: %lg\n", arma::norm(bounds->center(), 2));
-
-    double furthest_cosine;
-    FurthestColumnIndex(bounds->center(), matrix, begin, count,
-			&furthest_cosine);
-    bounds->set_radius(furthest_cosine);
-  }
-  
-
-  // fixed until compiled
-  template<typename TBound>
-  size_t MatrixPartition(arma::mat& matrix, size_t first, size_t count,
-			  TBound &left_bound, TBound &right_bound,
-			  size_t *old_from_new) {
+    arma::vec cosine_vec(count);
     
-    size_t end = first + count;
-    size_t left_count = 0;
-
-    std::vector<bool> left_membership;
-    left_membership.reserve(count);
+    for (size_t i = 0; i < count; i++)
+      cosine_vec(i) = Cosine::Evaluate(bounds->center(), 
+				       matrix.unsafe_col(i + begin));
     
-    for (size_t left = first; left < end; left++) {
-
-      // Make alias of the current point.
-      arma::vec point = matrix.unsafe_col(left);
-
-      // Compute the cosines from the two pivots.
-      double cosine_from_left_pivot =
-	Cosine::Evaluate(point, left_bound.center());
-      double cosine_from_right_pivot =
-	Cosine::Evaluate(point, right_bound.center());
-
-      // We swap if the point is more angled away from the left pivot.
-      if(cosine_from_left_pivot < cosine_from_right_pivot) {	
-	left_membership[left - first] = false;
-      } else {
-	left_membership[left - first] = true;
-	left_count++;
-      }
-    }
-
-    size_t left = first;
-    size_t right = first + count - 1;
-    
-    /* At any point:
-     *
-     *   everything < left is correct
-     *   everything > right is correct
-     */
-    for (;;) {
-      while (left_membership[left - first] && (left <= right)) {
-        left++;
-      }
-
-      while (!left_membership[right - first] && (left <= right)) {
-        right--;
-      }
-
-      if (left > right) {
-        /* left == right + 1 */
-        break;
-      }
-
-      // Swap the left vector with the right vector.
-      matrix.swap_cols(left, right);
-
-      bool tmp = left_membership[left - first];
-      left_membership[left - first] = left_membership[right - first];
-      left_membership[right - first] = tmp;
-      
-      if (old_from_new) {
-        size_t t = old_from_new[left];
-        old_from_new[left] = old_from_new[right];
-        old_from_new[right] = t;
-      }
-      
-      assert(left <= right);
-      right--;
-    }
-    
-    assert(left == right + 1);
-
-    return left_count;
+    bounds->set_radius(arma::min(cosine_vec), arma::max(cosine_vec));
   }
-	
+  
 
-  // fixed
+  // not compiled
   template<typename TCosineTree>
   bool AttemptSplitting(arma::mat& matrix,
 			TCosineTree *node,
@@ -135,83 +42,100 @@
 			size_t leaf_size,
 			size_t *old_from_new) {
 
-    // Pick a random row.
-    size_t random_row 
-      = math::RandInt(node->begin(),
-		      node->begin() + node->count());
+    // obtain the list of cosine values to all the points
+    // in the set
+    arma::vec cosine_vec(node->count());
 
-    // why is this here?
-    // random_row = node->begin();
+    for (size_t i = 0; i < node->count(); i++)
+      cosine_vec(i)
+	= Cosine::Evaluate(node->bound().center(),
+			   matrix.unsafe_col(i + node->begin()));
 
-    arma::vec random_row_vec = matrix.unsafe_col(random_row);
-
-    // Now figure out the furthest point from the random row picked
-    // above.
-    double furthest_cosine;
-    size_t furthest_from_random_row =
-      FurthestColumnIndex(random_row_vec, matrix, node->begin(), node->count(),
-			  &furthest_cosine);
-    arma::vec furthest_from_random_row_vec = matrix.unsafe_col(furthest_from_random_row);
-
-    // Then figure out the furthest point from the furthest point.
-    double furthest_from_furthest_cosine;
-    size_t furthest_from_furthest_random_row =
-      FurthestColumnIndex(furthest_from_random_row_vec, matrix, node->begin(),
-			  node->count(), &furthest_from_furthest_cosine);
-    arma::vec furthest_from_furthest_random_row_vec =
-       matrix.unsafe_col(furthest_from_furthest_random_row);
-
-    if(furthest_from_furthest_cosine > (1.0 - DBL_EPSILON)) {
-      // everything in a really tight narrow cone
+    if(arma::max(cosine_vec) - arma::min(cosine_vec) < DBL_EPSILON) {
+      // everything in a really tight narrow co-axial cone-ring
       return false;
     } else {
       *left = new TCosineTree();
       *right = new TCosineTree();
 
-      // not necessary, vec::operator=() takes care of resetting the size
-//      ((*left)->bound().center()).set_size(matrix.n_rows);
-//      ((*right)->bound().center()).set_size(matrix.n_rows);
+      ((*left)->bound().center()) = node->bound().center();
+      ((*right)->bound().center()) = node->bound().center();
 
-      ((*left)->bound().center()) = furthest_from_random_row_vec;
-      ((*right)->bound().center()) = furthest_from_furthest_random_row_vec;
+      node->bound().set_radius(arma::min(cosine_vec), arma::max(cosine_vec));
 
-      size_t left_count
-	= MatrixPartition(matrix, node->begin(), node->count(),
-			  (*left)->bound(), (*right)->bound(),
-			  old_from_new);
+//       printf("%lg, %lg\n", node->bound().rad_min(), node->bound().rad_max());
 
-      (*left)->Init(node->begin(), left_count);
-      (*right)->Init(node->begin() + left_count, node->count() - left_count);
-    }
+      size_t first = node->begin();
+      size_t end = first + node->count();
+      size_t left_count = 0;
 
-    return true;
-  }
+      double median_cosine_value = arma::median(cosine_vec);
 
-  // fixed
-  template<typename TCosineTree>
-  void CombineBounds(arma::mat& matrix, TCosineTree *node,
-		     TCosineTree *left, TCosineTree *right) {
+      std::vector<bool> left_membership;
+      left_membership.reserve(node->count());
     
-    // First clear the internal node center.
-    node->bound().center().zeros();
+      for (size_t left_ind = first; left_ind < end; left_ind++) {
+	  
+	if(cosine_vec(left_ind - first) < median_cosine_value) {	
+	  // the outer ring
+	  left_membership[left_ind - first] = false;
+	} else {
+	  // the inner ring
+	  left_membership[left_ind - first] = true;
+	  left_count++;
+	}
+      }
 
-    // Compute the weighted sum of the two pivots
-    node->bound().center() += left->count() * left->bound().center();
-    node->bound().center() += right->count() * right->bound().center();
-    node->bound().center() /= (double) node->count();
-//     node->bound().center() /= arma::norm(node->bound().center(), 2);
+      size_t left_ind = first;
+      size_t right_ind = end - 1;
     
-//     printf("c_norm: %lg\n", arma::norm(node->bound().center(), 2));
+      /* At any point:
+       *
+       *   everything < left_ind is correct
+       *   everything > right_ind is correct
+       */
+      for (;;) {
+	while (left_membership[left_ind - first] && (left_ind <= right_ind)) {
+	  left_ind++;
+	}
 
-    double left_min_cosine, right_min_cosine;
-    FurthestColumnIndex(node->bound().center(), matrix, left->begin(), 
-			left->count(), &left_min_cosine);
-    FurthestColumnIndex(node->bound().center(), matrix, right->begin(), 
-			right->count(), &right_min_cosine);    
-    node->bound().set_radius(std::min(left_min_cosine, right_min_cosine));
+	while (!left_membership[right_ind - first] && (left_ind <= right_ind)) {
+	  right_ind--;
+	}
+
+	if (left_ind > right_ind) {
+	  /* left == right_ind + 1 */
+	  break;
+	}
+
+	// Swap the left vector with the right_ind vector.
+	matrix.swap_cols(left_ind, right_ind);
+
+	bool tmp = left_membership[left_ind - first];
+	left_membership[left_ind - first] = left_membership[right_ind - first];
+	left_membership[right_ind - first] = tmp;
+      
+	if (old_from_new) {
+	  size_t t = old_from_new[left_ind];
+	  old_from_new[left_ind] = old_from_new[right_ind];
+	  old_from_new[right_ind] = t;
+	}
+      
+	assert(left_ind <= right_ind);
+	right_ind--;
+      }
+    
+      assert(left_ind == right_ind + 1);
+
+      (*left)->Init(node->begin(), left_count);
+      (*right)->Init(node->begin() + left_count, node->count() - left_count);
+	
+      return true;    
+    }
   }
 
-  // fixed
+
+  // not compiled
   template<typename TCosineTree>
   void SplitGenCosineTree(arma::mat& matrix, TCosineTree *node,
 			  size_t leaf_size, size_t *old_from_new) {
@@ -233,7 +157,6 @@
       if(can_cut) {
 	SplitGenCosineTree(matrix, left, leaf_size, old_from_new);
 	SplitGenCosineTree(matrix, right, leaf_size, old_from_new);
-	CombineBounds(matrix, node, left, right);
       }
       else {
 	MakeLeafCosineTreeNode(matrix, node->begin(),

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -14,6 +14,8 @@
 #ifndef GEN_METRIC_TREE_H
 #define GEN_METRIC_TREE_H
 
+#define NDEBUG
+
 #include "general_spacetree.h"
 #include "gen_metric_tree_impl.h"
 

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h	2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h	2011-10-30 03:16:22 UTC (rev 10076)
@@ -8,6 +8,8 @@
 #ifndef GEN_METRIC_TREE_IMPL_H
 #define GEN_METRIC_TREE_IMPL_H
 
+#define NDEBUG
+
 #include <assert.h>
 #include <mlpack/core.h>
 #include <mlpack/core/tree/bounds.hpp>




More information about the mlpack-svn mailing list