[mlpack-svn] r10076 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Oct 29 23:16:22 EDT 2011
Author: pram
Date: 2011-10-29 23:16:22 -0400 (Sat, 29 Oct 2011)
New Revision: 10076
Added:
mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h
Modified:
mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
Log:
Cosine tree -- co-axial cone tree added and tested
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-10-30 03:16:22 UTC (rev 10076)
@@ -10,14 +10,17 @@
gen_metric_tree.h
gen_metric_tree_impl.h
gen_metric_tree_impl.cc
+
+ cosine.h
+
+ # Cosine tree
+ dcosinebound.h
+ gen_cosine_tree.h
+ gen_cosine_tree_impl.h
+ #gen_cosine_tree_impl.cc
# Cone bound
- cosine.h
dconebound.h
dconebound_impl.h
- # Cosine tree
- #gen_cosine_tree.h
- #gen_cosine_tree_impl.h
- #gen_cosine_tree_impl.cc
# cone tree
gen_cone_tree.h
gen_cone_tree_impl.h
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc 2011-10-30 03:16:22 UTC (rev 10076)
@@ -1,9 +1,12 @@
#include <armadillo>
#include <string>
#include "general_spacetree.h"
-#include "dconebound.h"
+// #include "dconebound.h"
+// #include "gen_cosine_tree.h"
+#include "dcosinebound.h"
#include "gen_cosine_tree.h"
+
using namespace mlpack;
using namespace std;
@@ -32,21 +35,20 @@
rdata = arma::trans(rdata);
- typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat> CTreeType;
+ typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat> CTreeType;
arma::Col<size_t> old_from_new_data;
CTreeType *test_tree
= proximity::MakeGenCosineTree<CTreeType>(rdata, 20,
- &old_from_new_data,
- NULL);
+ &old_from_new_data); //,
+// NULL);
if (CLI::HasParam("print_tree")) {
test_tree->Print();
} else {
Log::Info << "Tree built" << endl;
}
-
} // end main
Added: mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/dcosinebound.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -0,0 +1,139 @@
+/**
+ * @file dcosinebound.h
+ *
+ * Bounds that are useful for binary space partitioning trees.
+ * Interface to a cosine bound that works with cosine
+ * similarity only (for now)
+ *
+ * @experimental
+ */
+
+#ifndef TREE_DCOSINEBOUND_H
+#define TREE_DCOSINEBOUND_H
+
+#include "mlpack/core/math/range.hpp"
+
+// Awaiting transition
+#include "cosine.h"
+
+#include <armadillo>
+
+/**
+ * Cosine bound that works in arbitrary metric spaces.
+ *
+ * See LMetric for an example metric template parameter.
+ *
+ * To initialize this, set the radius with @c set_radius
+ * and set the point by initializing @c point() directly.
+ */
+template<typename TMetric = Cosine, typename TPoint = arma::vec>
+class DCosineBound {
+public:
+ typedef TPoint Point;
+ typedef TMetric Metric;
+
+private:
+ double rad_min_;
+ double rad_max_;
+ double radius_;
+ TPoint center_;
+
+public:
+ /***
+ * Return the radius of the cosine bound.
+ */
+ double rad_min() const { return rad_min_; }
+ double rad_max() const { return rad_max_; }
+ double radius() const { return radius_; }
+
+ /***
+ * Set the radius of the bound.
+ */
+ void set_radius(double rad_min, double rad_max) {
+
+ rad_min_ = rad_min;
+ rad_max_ = rad_max;
+ radius_ = rad_max - rad_min;
+ }
+
+ /***
+ * Return the center point.
+ */
+ const TPoint& center() const { return center_; }
+
+ /***
+ * Return the center point.
+ */
+ TPoint& center() { return center_; }
+
+ // IMPLEMENT THESE LATER IN CASE THIS WORKS OUT
+ // /**
+ // * Determines if a point is within this bound.
+ // */
+ // bool Contains(const Point& point) const;
+
+ // /**
+ // * Gets the center.
+ // *
+ // * Don't really use this directly. This is only here for consistency
+ // * with DHrectBound, so it can plug in more directly if a "centroid"
+ // * is needed.
+ // */
+ // void CalculateMidpoint(Point *centroid) const;
+
+ // /**
+ // * Calculates maximum bound-to-point cosine.
+ // */
+ // double MaxCosine(const Point& point) const;
+
+ // /**
+ // * Calculates maximum bound-to-bound cosine.
+ // */
+ // double MaxCosine(const DCosineBound& other) const;
+
+ // /**
+ // * Computes maximum distance.
+ // */
+ // double MinCosine(const Point& point) const;
+
+ // /**
+ // * Computes maximum distance.
+ // */
+ // double MinCosine(const DCosineBound& other) const;
+
+ // /**
+ // * Calculates minimum and maximum bound-to-bound squared distance.
+ // *
+ // * Example: bound1.MinDistanceSq(other) for minimum squared distance.
+ // */
+ // mlpack::math::Range RangeCosine(const DCosineBound& other) const;
+
+ // /**
+ // * Calculates closest-to-their-midpoint bounding box distance,
+ // * i.e. calculates their midpoint and finds the minimum box-to-point
+ // * distance.
+ // *
+ // * Equivalent to:
+ // * <code>
+ // * other.CalcMidpoint(&other_midpoint)
+ // * return MaxCosineToPoint(other_midpoint)
+ // * </code>
+ // */
+ // double MaxToMid(const DCosineBound& other) const;
+
+ // /**
+ // * Computes minimax distance, where the other node is trying to avoid me.
+ // */
+ // double MinimaxCosine(const DCosineBound& other) const;
+
+ // /**
+ // * Calculates midpoint-to-midpoint bounding box distance.
+ // */
+ // double MidCosine(const DCosineBound& other) const;
+ // double MidCosine(const Point& point) const;
+
+}; // DCosineBound
+
+// #include "dcosinebound_impl.h"
+
+#endif
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-30 03:16:22 UTC (rev 10076)
@@ -77,7 +77,9 @@
return (query_norms_(query_) * (c_norm + rad));
}
-double MaxIP::MaxNodeIP_(CTreeType* query_node,
+
+// CONE TREE
+double MaxIP::MaxNodeIP_(CTreeTypeA* query_node,
TreeType* reference_node) {
// counting the split decisions
@@ -85,16 +87,15 @@
// min_{q', q} cos <qq' = cos_w
arma::vec q = query_node->bound().center();
+
+ // FOR CONE TREE
double cos_w = query_node->bound().radius();
double sin_w = query_node->bound().radius_conjugate();
- // +1: Can cache it in the query tree
- // double q_norm = arma::norm(q, 2);
double q_norm = query_node->stat().center_norm();
arma::vec centroid = reference_node->bound().center();
- // +1: can be cached in the reference tree
double c_norm = reference_node->stat().dist_to_origin();
double rad = std::sqrt(reference_node->bound().radius());
@@ -164,7 +165,78 @@
} // alt-angle-prune
return (c_norm + rad);
+}
+
+// COSINE TREE
+double MaxIP::MaxNodeIP_(CTreeTypeB* query_node,
+ TreeType* reference_node) {
+
+ // counting the split decisions
+ split_decisions_++;
+
+ // min_{q', q} cos <qq' = cos_w
+ arma::vec q = query_node->bound().center();
+
+ // FOR COSINE TREE
+ double cos_w_min = query_node->bound().rad_min();
+ double cos_w_max = query_node->bound().rad_max();
+
+ double q_norm = query_node->stat().center_norm();
+
+ arma::vec centroid = reference_node->bound().center();
+
+ double c_norm = reference_node->stat().dist_to_origin();
+ double rad = std::sqrt(reference_node->bound().radius());
+
+ if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) {
+ // HAVE TO WORK THIS OUT AND CHECK IT ANALYTICALLY
+ // using the closed-form-maximization,
+ // |p| cos (phi - w) + R
+
+ double cos_phi, sin_phi;
+
+ if (!reference_node->stat().has_cos_phi()) {
+ // +1
+ cos_phi = arma::dot(q, centroid) / (c_norm * q_norm);
+ reference_node->stat().set_cos_phi(cos_phi);
+ } else {
+ cos_phi = reference_node->stat().cos_phi();
+ split_decisions_--;
+ }
+
+ sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+
+ if (cos_phi < cos_w_min) {
+
+ // bound with cos(phi - w_min)
+ double sin_w_min = std::sqrt(1 - cos_w_min * cos_w_min);
+ return (c_norm * (cos_phi * cos_w_min + sin_phi * sin_w_min) + rad);
+
+ } else if (cos_phi > cos_w_max) {
+
+ // bound with cos (w_max - phi)
+ double sin_w_max = std::sqrt(1 - cos_w_max * cos_w_max);
+ return (c_norm * (cos_phi * cos_w_max + sin_phi * sin_w_max) + rad);
+
+ } else {
+ // cos w_min < cos phi < cos w_max;
+ // hence the query cone has the centroid
+ // so no useful bounding
+
+ cone_has_centroid_ += query_node->end() - query_node->begin();
+
+ // maybe trigger a flag which lets the peep know that
+ // query cone too wide
+ // Although, not sure how that will benefit since the
+ // current traversal kind of takes care of this.
+
+ return (c_norm + rad);
+ }
+ } // alt-angle-prune
+
+ return (c_norm + rad);
+
}
@@ -273,7 +345,8 @@
} // ComputeNeighborsRecursion_
-void MaxIP::ComputeBaseCase_(CTreeType* query_node,
+// CONE TREE
+void MaxIP::ComputeBaseCase_(CTreeTypeA* query_node,
TreeType* reference_node) {
// Check that the pointers are not NULL
@@ -313,10 +386,55 @@
query_node->stat().set_bound(query_worst_p_cos_pq);
} // ComputeBaseCase_
+
+
+
+// COSINE TREE
+void MaxIP::ComputeBaseCase_(CTreeTypeB* query_node,
+ TreeType* reference_node) {
+
+ // Check that the pointers are not NULL
+ assert(reference_node != NULL);
+ assert(reference_node->is_leaf());
+ assert(query_node != NULL);
+
+ // query node may not be a leaf
+ // assert(query_node->is_leaf());
+
+ // Used to find the query node's new lower bound
+ double query_worst_p_cos_pq = DBL_MAX;
+ bool new_bound = false;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // checking if this node has potential
+ double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+ if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+ // this node has potential
+ ComputeBaseCase_(reference_node);
+
+ double p_cos_pq = max_ips_(knns_ -1, query_)
+ / query_norms_(query_);
+
+ if (query_worst_p_cos_pq > p_cos_pq) {
+ query_worst_p_cos_pq = p_cos_pq;
+ new_bound = true;
+ }
+ } // for query_
+ // Update the lower bound for the query_node
+ if (new_bound)
+ query_node->stat().set_bound(query_worst_p_cos_pq);
-void MaxIP::CheckPrune(CTreeType* query_node, TreeType* ref_node) {
+} // ComputeBaseCase_
+
+// CONE TREE
+void MaxIP::CheckPrune(CTreeTypeA* query_node, TreeType* ref_node) {
+
size_t missed_nns = 0;
double max_p_cos_pq = 0.0;
double min_p_cos_pq = DBL_MAX;
@@ -360,7 +478,57 @@
}
-void MaxIP::ComputeNeighborsRecursion_(CTreeType* query_node,
+
+
+// COSINE TREE
+void MaxIP::CheckPrune(CTreeTypeB* query_node, TreeType* ref_node) {
+
+ size_t missed_nns = 0;
+ double max_p_cos_pq = 0.0;
+ double min_p_cos_pq = DBL_MAX;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // Get the query point from the matrix
+ arma::vec q = queries_.unsafe_col(query_);
+
+ double p_cos_qp = max_ips_(knns_ -1, query_) / query_norms_(query_);
+ if (min_p_cos_pq > p_cos_qp)
+ min_p_cos_pq = p_cos_qp;
+
+ // We'll do the same for the references
+ for (size_t reference_index = ref_node->begin();
+ reference_index < ref_node->end(); reference_index++) {
+
+ arma::vec r = references_.unsafe_col(reference_index);
+
+ double ip = arma::dot(q, r);
+ if (ip > max_ips_(knns_-1, query_))
+ missed_nns++;
+
+ double p_cos_pq = ip / query_norms_(query_);
+
+ if (p_cos_pq > max_p_cos_pq)
+ max_p_cos_pq = p_cos_pq;
+
+ } // for reference_index
+ } // for query_
+
+ if (missed_nns > 0 || query_node->stat().bound() != min_p_cos_pq)
+ printf("Prune %zu - Missed candidates: %zu\n"
+ "QLBound: %lg, ActualQLBound: %lg\n"
+ "QRBound: %lg, ActualQRBound: %lg\n",
+ number_of_prunes_, missed_nns,
+ query_node->stat().bound(), min_p_cos_pq,
+ MaxNodeIP_(query_node, ref_node), max_p_cos_pq);
+
+}
+
+
+// CONE TREE
+void MaxIP::ComputeNeighborsRecursion_(CTreeTypeA* query_node,
TreeType* reference_node,
double upper_bound_p_cos_pq) {
@@ -517,8 +685,152 @@
} // alt-traversal
} // All cases of dual-tree traversal
} // ComputeNeighborsRecursion_
+
+
+
+// COSINE TREE
+void MaxIP::ComputeNeighborsRecursion_(CTreeTypeB* query_node,
+ TreeType* reference_node,
+ double upper_bound_p_cos_pq) {
+
+ assert(query_node != NULL);
+ assert(reference_node != NULL);
+ //assert(upper_bound_p_cos_pq == MaxNodeIP_(query_node, reference_node));
+
+ if (upper_bound_p_cos_pq < query_node->stat().bound()) {
+ // Pruned
+ number_of_prunes_ += query_node->end() - query_node->begin();
+
+ if (CLI::HasParam("maxip/check_prune"))
+ CheckPrune(query_node, reference_node);
+ }
+ // node->is_leaf() works as one would expect
+ else if (query_node->is_leaf() && reference_node->is_leaf()) {
+ // Base Case
+ ComputeBaseCase_(query_node, reference_node);
+ } else if (query_node->is_leaf()) {
+ // Only query is a leaf
+
+
+ if (CLI::HasParam("maxip/alt_dual_traversal")) {
+ // Trying to do single-tree on the leaves
+
+ // Used to find the query node's new lower bound
+ double query_worst_p_cos_pq = DBL_MAX;
+ bool new_bound = false;
+
+ // Iterating over the queries individually
+ for (query_ = query_node->begin();
+ query_ < query_node->end(); query_++) {
+
+ // checking if this node has potential
+ double query_to_node_max_ip = MaxNodeIP_(reference_node);
+
+ if (query_to_node_max_ip > max_ips_(knns_ -1, query_))
+ // this node has potential
+ ComputeNeighborsRecursion_(reference_node, query_to_node_max_ip);
+
+ double p_cos_pq = max_ips_(knns_ -1, query_)
+ / query_norms_(query_);
+
+ if (query_worst_p_cos_pq > p_cos_pq) {
+ query_worst_p_cos_pq = p_cos_pq;
+ new_bound = true;
+ }
+ } // for query_
+ // Update the lower bound for the query_node
+ if (new_bound)
+ query_node->stat().set_bound(query_worst_p_cos_pq);
+ } else {
+ // We'll order the computation by distance
+ double left_p_cos_pq = MaxNodeIP_(query_node,
+ reference_node->left());
+ double right_p_cos_pq = MaxNodeIP_(query_node,
+ reference_node->right());
+
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_p_cos_pq);
+ }
+ } // alt-traversal
+ } else if (reference_node->is_leaf()) {
+ // Only reference is a leaf
+ double left_p_cos_pq
+ = MaxNodeIP_(query_node->left(), reference_node);
+ double right_p_cos_pq
+ = MaxNodeIP_(query_node->right(), reference_node);
+
+ ComputeNeighborsRecursion_(query_node->left(), reference_node,
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->right(), reference_node,
+ right_p_cos_pq);
+
+ // We need to update the upper bound based on the new upper bounds of
+ // the children
+ query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+ query_node->right()->stat().bound()));
+ } else {
+
+ // Recurse on both as above
+ double left_p_cos_pq = MaxNodeIP_(query_node->left(),
+ reference_node->left());
+ double right_p_cos_pq = MaxNodeIP_(query_node->left(),
+ reference_node->right());
+
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(),
+ left_p_cos_pq);
+ }
+
+ left_p_cos_pq = MaxNodeIP_(query_node->right(),
+ reference_node->left());
+ right_p_cos_pq = MaxNodeIP_(query_node->right(),
+ reference_node->right());
+
+ if (left_p_cos_pq > right_p_cos_pq) {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_p_cos_pq);
+ } else {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(),
+ right_p_cos_pq);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(),
+ left_p_cos_pq);
+ }
+
+ // Update the upper bound as above
+ query_node->stat().set_bound(std::min(query_node->left()->stat().bound(),
+ query_node->right()->stat().bound()));
+ } // All cases of dual-tree traversal
+} // ComputeNeighborsRecursion_
+
+
+
void MaxIP::Init(const arma::mat& queries_in,
const arma::mat& references_in) {
@@ -564,12 +876,27 @@
set_angles_in_balls_(reference_tree_);
if (mlpack::CLI::HasParam("maxip/dual_tree")) {
- query_tree_
- = proximity::MakeGenConeTree<CTreeType>(queries_,
- leaf_size_,
- &old_from_new_queries_,
- NULL);
- set_norms_in_cones_(query_tree_);
+
+ if (mlpack::CLI::HasParam("maxip/alt_tree")) {
+
+ // using cosine tree
+ query_tree_B_
+ = proximity::MakeGenCosineTree<CTreeTypeB>(queries_,
+ leaf_size_,
+ &old_from_new_queries_,
+ NULL);
+ set_norms_in_cones_(query_tree_B_);
+
+ } else {
+
+ // using cone tree
+ query_tree_A_
+ = proximity::MakeGenConeTree<CTreeTypeA>(queries_,
+ leaf_size_,
+ &old_from_new_queries_,
+ NULL);
+ set_norms_in_cones_(query_tree_A_);
+ }
}
// saving the query norms beforehand to use
@@ -652,13 +979,22 @@
max_ips_ = 0.0 * arma::ones<arma::mat>(knns_, queries_.n_cols);
// need to reset the querystats in the Query Tree
- if (mlpack::CLI::HasParam("maxip/dual_tree"))
- if (query_tree_ != NULL)
- reset_tree_(query_tree_);
+ if (mlpack::CLI::HasParam("maxip/dual_tree"))
+ if (mlpack::CLI::HasParam("maxip/alt_tree")) {
+ if (query_tree_B_ != NULL)
+ reset_tree_(query_tree_B_);
+ reset_tree_(reference_tree_);
+
+ } else {
+ if (query_tree_A_ != NULL)
+ reset_tree_(query_tree_A_);
+ }
+
+
} // WarmInit
-void MaxIP::reset_tree_(CTreeType* tree) {
+void MaxIP::reset_tree_(CTreeTypeA* tree) {
assert(tree != NULL);
tree->stat().set_bound(0.0);
@@ -670,6 +1006,32 @@
return;
} // reset_tree_
+
+void MaxIP::reset_tree_(CTreeTypeB* tree) {
+ assert(tree != NULL);
+ tree->stat().set_bound(0.0);
+
+ if (!tree->is_leaf()) {
+ reset_tree_(tree->left());
+ reset_tree_(tree->right());
+ }
+
+ return;
+} // reset_tree_
+
+
+void MaxIP::reset_tree_(TreeType* tree) {
+ assert(tree != NULL);
+ tree->stat().reset();
+
+ if (!tree->is_leaf()) {
+ reset_tree_(tree->left());
+ reset_tree_(tree->right());
+ }
+
+ return;
+} // reset_tree_
+
void MaxIP::set_angles_in_balls_(TreeType* tree) {
assert(tree != NULL);
@@ -694,7 +1056,7 @@
} // set_angles_in_balls_
-void MaxIP::set_norms_in_cones_(CTreeType* tree) {
+void MaxIP::set_norms_in_cones_(CTreeTypeA* tree) {
assert(tree != NULL);
@@ -710,7 +1072,25 @@
return;
} // set_norms_in_cones_
+void MaxIP::set_norms_in_cones_(CTreeTypeB* tree) {
+ assert(tree != NULL);
+
+ // set up node stats
+ tree->stat().set_center_norm(arma::norm(tree->bound().center(), 2));
+
+ // traverse down the children
+ if (!tree->is_leaf()) {
+ set_norms_in_cones_(tree->left());
+ set_norms_in_cones_(tree->right());
+ }
+
+ return;
+} // set_norms_in_cones_
+
+
+
+
double MaxIP::ComputeNeighbors(arma::Mat<size_t>* resulting_neighbors,
arma::mat* ips) {
@@ -722,8 +1102,14 @@
CLI::StartTimer("maxip/fast_dual");
- ComputeNeighborsRecursion_(query_tree_, reference_tree_,
- MaxNodeIP_(query_tree_, reference_tree_));
+
+ if (mlpack::CLI::HasParam("maxip/alt_tree"))
+ ComputeNeighborsRecursion_(query_tree_B_, reference_tree_,
+ MaxNodeIP_(query_tree_B_, reference_tree_));
+ else
+ ComputeNeighborsRecursion_(query_tree_A_, reference_tree_,
+ MaxNodeIP_(query_tree_A_, reference_tree_));
+
CLI::StopTimer("maxip/fast_dual");
resulting_neighbors->set_size(max_ip_indices_.n_rows,
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -6,7 +6,7 @@
#ifndef EXACT_MAX_IP_H
#define EXACT_MAX_IP_H
-//#define NDEBUG
+#define NDEBUG
#include <assert.h>
#include <mlpack/core.h>
@@ -18,6 +18,8 @@
#include "gen_metric_tree.h"
#include "dconebound.h"
#include "gen_cone_tree.h"
+#include "dcosinebound.h"
+#include "gen_cosine_tree.h"
using namespace mlpack;
@@ -70,11 +72,15 @@
double cosine_origin_;
double sine_origin_;
double dist_to_origin_;
+ bool has_cos_phi_;
+ double cos_phi_;
public:
double cosine_origin() { return cosine_origin_; }
double sine_origin() { return sine_origin_; }
double dist_to_origin() { return dist_to_origin_; }
+ bool has_cos_phi() { return has_cos_phi_; }
+ double cos_phi() { return cos_phi_; }
void set_angles(double val, size_t type = 0) {
@@ -91,11 +97,23 @@
dist_to_origin_ = val;
}
+ void set_cos_phi(double cos_phi) {
+ cos_phi_ = cos_phi;
+ has_cos_phi_ = true;
+ }
+
+ void reset() {
+ cos_phi_ = -1.0;
+ has_cos_phi_ = false;
+ }
+
// FILL OUT THE INIT FUNCTIONS APPROPRIATELY LATER
RefStat() {
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
dist_to_origin_ = 0.0;
+ has_cos_phi_ = false;
+ cos_phi_ = -1.0;
}
~RefStat() {}
@@ -103,6 +121,8 @@
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
dist_to_origin_ = 0.0;
+ has_cos_phi_ = false;
+ cos_phi_ = -1.0;
}
void Init(const arma::mat& data, size_t begin, size_t count,
@@ -110,6 +130,8 @@
cosine_origin_ = 0.0;
sine_origin_ = 0.0;
dist_to_origin_ = 0.0;
+ has_cos_phi_ = false;
+ cos_phi_ = -1.0;
}
}; // RefStat
@@ -117,7 +139,8 @@
// Euclidean bounding boxes, the data are stored in a Matrix,
// and each node has a QueryStat for its bound.
typedef GeneralBinarySpaceTree<bound::DBallBound<>, arma::mat, RefStat> TreeType;
- typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeType;
+ typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat, QueryStat> CTreeTypeA;
+ typedef GeneralBinarySpaceTree<DCosineBound<>, arma::mat, QueryStat> CTreeTypeB;
/////////////////////////////// Members ////////////////////////////
@@ -132,7 +155,8 @@
// Pointers to the roots of the two trees.
TreeType* reference_tree_;
- CTreeType* query_tree_;
+ CTreeTypeA* query_tree_A_;
+ CTreeTypeB* query_tree_B_;
// The total number of prunes.
size_t number_of_prunes_;
@@ -170,7 +194,8 @@
*/
MaxIP() {
reference_tree_ = NULL;
- query_tree_ = NULL;
+ query_tree_A_ = NULL;
+ query_tree_B_ = NULL;
}
/**
@@ -181,8 +206,12 @@
if (reference_tree_ != NULL)
delete reference_tree_;
- if (query_tree_ != NULL)
- delete query_tree_;
+ if (query_tree_A_ != NULL)
+ delete query_tree_A_;
+
+ if (query_tree_B_ != NULL)
+ delete query_tree_B_;
+
}
/////////////////////////// Helper Functions //////////////////////
@@ -200,7 +229,8 @@
* ignoring the norm of any query.
* So it is computing \max_(q,r) |r| cos <qr.
*/
- double MaxNodeIP_(CTreeType *query_node, TreeType* reference_node);
+ double MaxNodeIP_(CTreeTypeA *query_node, TreeType* reference_node);
+ double MaxNodeIP_(CTreeTypeB *query_node, TreeType* reference_node);
/**
* Performs exhaustive computation at the leaves.
@@ -210,7 +240,8 @@
/**
* Dual-tree: Performs exhaustive computation between two leaves.
*/
- void ComputeBaseCase_(CTreeType* query_node, TreeType* reference_node);
+ void ComputeBaseCase_(CTreeTypeA* query_node, TreeType* reference_node);
+ void ComputeBaseCase_(CTreeTypeB* query_node, TreeType* reference_node);
/**
* The recursive function
@@ -221,13 +252,20 @@
/**
* Dual-tree: The recursive function
*/
- void ComputeNeighborsRecursion_(CTreeType* query_node,
+ void ComputeNeighborsRecursion_(CTreeTypeA* query_node,
TreeType* reference_node,
double upper_bound_ip);
- void reset_tree_(CTreeType *tree);
+ void ComputeNeighborsRecursion_(CTreeTypeB* query_node,
+ TreeType* reference_node,
+ double upper_bound_ip);
+
+ void reset_tree_(CTreeTypeA *tree);
+ void reset_tree_(CTreeTypeB *tree);
+ void reset_tree_(TreeType *tree);
void set_angles_in_balls_(TreeType *tree);
- void set_norms_in_cones_(CTreeType *tree);
+ void set_norms_in_cones_(CTreeTypeA *tree);
+ void set_norms_in_cones_(CTreeTypeB *tree);
size_t SortValue(double value);
@@ -246,8 +284,10 @@
if (reference_tree_ != NULL)
delete reference_tree_;
- if (query_tree_ != NULL)
- delete query_tree_;
+ if (query_tree_A_ != NULL)
+ delete query_tree_A_;
+ if (query_tree_B_ != NULL)
+ delete query_tree_B_;
}
/**
@@ -275,7 +315,8 @@
arma::mat* ips);
- void CheckPrune(CTreeType* query_node, TreeType* ref_node);
+ void CheckPrune(CTreeTypeA* query_node, TreeType* ref_node);
+ void CheckPrune(CTreeTypeB* query_node, TreeType* ref_node);
}; //class MaxIP
@@ -293,9 +334,13 @@
" pruning using the angles as well", "maxip");
PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
" pruning using the angles as well", "maxip");
+
PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
"computation, using a cone tree for the "
"queries.", "maxip");
+PARAM_FLAG("alt_tree", "The flag to trigger the "
+ "alternate query tree.",
+ "maxip");
PARAM_FLAG("check_prune", "The flag to trigger the "
"checking of the prune.", "maxip");
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -11,6 +11,7 @@
#ifndef GEN_CONE_TREE_H
#define GEN_CONE_TREE_H
+#define NDEBUG
#include "general_spacetree.h"
#include "gen_cone_tree_impl.h"
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -7,6 +7,8 @@
#ifndef GEN_CONE_TREE_IMPL_H
#define GEN_CONE_TREE_IMPL_H
+#define NDEBUG
+
#include <assert.h>
#include <mlpack/core.h>
#include <armadillo>
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -11,7 +11,11 @@
#ifndef GEN_COSINE_TREE_H
#define GEN_COSINE_TREE_H
+#define NDEBUG
+#include <armadillo>
+#include <assert.h>
+
#include "general_spacetree.h"
#include "gen_cosine_tree_impl.h"
@@ -59,6 +63,10 @@
node->Init(0, matrix.n_cols);
node->bound().center().set_size(matrix.n_rows);
+ node->bound().center() = arma::mean(matrix, 1);
+
+ assert(node->bound().center().n_elem == matrix.n_rows);
+
tree_gen_cosine_tree_private::SplitGenCosineTree<TCosineTree>
(matrix, node, leaf_size, old_from_new_ptr);
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -7,6 +7,8 @@
#ifndef GEN_COSINE_TREE_IMPL_H
#define GEN_COSINE_TREE_IMPL_H
+#define NDEBUG
+
#include <assert.h>
#include <mlpack/core.h>
#include <armadillo>
@@ -14,119 +16,24 @@
namespace tree_gen_cosine_tree_private {
- size_t FurthestColumnIndex(const arma::vec& pivot,
- const arma::mat& matrix,
- size_t begin, size_t count,
- double *furthest_cosine);
-
- // This function assumes that we have points embedded in Euclidean
- // space. The representative point (center) is chosen as the mean
- // of the all unit directions of all the vectors in this set.
-
- // fixed!!
+ // not compiled
template<typename TBound>
void MakeLeafCosineTreeNode(const arma::mat& matrix,
size_t begin, size_t count,
TBound *bounds) {
- bounds->center().zeros();
- size_t end = begin + count;
- for (size_t i = begin; i < end; i++) {
- bounds->center()
- += (matrix.unsafe_col(i)
- / arma::norm(matrix.unsafe_col(i), 2));
- }
- bounds->center() /= (double) count;
-// bounds->center() /= arma::norm(bounds->center(), 2);
-
-// printf("c_norm: %lg\n", arma::norm(bounds->center(), 2));
-
- double furthest_cosine;
- FurthestColumnIndex(bounds->center(), matrix, begin, count,
- &furthest_cosine);
- bounds->set_radius(furthest_cosine);
- }
-
-
- // fixed until compiled
- template<typename TBound>
- size_t MatrixPartition(arma::mat& matrix, size_t first, size_t count,
- TBound &left_bound, TBound &right_bound,
- size_t *old_from_new) {
+ arma::vec cosine_vec(count);
- size_t end = first + count;
- size_t left_count = 0;
-
- std::vector<bool> left_membership;
- left_membership.reserve(count);
+ for (size_t i = 0; i < count; i++)
+ cosine_vec(i) = Cosine::Evaluate(bounds->center(),
+ matrix.unsafe_col(i + begin));
- for (size_t left = first; left < end; left++) {
-
- // Make alias of the current point.
- arma::vec point = matrix.unsafe_col(left);
-
- // Compute the cosines from the two pivots.
- double cosine_from_left_pivot =
- Cosine::Evaluate(point, left_bound.center());
- double cosine_from_right_pivot =
- Cosine::Evaluate(point, right_bound.center());
-
- // We swap if the point is more angled away from the left pivot.
- if(cosine_from_left_pivot < cosine_from_right_pivot) {
- left_membership[left - first] = false;
- } else {
- left_membership[left - first] = true;
- left_count++;
- }
- }
-
- size_t left = first;
- size_t right = first + count - 1;
-
- /* At any point:
- *
- * everything < left is correct
- * everything > right is correct
- */
- for (;;) {
- while (left_membership[left - first] && (left <= right)) {
- left++;
- }
-
- while (!left_membership[right - first] && (left <= right)) {
- right--;
- }
-
- if (left > right) {
- /* left == right + 1 */
- break;
- }
-
- // Swap the left vector with the right vector.
- matrix.swap_cols(left, right);
-
- bool tmp = left_membership[left - first];
- left_membership[left - first] = left_membership[right - first];
- left_membership[right - first] = tmp;
-
- if (old_from_new) {
- size_t t = old_from_new[left];
- old_from_new[left] = old_from_new[right];
- old_from_new[right] = t;
- }
-
- assert(left <= right);
- right--;
- }
-
- assert(left == right + 1);
-
- return left_count;
+ bounds->set_radius(arma::min(cosine_vec), arma::max(cosine_vec));
}
-
+
- // fixed
+ // not compiled
template<typename TCosineTree>
bool AttemptSplitting(arma::mat& matrix,
TCosineTree *node,
@@ -135,83 +42,100 @@
size_t leaf_size,
size_t *old_from_new) {
- // Pick a random row.
- size_t random_row
- = math::RandInt(node->begin(),
- node->begin() + node->count());
+ // obtain the list of cosine values to all the points
+ // in the set
+ arma::vec cosine_vec(node->count());
- // why is this here?
- // random_row = node->begin();
+ for (size_t i = 0; i < node->count(); i++)
+ cosine_vec(i)
+ = Cosine::Evaluate(node->bound().center(),
+ matrix.unsafe_col(i + node->begin()));
- arma::vec random_row_vec = matrix.unsafe_col(random_row);
-
- // Now figure out the furthest point from the random row picked
- // above.
- double furthest_cosine;
- size_t furthest_from_random_row =
- FurthestColumnIndex(random_row_vec, matrix, node->begin(), node->count(),
- &furthest_cosine);
- arma::vec furthest_from_random_row_vec = matrix.unsafe_col(furthest_from_random_row);
-
- // Then figure out the furthest point from the furthest point.
- double furthest_from_furthest_cosine;
- size_t furthest_from_furthest_random_row =
- FurthestColumnIndex(furthest_from_random_row_vec, matrix, node->begin(),
- node->count(), &furthest_from_furthest_cosine);
- arma::vec furthest_from_furthest_random_row_vec =
- matrix.unsafe_col(furthest_from_furthest_random_row);
-
- if(furthest_from_furthest_cosine > (1.0 - DBL_EPSILON)) {
- // everything in a really tight narrow cone
+ if(arma::max(cosine_vec) - arma::min(cosine_vec) < DBL_EPSILON) {
+ // everything in a really tight narrow co-axial cone-ring
return false;
} else {
*left = new TCosineTree();
*right = new TCosineTree();
- // not necessary, vec::operator=() takes care of resetting the size
-// ((*left)->bound().center()).set_size(matrix.n_rows);
-// ((*right)->bound().center()).set_size(matrix.n_rows);
+ ((*left)->bound().center()) = node->bound().center();
+ ((*right)->bound().center()) = node->bound().center();
- ((*left)->bound().center()) = furthest_from_random_row_vec;
- ((*right)->bound().center()) = furthest_from_furthest_random_row_vec;
+ node->bound().set_radius(arma::min(cosine_vec), arma::max(cosine_vec));
- size_t left_count
- = MatrixPartition(matrix, node->begin(), node->count(),
- (*left)->bound(), (*right)->bound(),
- old_from_new);
+// printf("%lg, %lg\n", node->bound().rad_min(), node->bound().rad_max());
- (*left)->Init(node->begin(), left_count);
- (*right)->Init(node->begin() + left_count, node->count() - left_count);
- }
+ size_t first = node->begin();
+ size_t end = first + node->count();
+ size_t left_count = 0;
- return true;
- }
+ double median_cosine_value = arma::median(cosine_vec);
- // fixed
- template<typename TCosineTree>
- void CombineBounds(arma::mat& matrix, TCosineTree *node,
- TCosineTree *left, TCosineTree *right) {
+ std::vector<bool> left_membership;
+ left_membership.reserve(node->count());
- // First clear the internal node center.
- node->bound().center().zeros();
+ for (size_t left_ind = first; left_ind < end; left_ind++) {
+
+ if(cosine_vec(left_ind - first) < median_cosine_value) {
+ // the outer ring
+ left_membership[left_ind - first] = false;
+ } else {
+ // the inner ring
+ left_membership[left_ind - first] = true;
+ left_count++;
+ }
+ }
- // Compute the weighted sum of the two pivots
- node->bound().center() += left->count() * left->bound().center();
- node->bound().center() += right->count() * right->bound().center();
- node->bound().center() /= (double) node->count();
-// node->bound().center() /= arma::norm(node->bound().center(), 2);
+ size_t left_ind = first;
+ size_t right_ind = end - 1;
-// printf("c_norm: %lg\n", arma::norm(node->bound().center(), 2));
+ /* At any point:
+ *
+ * everything < left_ind is correct
+ * everything > right_ind is correct
+ */
+ for (;;) {
+ while (left_membership[left_ind - first] && (left_ind <= right_ind)) {
+ left_ind++;
+ }
- double left_min_cosine, right_min_cosine;
- FurthestColumnIndex(node->bound().center(), matrix, left->begin(),
- left->count(), &left_min_cosine);
- FurthestColumnIndex(node->bound().center(), matrix, right->begin(),
- right->count(), &right_min_cosine);
- node->bound().set_radius(std::min(left_min_cosine, right_min_cosine));
+ while (!left_membership[right_ind - first] && (left_ind <= right_ind)) {
+ right_ind--;
+ }
+
+ if (left_ind > right_ind) {
+ /* left == right_ind + 1 */
+ break;
+ }
+
+ // Swap the left vector with the right_ind vector.
+ matrix.swap_cols(left_ind, right_ind);
+
+ bool tmp = left_membership[left_ind - first];
+ left_membership[left_ind - first] = left_membership[right_ind - first];
+ left_membership[right_ind - first] = tmp;
+
+ if (old_from_new) {
+ size_t t = old_from_new[left_ind];
+ old_from_new[left_ind] = old_from_new[right_ind];
+ old_from_new[right_ind] = t;
+ }
+
+ assert(left_ind <= right_ind);
+ right_ind--;
+ }
+
+ assert(left_ind == right_ind + 1);
+
+ (*left)->Init(node->begin(), left_count);
+ (*right)->Init(node->begin() + left_count, node->count() - left_count);
+
+ return true;
+ }
}
- // fixed
+
+ // not compiled
template<typename TCosineTree>
void SplitGenCosineTree(arma::mat& matrix, TCosineTree *node,
size_t leaf_size, size_t *old_from_new) {
@@ -233,7 +157,6 @@
if(can_cut) {
SplitGenCosineTree(matrix, left, leaf_size, old_from_new);
SplitGenCosineTree(matrix, right, leaf_size, old_from_new);
- CombineBounds(matrix, node, left, right);
}
else {
MakeLeafCosineTreeNode(matrix, node->begin(),
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -14,6 +14,8 @@
#ifndef GEN_METRIC_TREE_H
#define GEN_METRIC_TREE_H
+#define NDEBUG
+
#include "general_spacetree.h"
#include "gen_metric_tree_impl.h"
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h 2011-10-29 23:16:02 UTC (rev 10075)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_metric_tree_impl.h 2011-10-30 03:16:22 UTC (rev 10076)
@@ -8,6 +8,8 @@
#ifndef GEN_METRIC_TREE_IMPL_H
#define GEN_METRIC_TREE_IMPL_H
+#define NDEBUG
+
#include <assert.h>
#include <mlpack/core.h>
#include <mlpack/core/tree/bounds.hpp>
More information about the mlpack-svn
mailing list