[mlpack-svn] r10074 - mlpack/trunk/src/contrib/pram/max_ip_search
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Oct 29 18:46:48 EDT 2011
Author: pram
Date: 2011-10-29 18:46:47 -0400 (Sat, 29 Oct 2011)
New Revision: 10074
Added:
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
Modified:
mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
Log:
Cone-tree code named appropriately
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt 2011-10-29 22:46:47 UTC (rev 10074)
@@ -10,13 +10,18 @@
gen_metric_tree.h
gen_metric_tree_impl.h
gen_metric_tree_impl.cc
- # Cosine tree
+ # Cone bound
cosine.h
dconebound.h
dconebound_impl.h
- gen_cosine_tree.h
- gen_cosine_tree_impl.h
- gen_cosine_tree_impl.cc
+ # Cosine tree
+ #gen_cosine_tree.h
+ #gen_cosine_tree_impl.h
+ #gen_cosine_tree_impl.cc
+ # cone tree
+ gen_cone_tree.h
+ gen_cone_tree_impl.h
+ gen_cone_tree_impl.cc
# Max-inner-product search class
exact_max_ip.h
exact_max_ip.cc
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc 2011-10-29 22:46:47 UTC (rev 10074)
@@ -45,12 +45,15 @@
string qfile = CLI::GetParam<string>("q");
Log::Info << "Loading files..." << endl;
- if (!data::Load(rfile.c_str(), rdata))
+ if (rdata.load(rfile.c_str()) == false)
Log::Fatal << "Reference file "<< rfile << " not found." << endl;
-
- if (!data::Load(qfile.c_str(), qdata))
+
+ if (qdata.load(qfile.c_str()) == false)
Log::Fatal << "Query file " << qfile << " not found." << endl;
+ rdata = arma::trans(rdata);
+ qdata ==arma::trans(qdata);
+
Log::Info << "File loaded..." << endl;
Log::Info << "R(" << rdata.n_rows << ", " << rdata.n_cols
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc 2011-10-29 22:46:47 UTC (rev 10074)
@@ -963,10 +963,10 @@
if (mlpack::CLI::HasParam("approx_maxip/dual_tree")) {
query_tree_
- = proximity::MakeGenCosineTree<CTreeType>(queries_,
- leaf_size_,
- &old_from_new_queries_,
- NULL);
+ = proximity::MakeGenConeTree<CTreeType>(queries_,
+ leaf_size_,
+ &old_from_new_queries_,
+ NULL);
}
// Stop the timer we started above
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h 2011-10-29 22:46:47 UTC (rev 10074)
@@ -15,7 +15,7 @@
#include "general_spacetree.h"
#include "gen_metric_tree.h"
#include "dconebound.h"
-#include "gen_cosine_tree.h"
+#include "gen_cone_tree.h"
using namespace mlpack;
@@ -40,7 +40,7 @@
PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
" pruning using the angles as well", "approx_maxip");
PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
- "computation, using a cosine tree for the "
+ "computation, using a cone tree for the "
"queries.", "approx_maxip");
PARAM_FLAG("check_prune", "The flag to trigger the "
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc 2011-10-29 22:46:47 UTC (rev 10074)
@@ -13,8 +13,7 @@
PARAM_STRING_REQ("r", "The data set to be indexed", "");
-PARAM_FLAG("xx_print_tree", "The flag to print the tree", "");
-PARAM_FLAG("some_flag", "some test flag", "");
+PARAM_FLAG("print_tree", "The flag to print the tree", "");
int main (int argc, char *argv[]) {
@@ -24,13 +23,14 @@
string rfile = CLI::GetParam<string>("r");
Log::Info << "Loading files..." << endl;
- if (!data::Load(rfile.c_str(), rdata))
+ if (rdata.load(rfile.c_str()) == false)
Log::Fatal << "Data file " << rfile << " not found!" << endl;
Log::Info << "Files loaded." << endl
<< "Data (" << rdata.n_rows << ", "
<< rdata.n_cols << ")" << endl;
+ rdata = arma::trans(rdata);
typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat> CTreeType;
@@ -41,14 +41,12 @@
&old_from_new_data,
NULL);
- if (CLI::HasParam("xx_print_tree")) {
+ if (CLI::HasParam("print_tree")) {
test_tree->Print();
} else {
Log::Info << "Tree built" << endl;
}
- if (CLI::HasParam("some_flag"))
- printf("The flag is working!\n");
-}
+} // end main
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc 2011-10-29 22:46:47 UTC (rev 10074)
@@ -95,7 +95,6 @@
arma::vec centroid = reference_node->bound().center();
// +1: can be cached in the reference tree
-// double c_norm = arma::norm(centroid, 2);
double c_norm = reference_node->stat().dist_to_origin();
double rad = std::sqrt(reference_node->bound().radius());
@@ -111,8 +110,6 @@
double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
// max_r sin <pr = sin_theta
-// double sin_theta = rad / c_norm;
-// double cos_theta = std::sqrt(1 - sin_theta * sin_theta);
double sin_theta = reference_node->stat().sine_origin();
double cos_theta = reference_node->stat().cosine_origin();
@@ -133,20 +130,37 @@
}
}
} else {
- ball_has_origin_++;
+ ball_has_origin_ += query_node->end() - query_node->begin();
}
return ((c_norm + rad) * max_cos_qp);
} // angle-prune
if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) {
+ // HAVE TO WORK THIS OUT AND CHECK IT ANALYTICALLY
// using the closed-form-maximization,
// |p| cos (phi - w) + R
- // +1
- double c_norm_cos_phi = arma::dot(q, centroid) / q_norm;
- double c_norm_sin_phi = std::sqrt(c_norm * c_norm
- - c_norm_cos_phi * c_norm_cos_phi);
- return (c_norm_cos_phi * cos_w + c_norm_sin_phi * sin_w + rad);
+ // +1
+ double cos_phi = arma::dot(q, centroid) / (c_norm * q_norm);
+ double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+
+ if (cos_phi < cos_w) {
+
+ return (c_norm * (cos_phi * cos_w + sin_phi * sin_w) + rad);
+
+ } else {
+ // phi < w; hence the query cone has the centroid
+ // so no useful bounding
+
+ cone_has_centroid_ += query_node->end() - query_node->begin();
+
+ // maybe trigger a flag which lets the peep know that
+ // query cone too wide
+ // Although, not sure how that will benefit since the
+ // current traversal kind of takes care of this.
+
+ return (c_norm + rad);
+ }
} // alt-angle-prune
return (c_norm + rad);
@@ -356,7 +370,7 @@
if (upper_bound_p_cos_pq < query_node->stat().bound()) {
// Pruned
- number_of_prunes_++;
+ number_of_prunes_ += query_node->end() - query_node->begin();
if (CLI::HasParam("maxip/check_prune"))
CheckPrune(query_node, reference_node);
@@ -512,6 +526,7 @@
// track the number of prunes and computations
number_of_prunes_ = 0;
ball_has_origin_ = 0;
+ cone_has_centroid_ = 0;
distance_computations_ = 0;
split_decisions_ = 0;
@@ -550,10 +565,10 @@
if (mlpack::CLI::HasParam("maxip/dual_tree")) {
query_tree_
- = proximity::MakeGenCosineTree<CTreeType>(queries_,
- leaf_size_,
- &old_from_new_queries_,
- NULL);
+ = proximity::MakeGenConeTree<CTreeType>(queries_,
+ leaf_size_,
+ &old_from_new_queries_,
+ NULL);
set_norms_in_cones_(query_tree_);
}
@@ -579,6 +594,7 @@
// track the number of prunes and computations
number_of_prunes_ = 0;
ball_has_origin_ = 0;
+ cone_has_centroid_ = 0;
distance_computations_ = 0;
split_decisions_ = 0;
@@ -620,6 +636,7 @@
// track the number of prunes and computations
number_of_prunes_ = 0;
ball_has_origin_ = 0;
+ cone_has_centroid_ = 0;
distance_computations_ = 0;
split_decisions_ = 0;
@@ -763,14 +780,15 @@
}
}
- mlpack::Log::Info << "Tree-based Search - Number of prunes: "
+ mlpack::Log::Warn << "Tree-based Search - Number of prunes: "
<< number_of_prunes_ << ", Ball has origin: "
- << ball_has_origin_ << std::endl;
+ << ball_has_origin_ << ", Cone has centroid: "
+ << cone_has_centroid_ << std::endl;
mlpack::Log::Info << "\t \t Avg. # of DC: "
- << (double) distance_computations_
+ << (double) distance_computations_
/ (double) queries_.n_cols << std::endl;
mlpack::Log::Info << "\t \t Avg. # of SD: "
- << (double) split_decisions_
+ << (double) split_decisions_
/ (double) queries_.n_cols << std::endl;
return (double) (distance_computations_ + split_decisions_)
@@ -808,12 +826,12 @@
}
mlpack::Log::Info << "Brute-force Search - Number of prunes: "
- << number_of_prunes_ << std::endl;
+ << number_of_prunes_ << std::endl;
mlpack::Log::Info << "\t \t Avg. # of DC: "
- << (double) distance_computations_
+ << (double) distance_computations_
/ (double) queries_.n_cols << std::endl;
mlpack::Log::Info << "\t \t Avg. # of SD: "
- << (double) split_decisions_
+ << (double) split_decisions_
/ (double) queries_.n_cols << std::endl;
return (double) (distance_computations_ + split_decisions_)
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h 2011-10-29 22:46:47 UTC (rev 10074)
@@ -6,7 +6,7 @@
#ifndef EXACT_MAX_IP_H
#define EXACT_MAX_IP_H
-#define NDEBUG
+//#define NDEBUG
#include <assert.h>
#include <mlpack/core.h>
@@ -17,7 +17,7 @@
#include "general_spacetree.h"
#include "gen_metric_tree.h"
#include "dconebound.h"
-#include "gen_cosine_tree.h"
+#include "gen_cone_tree.h"
using namespace mlpack;
@@ -137,6 +137,7 @@
// The total number of prunes.
size_t number_of_prunes_;
size_t ball_has_origin_;
+ size_t cone_has_centroid_;
// A permutation of the indices for tree building.
arma::Col<size_t> old_from_new_queries_;
@@ -293,7 +294,7 @@
PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
" pruning using the angles as well", "maxip");
PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
- "computation, using a cosine tree for the "
+ "computation, using a cone tree for the "
"queries.", "maxip");
PARAM_FLAG("check_prune", "The flag to trigger the "
Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h 2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,77 @@
+// Copyright 2007 Georgia Institute of Technology. All rights reserved.
+// ABSOLUTELY NOT FOR DISTRIBUTCLIN
+/**
+ * @file gen_cone_tree.h
+ *
+ * Tools for cone-trees.
+ *
+ * @experimental
+ */
+
+#ifndef GEN_CONE_TREE_H
+#define GEN_CONE_TREE_H
+
+
+#include "general_spacetree.h"
+#include "gen_cone_tree_impl.h"
+
+
+/**
+ * Regular pointer-style trees (as opposed to THOR trees).
+ */
+namespace proximity {
+
+ /**
+ * Creates a cone tree from data.
+ *
+ * @experimental
+ *
+ * This requires you to pass in two unitialized ArrayLists which will contain
+ * index mappings so you can account for the re-ordering of the matrix.
+ * (By unitialized I mean don't call Init on it)
+ *
+ * @param matrix data where each column is a point, WHICH WILL BE RE-ORDERED
+ * @param leaf_size the maximum points in a leaf
+ * @param old_from_new pointer to an unitialized arraylist; it will map
+ * new indices to original
+ * @param new_from_old pointer to an unitialized arraylist; it will map
+ * original indexes to new indices
+ */
+ template<typename TConeTree>
+ TConeTree *MakeGenConeTree(arma::mat& matrix, size_t leaf_size,
+ arma::Col<size_t> *old_from_new = NULL,
+ arma::Col<size_t> *new_from_old = NULL) {
+
+ TConeTree *node = new TConeTree();
+ size_t *old_from_new_ptr;
+
+ if (old_from_new) {
+ old_from_new->set_size(matrix.n_cols);
+
+ for (size_t i = 0; i < matrix.n_cols; i++) {
+ (*old_from_new)[i] = i;
+ }
+
+ old_from_new_ptr = old_from_new->memptr();
+ } else {
+ old_from_new_ptr = NULL;
+ }
+
+ node->Init(0, matrix.n_cols);
+ node->bound().center().set_size(matrix.n_rows);
+ tree_gen_cone_tree_private::SplitGenConeTree<TConeTree>
+ (matrix, node, leaf_size, old_from_new_ptr);
+
+ if (new_from_old) {
+ new_from_old->set_size(matrix.n_cols);
+ for (size_t i = 0; i < matrix.n_cols; i++) {
+ (*new_from_old)[(*old_from_new)[i]] = i;
+ }
+ }
+
+ return node;
+ }
+
+};
+
+#endif
Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc 2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,30 @@
+#include "gen_cone_tree_impl.h"
+
+
+namespace tree_gen_cone_tree_private {
+ // fixed!!
+ size_t FurthestColumnIndex(const arma::vec& pivot,
+ const arma::mat& matrix,
+ size_t begin, size_t count,
+ double *furthest_cosine) {
+
+ size_t furthest_index = -1;
+ size_t end = begin + count;
+ *furthest_cosine = 1.0;
+
+ for(size_t i = begin; i < end; i++) {
+ double cosine_between_center_and_point =
+ Cosine::Evaluate(pivot, matrix.unsafe_col(i));
+
+ if((*furthest_cosine) > cosine_between_center_and_point) {
+ *furthest_cosine = cosine_between_center_and_point;
+ furthest_index = i;
+ }
+ }
+
+ assert((*furthest_cosine) >= -1.0);
+
+ return furthest_index;
+ }
+
+};
Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h 2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,232 @@
+/** @file gen_cone_tree_impl.h
+ *
+ * Implementation for the regular pointer-style cone-tree builder.
+ *
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ */
+#ifndef GEN_CONE_TREE_IMPL_H
+#define GEN_CONE_TREE_IMPL_H
+
+#include <assert.h>
+#include <mlpack/core.h>
+#include <armadillo>
+#include "cosine.h"
+
+namespace tree_gen_cone_tree_private {
+
+ size_t FurthestColumnIndex(const arma::vec& pivot,
+ const arma::mat& matrix,
+ size_t begin, size_t count,
+ double *furthest_cosine);
+
+ // This function assumes that we have points embedded in Euclidean
+ // space. The representative point (center) is chosen as the mean
+ // of the all the vectors in this set.
+
+ template<typename TBound>
+ void MakeLeafConeTreeNode(const arma::mat& matrix,
+ size_t begin, size_t count,
+ TBound *bounds) {
+
+ bounds->center().zeros();
+
+ size_t end = begin + count;
+ for (size_t i = begin; i < end; i++) {
+ bounds->center()
+ += (matrix.unsafe_col(i)
+ / arma::norm(matrix.unsafe_col(i), 2));
+ }
+ bounds->center() /= (double) count;
+
+ double furthest_cosine;
+ FurthestColumnIndex(bounds->center(), matrix, begin, count,
+ &furthest_cosine);
+ bounds->set_radius(furthest_cosine);
+ }
+
+
+ template<typename TBound>
+ size_t MatrixPartition(arma::mat& matrix, size_t first, size_t count,
+ TBound &left_bound, TBound &right_bound,
+ size_t *old_from_new) {
+
+ size_t end = first + count;
+ size_t left_count = 0;
+
+ std::vector<bool> left_membership;
+ left_membership.reserve(count);
+
+ for (size_t left = first; left < end; left++) {
+
+ // Make alias of the current point.
+ arma::vec point = matrix.unsafe_col(left);
+
+ // Compute the cosines from the two pivots.
+ double cosine_from_left_pivot =
+ Cosine::Evaluate(point, left_bound.center());
+ double cosine_from_right_pivot =
+ Cosine::Evaluate(point, right_bound.center());
+
+ // We swap if the point is more angled away from the left pivot.
+ if(cosine_from_left_pivot < cosine_from_right_pivot) {
+ left_membership[left - first] = false;
+ } else {
+ left_membership[left - first] = true;
+ left_count++;
+ }
+ }
+
+ size_t left = first;
+ size_t right = first + count - 1;
+
+ /* At any point:
+ *
+ * everything < left is correct
+ * everything > right is correct
+ */
+ for (;;) {
+ while (left_membership[left - first] && (left <= right)) {
+ left++;
+ }
+
+ while (!left_membership[right - first] && (left <= right)) {
+ right--;
+ }
+
+ if (left > right) {
+ /* left == right + 1 */
+ break;
+ }
+
+ // Swap the left vector with the right vector.
+ matrix.swap_cols(left, right);
+
+ bool tmp = left_membership[left - first];
+ left_membership[left - first] = left_membership[right - first];
+ left_membership[right - first] = tmp;
+
+ if (old_from_new) {
+ size_t t = old_from_new[left];
+ old_from_new[left] = old_from_new[right];
+ old_from_new[right] = t;
+ }
+
+ assert(left <= right);
+ right--;
+ }
+
+ assert(left == right + 1);
+
+ return left_count;
+ }
+
+
+ template<typename TConeTree>
+ bool AttemptSplitting(arma::mat& matrix,
+ TConeTree *node,
+ TConeTree **left,
+ TConeTree **right,
+ size_t leaf_size,
+ size_t *old_from_new) {
+
+ // Pick a random row.
+ size_t random_row
+ = math::RandInt(node->begin(),
+ node->begin() + node->count());
+
+ arma::vec random_row_vec = matrix.unsafe_col(random_row);
+
+ // Now figure out the furthest point from the random row picked
+ // above.
+ double furthest_cosine;
+ size_t furthest_from_random_row =
+ FurthestColumnIndex(random_row_vec, matrix, node->begin(), node->count(),
+ &furthest_cosine);
+ arma::vec furthest_from_random_row_vec = matrix.unsafe_col(furthest_from_random_row);
+
+ // Then figure out the furthest point from the furthest point.
+ double furthest_from_furthest_cosine;
+ size_t furthest_from_furthest_random_row =
+ FurthestColumnIndex(furthest_from_random_row_vec, matrix, node->begin(),
+ node->count(), &furthest_from_furthest_cosine);
+ arma::vec furthest_from_furthest_random_row_vec =
+ matrix.unsafe_col(furthest_from_furthest_random_row);
+
+ if(furthest_from_furthest_cosine > (1.0 - DBL_EPSILON)) {
+ // everything in a really tight narrow cone
+ return false;
+ } else {
+ *left = new TConeTree();
+ *right = new TConeTree();
+
+ ((*left)->bound().center()) = furthest_from_random_row_vec;
+ ((*right)->bound().center()) = furthest_from_furthest_random_row_vec;
+
+ size_t left_count
+ = MatrixPartition(matrix, node->begin(), node->count(),
+ (*left)->bound(), (*right)->bound(),
+ old_from_new);
+
+ (*left)->Init(node->begin(), left_count);
+ (*right)->Init(node->begin() + left_count, node->count() - left_count);
+ }
+
+ return true;
+ }
+
+ template<typename TConeTree>
+ void CombineBounds(arma::mat& matrix, TConeTree *node,
+ TConeTree *left, TConeTree *right) {
+
+ // First clear the internal node center.
+ node->bound().center().zeros();
+
+ // Compute the weighted sum of the two pivots
+ node->bound().center() += left->count() * left->bound().center();
+ node->bound().center() += right->count() * right->bound().center();
+ node->bound().center() /= (double) node->count();
+
+ double left_min_cosine, right_min_cosine;
+ FurthestColumnIndex(node->bound().center(), matrix, left->begin(),
+ left->count(), &left_min_cosine);
+ FurthestColumnIndex(node->bound().center(), matrix, right->begin(),
+ right->count(), &right_min_cosine);
+ node->bound().set_radius(std::min(left_min_cosine, right_min_cosine));
+ }
+
+ // fixed
+ template<typename TConeTree>
+ void SplitGenConeTree(arma::mat& matrix, TConeTree *node,
+ size_t leaf_size, size_t *old_from_new) {
+
+ TConeTree *left = NULL;
+ TConeTree *right = NULL;
+
+ // If the node is just too small, then do not split.
+ if(node->count() < leaf_size) {
+ MakeLeafConeTreeNode(matrix, node->begin(), node->count(),
+ &(node->bound()));
+ }
+
+ // Otherwise, attempt to split.
+ else {
+ bool can_cut = AttemptSplitting(matrix, node, &left, &right,
+ leaf_size, old_from_new);
+
+ if(can_cut) {
+ SplitGenConeTree(matrix, left, leaf_size, old_from_new);
+ SplitGenConeTree(matrix, right, leaf_size, old_from_new);
+ CombineBounds(matrix, node, left, right);
+ }
+ else {
+ MakeLeafConeTreeNode(matrix, node->begin(),
+ node->count(), &(node->bound()));
+ }
+ }
+
+ // Set children information appropriately.
+ node->set_children(matrix, left, right);
+ }
+};
+
+#endif
Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h 2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h 2011-10-29 22:46:47 UTC (rev 10074)
@@ -34,12 +34,14 @@
size_t end = begin + count;
for (size_t i = begin; i < end; i++) {
bounds->center()
- += (matrix.unsafe_col(i)
- / arma::norm(matrix.unsafe_col(i), 2));
+ += (matrix.unsafe_col(i)
+ / arma::norm(matrix.unsafe_col(i), 2));
}
bounds->center() /= (double) count;
- bounds->center() /= arma::norm(bounds->center(), 2);
+// bounds->center() /= arma::norm(bounds->center(), 2);
+// printf("c_norm: %lg\n", arma::norm(bounds->center(), 2));
+
double furthest_cosine;
FurthestColumnIndex(bounds->center(), matrix, begin, count,
&furthest_cosine);
@@ -197,7 +199,10 @@
node->bound().center() += left->count() * left->bound().center();
node->bound().center() += right->count() * right->bound().center();
node->bound().center() /= (double) node->count();
+// node->bound().center() /= arma::norm(node->bound().center(), 2);
+// printf("c_norm: %lg\n", arma::norm(node->bound().center(), 2));
+
double left_min_cosine, right_min_cosine;
FurthestColumnIndex(node->bound().center(), matrix, left->begin(),
left->count(), &left_min_cosine);
More information about the mlpack-svn
mailing list