[mlpack-svn] r10056 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 27 16:54:10 EDT 2011


Author: pram
Date: 2011-10-27 16:54:10 -0400 (Thu, 27 Oct 2011)
New Revision: 10056

Modified:
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h
   mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
   mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h
   mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
Log:
appropriate changes made to work with new mlpack + alt-angle-pruning added

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -8,8 +8,8 @@
 
 #include <assert.h>
 #include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/statistic.hpp>
 #include <vector>
 #include <armadillo>
 #include "general_spacetree.h"
@@ -113,7 +113,7 @@
   // TreeType are BinarySpaceTrees where the data are bounded by 
   // Euclidean bounding boxes, the data are stored in a Matrix, 
   // and each node has a QueryStat for its bound.
-  typedef GeneralBinarySpaceTree<DBallBound<>, arma::mat> TreeType;
+  typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat> TreeType;
   typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
    
   

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -11,7 +11,7 @@
 #ifndef TREE_DCONEBOUND_H
 #define TREE_DCONEBOUND_H
 
-#include "mlpack/core/math/math_lib.h"
+#include "mlpack/core/math/range.hpp"
 
 // Awaiting transition
 #include "cosine.h"
@@ -101,7 +101,7 @@
      *
      * Example: bound1.MinDistanceSq(other) for minimum squared distance.
      */
-    Range RangeCosine(const DConeBound& other) const;
+    mlpack::math::Range RangeCosine(const DConeBound& other) const;
 
     /**
      * Calculates closest-to-their-midpoint bounding box distance,

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -162,8 +162,8 @@
  * Calculates minimum and maximum bound-to-bound cosine.
  */
 template<typename TMetric, typename TPoint>
-Range DConeBound<TMetric, TPoint>::RangeCosine(const DConeBound& other) const {
-  return Range(MinCosine(other), MaxCosine(other));
+mlpack::math::Range DConeBound<TMetric, TPoint>::RangeCosine(const DConeBound& other) const {
+  return mlpack::math::Range(MinCosine(other), MaxCosine(other));
 }
 
 /**

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-27 20:54:10 UTC (rev 10056)
@@ -19,16 +19,13 @@
   arma::vec q = queries_.col(query_);
   arma::vec centroid = reference_node->bound().center();
 
-//   // +1: Can be cached in the reference tree
-//   double c_norm = arma::norm(centroid, 2);
+  // +1: Can be cached in the reference tree
   double c_norm = reference_node->stat().dist_to_origin();
 
   assert(arma::norm(q, 2) == query_norms_(query_));
 
   double rad = std::sqrt(reference_node->bound().radius());
 
-  double max_cos_qr = 1.0;
-
   if (mlpack::CLI::HasParam("maxip/angle_prune")) { 
     // tighter bound of \max_{r \in B_p^R} <q,r> 
     //    = |q| \max_{r \in B_p^R} |r| cos <qr 
@@ -36,14 +33,13 @@
     //    \leq |q| (|p|+R) if <qp \leq \max_r <pr
     //    \leq |q| (|p|+R) cos( <qp - \max_r <pr ) otherwise
 
+    double max_cos_qr = 1.0;
     if (rad <= c_norm) {
       // +1
       double cos_qp = arma::dot(q, centroid) 
 	/ (query_norms_(query_) * c_norm);
       double sin_qp = std::sqrt(1 - cos_qp * cos_qp);
 
-//       double max_sin_pr = rad / c_norm;
-//       double min_cos_pr = std::sqrt(1 - max_sin_pr * max_sin_pr);
       double max_sin_pr = reference_node->stat().sine_origin();
       double min_cos_pr = reference_node->stat().cosine_origin();;
 
@@ -60,14 +56,25 @@
       ball_has_origin_++;
     }
 
-  }
+    return (query_norms_(query_) * (c_norm + rad) * max_cos_qr);
+  } // angle-prune
 
+  if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) { 
+    // tighter bound of \max_{r \in B_p^R} <q,r> 
+    //    = |q| \max_{r \in B_p^R} |r| cos <qr 
+    //    \leq  |q| (|p| cos <qp + R )  (closed-form solution 
+    // the maximization above (I think it is correct))
+    //    = ( <q, p> + |p| R )
+
+    // +1
+    return (arma::dot(q, centroid) + (query_norms_(query_) * rad));
+  } // alt-angle-prune
+  
   // Otherwise :
   // simple bound of \max_{r \in B_p^R} <q,r> 
   //    = |q| \max_{r \in B_p^R} |r| cos <qr 
   //    \leq |q| \max_{r \in B_p^R} |r| \leq |q| (|p|+R)
-
-  return (query_norms_(query_) * (c_norm + rad) * max_cos_qr);
+  return (query_norms_(query_) * (c_norm + rad));
 }
 
 double MaxIP::MaxNodeIP_(CTreeType* query_node,
@@ -128,9 +135,22 @@
     } else {
       ball_has_origin_++;
     }
-  }
+    return ((c_norm + rad) * max_cos_qp);
+  } // angle-prune
 
-  return ((c_norm + rad) * max_cos_qp);
+  if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) { 
+    // using the closed-form-maximization,
+    // |p| cos (phi - w) + R
+      // +1
+    double c_norm_cos_phi = arma::dot(q, centroid) / q_norm;
+    double c_norm_sin_phi = std::sqrt(c_norm * c_norm 
+				      - c_norm_cos_phi * c_norm_cos_phi);
+
+    return (c_norm_cos_phi * cos_w + c_norm_sin_phi * sin_w + rad);
+  } // alt-angle-prune
+
+  return (c_norm + rad);
+
 }
 
 
@@ -642,7 +662,10 @@
   double rad = std::sqrt(tree->bound().radius());
   
   tree->stat().set_dist_to_origin(c_norm);
-  tree->stat().set_angles(rad / c_norm, (size_t) 1);
+  if (rad <= c_norm)
+    tree->stat().set_angles(rad / c_norm, (size_t) 1);
+  else
+    tree->stat().set_angles(-1.0, (size_t) 0);
 
   // traverse down the children
   if (!tree->is_leaf()) {

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -6,12 +6,12 @@
 #ifndef EXACT_MAX_IP_H
 #define EXACT_MAX_IP_H
 
-// #define NDEBUG
+#define NDEBUG
 
 #include <assert.h>
 #include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/statistic.hpp>
 #include <vector>
 #include <armadillo>
 #include "general_spacetree.h"
@@ -116,7 +116,7 @@
   // TreeType are BinarySpaceTrees where the data are bounded by 
   // Euclidean bounding boxes, the data are stored in a Matrix, 
   // and each node has a QueryStat for its bound.
-  typedef GeneralBinarySpaceTree<DBallBound<>, arma::mat, RefStat> TreeType;
+  typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
   typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
    
   
@@ -290,6 +290,8 @@
 
 PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
 	   " pruning using the angles as well", "maxip");
+PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
+	   " pruning using the angles as well", "maxip");
 PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
 	   "computation, using a cosine tree for the "
 	   "queries.", "maxip");

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -10,7 +10,7 @@
 
 #include <assert.h>
 #include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
+#include <mlpack/core/tree/bounds.hpp>
 #include <armadillo>
 
 namespace tree_gen_metric_tree_private {

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h	2011-10-27 20:54:10 UTC (rev 10056)
@@ -14,7 +14,7 @@
 
 #include <assert.h>
 #include <mlpack/core.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/statistic.hpp>
 //#include <mlpack/core.h>
 
 /**

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc	2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc	2011-10-27 20:54:10 UTC (rev 10056)
@@ -47,10 +47,10 @@
   string qfile = CLI::GetParam<string>("q");
 
   Log::Info << "Loading files..." << endl;
-  if (!data::Load(rfile.c_str(), rdata))
+  if (rdata.load(rfile.c_str()) == false)
     Log::Fatal << "Reference file "<< rfile << " not found." << endl;
 
-  if (!data::Load(qfile.c_str(), qdata)) 
+  if (qdata.load(qfile.c_str()) == false) 
     Log::Fatal << "Query file " << qfile << " not found." << endl;
 
   Log::Info << "File loaded..." << endl;
@@ -60,12 +60,12 @@
 	   << ")" << endl;
 
 
-  arma::Mat<size_t> nac, exc, apc;
-  arma::mat din, die, dia;
+  arma::Mat<size_t> nac, exc;
+  arma::mat din, die;
 
   //size_t knns = CLI::GetParam<int>("maxip/knns");
 
-  double naive_comp, fast_comp, approx_comp;
+  double naive_comp, fast_comp;
 
 
   // Naive computation




More information about the mlpack-svn mailing list