[mlpack-svn] r10056 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Oct 27 16:54:10 EDT 2011
Author: pram
Date: 2011-10-27 16:54:10 -0400 (Thu, 27 Oct 2011)
New Revision: 10056
Modified:
mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h
mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h
mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
Log:
appropriate changes made to work with new mlpack + alt-angle-pruning added
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -8,8 +8,8 @@
#include <assert.h>
#include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/statistic.hpp>
#include <vector>
#include <armadillo>
#include "general_spacetree.h"
@@ -113,7 +113,7 @@
// TreeType are BinarySpaceTrees where the data are bounded by
// Euclidean bounding boxes, the data are stored in a Matrix,
// and each node has a QueryStat for its bound.
- typedef GeneralBinarySpaceTree<DBallBound<>, arma::mat> TreeType;
+ typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat> TreeType;
typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dconebound.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -11,7 +11,7 @@
#ifndef TREE_DCONEBOUND_H
#define TREE_DCONEBOUND_H
-#include "mlpack/core/math/math_lib.h"
+#include "mlpack/core/math/range.hpp"
// Awaiting transition
#include "cosine.h"
@@ -101,7 +101,7 @@
*
* Example: bound1.MinDistanceSq(other) for minimum squared distance.
*/
- Range RangeCosine(const DConeBound& other) const;
+ mlpack::math::Range RangeCosine(const DConeBound& other) const;
/**
* Calculates closest-to-their-midpoint bounding box distance,
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dconebound_impl.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -162,8 +162,8 @@
* Calculates minimum and maximum bound-to-bound cosine.
*/
template<typename TMetric, typename TPoint>
-Range DConeBound<TMetric, TPoint>::RangeCosine(const DConeBound& other) const {
- return Range(MinCosine(other), MaxCosine(other));
+mlpack::math::Range DConeBound<TMetric, TPoint>::RangeCosine(const DConeBound& other) const {
+ return mlpack::math::Range(MinCosine(other), MaxCosine(other));
}
/**
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-27 20:54:10 UTC (rev 10056)
@@ -19,16 +19,13 @@
arma::vec q = queries_.col(query_);
arma::vec centroid = reference_node->bound().center();
-// // +1: Can be cached in the reference tree
-// double c_norm = arma::norm(centroid, 2);
+ // +1: Can be cached in the reference tree
double c_norm = reference_node->stat().dist_to_origin();
assert(arma::norm(q, 2) == query_norms_(query_));
double rad = std::sqrt(reference_node->bound().radius());
- double max_cos_qr = 1.0;
-
if (mlpack::CLI::HasParam("maxip/angle_prune")) {
// tighter bound of \max_{r \in B_p^R} <q,r>
// = |q| \max_{r \in B_p^R} |r| cos <qr
@@ -36,14 +33,13 @@
// \leq |q| (|p|+R) if <qp \leq \max_r <pr
// \leq |q| (|p|+R) cos( <qp - \max_r <pr ) otherwise
+ double max_cos_qr = 1.0;
if (rad <= c_norm) {
// +1
double cos_qp = arma::dot(q, centroid)
/ (query_norms_(query_) * c_norm);
double sin_qp = std::sqrt(1 - cos_qp * cos_qp);
-// double max_sin_pr = rad / c_norm;
-// double min_cos_pr = std::sqrt(1 - max_sin_pr * max_sin_pr);
double max_sin_pr = reference_node->stat().sine_origin();
double min_cos_pr = reference_node->stat().cosine_origin();;
@@ -60,14 +56,25 @@
ball_has_origin_++;
}
- }
+ return (query_norms_(query_) * (c_norm + rad) * max_cos_qr);
+ } // angle-prune
+ if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) {
+ // tighter bound of \max_{r \in B_p^R} <q,r>
+ // = |q| \max_{r \in B_p^R} |r| cos <qr
+ // \leq |q| (|p| cos <qp + R ) (closed-form solution
+ // the maximization above (I think it is correct))
+ // = ( <q, p> + |p| R )
+
+ // +1
+ return (arma::dot(q, centroid) + (query_norms_(query_) * rad));
+ } // alt-angle-prune
+
// Otherwise :
// simple bound of \max_{r \in B_p^R} <q,r>
// = |q| \max_{r \in B_p^R} |r| cos <qr
// \leq |q| \max_{r \in B_p^R} |r| \leq |q| (|p|+R)
-
- return (query_norms_(query_) * (c_norm + rad) * max_cos_qr);
+ return (query_norms_(query_) * (c_norm + rad));
}
double MaxIP::MaxNodeIP_(CTreeType* query_node,
@@ -128,9 +135,22 @@
} else {
ball_has_origin_++;
}
- }
+ return ((c_norm + rad) * max_cos_qp);
+ } // angle-prune
- return ((c_norm + rad) * max_cos_qp);
+ if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) {
+ // using the closed-form-maximization,
+ // |p| cos (phi - w) + R
+ // +1
+ double c_norm_cos_phi = arma::dot(q, centroid) / q_norm;
+ double c_norm_sin_phi = std::sqrt(c_norm * c_norm
+ - c_norm_cos_phi * c_norm_cos_phi);
+
+ return (c_norm_cos_phi * cos_w + c_norm_sin_phi * sin_w + rad);
+ } // alt-angle-prune
+
+ return (c_norm + rad);
+
}
@@ -642,7 +662,10 @@
double rad = std::sqrt(tree->bound().radius());
tree->stat().set_dist_to_origin(c_norm);
- tree->stat().set_angles(rad / c_norm, (size_t) 1);
+ if (rad <= c_norm)
+ tree->stat().set_angles(rad / c_norm, (size_t) 1);
+ else
+ tree->stat().set_angles(-1.0, (size_t) 0);
// traverse down the children
if (!tree->is_leaf()) {
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -6,12 +6,12 @@
#ifndef EXACT_MAX_IP_H
#define EXACT_MAX_IP_H
-// #define NDEBUG
+#define NDEBUG
#include <assert.h>
#include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/statistic.hpp>
#include <vector>
#include <armadillo>
#include "general_spacetree.h"
@@ -116,7 +116,7 @@
// TreeType are BinarySpaceTrees where the data are bounded by
// Euclidean bounding boxes, the data are stored in a Matrix,
// and each node has a QueryStat for its bound.
- typedef GeneralBinarySpaceTree<DBallBound<>, arma::mat, RefStat> TreeType;
+ typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
@@ -290,6 +290,8 @@
PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
" pruning using the angles as well", "maxip");
+PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
+ " pruning using the angles as well", "maxip");
PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
"computation, using a cosine tree for the "
"queries.", "maxip");
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -10,7 +10,7 @@
#include <assert.h>
#include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.h>
+#include <mlpack/core/tree/bounds.hpp>
#include <armadillo>
namespace tree_gen_metric_tree_private {
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/general_spacetree.h 2011-10-27 20:54:10 UTC (rev 10056)
@@ -14,7 +14,7 @@
#include <assert.h>
#include <mlpack/core.h>
-#include <mlpack/core/tree/statistic.h>
+#include <mlpack/core/tree/statistic.hpp>
//#include <mlpack/core.h>
/**
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc 2011-10-27 20:49:55 UTC (rev 10055)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/max_ip_main.cc 2011-10-27 20:54:10 UTC (rev 10056)
@@ -47,10 +47,10 @@
string qfile = CLI::GetParam<string>("q");
Log::Info << "Loading files..." << endl;
- if (!data::Load(rfile.c_str(), rdata))
+ if (rdata.load(rfile.c_str()) == false)
Log::Fatal << "Reference file "<< rfile << " not found." << endl;
- if (!data::Load(qfile.c_str(), qdata))
+ if (qdata.load(qfile.c_str()) == false)
Log::Fatal << "Query file " << qfile << " not found." << endl;
Log::Info << "File loaded..." << endl;
@@ -60,12 +60,12 @@
<< ")" << endl;
- arma::Mat<size_t> nac, exc, apc;
- arma::mat din, die, dia;
+ arma::Mat<size_t> nac, exc;
+ arma::mat din, die;
//size_t knns = CLI::GetParam<int>("maxip/knns");
- double naive_comp, fast_comp, approx_comp;
+ double naive_comp, fast_comp;
// Naive computation
More information about the mlpack-svn
mailing list