[mlpack-svn] r10074 - mlpack/trunk/src/contrib/pram/max_ip_search

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sat Oct 29 18:46:48 EDT 2011


Author: pram
Date: 2011-10-29 18:46:47 -0400 (Sat, 29 Oct 2011)
New Revision: 10074

Added:
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
Modified:
   mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
   mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
   mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
Log:
Cone-tree code named appropriately

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/CMakeLists.txt	2011-10-29 22:46:47 UTC (rev 10074)
@@ -10,13 +10,18 @@
    gen_metric_tree.h
    gen_metric_tree_impl.h
    gen_metric_tree_impl.cc
-   # Cosine tree
+   # Cone bound
    cosine.h
    dconebound.h
    dconebound_impl.h
-   gen_cosine_tree.h
-   gen_cosine_tree_impl.h
-   gen_cosine_tree_impl.cc
+   # Cosine tree
+   #gen_cosine_tree.h
+   #gen_cosine_tree_impl.h
+   #gen_cosine_tree_impl.cc
+   # cone tree
+   gen_cone_tree.h
+   gen_cone_tree_impl.h
+   gen_cone_tree_impl.cc
    # Max-inner-product search class
    exact_max_ip.h
    exact_max_ip.cc

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_ip_main.cc	2011-10-29 22:46:47 UTC (rev 10074)
@@ -45,12 +45,15 @@
   string qfile = CLI::GetParam<string>("q");
 
   Log::Info << "Loading files..." << endl;
-  if (!data::Load(rfile.c_str(), rdata))
+  if (rdata.load(rfile.c_str()) == false)
     Log::Fatal << "Reference file "<< rfile << " not found." << endl;
-
-  if (!data::Load(qfile.c_str(), qdata)) 
+  
+  if (qdata.load(qfile.c_str()) == false)
     Log::Fatal << "Query file " << qfile << " not found." << endl;
 
+  rdata = arma::trans(rdata);
+  qdata ==arma::trans(qdata);
+
   Log::Info << "File loaded..." << endl;
   
   Log::Info << "R(" << rdata.n_rows << ", " << rdata.n_cols 

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.cc	2011-10-29 22:46:47 UTC (rev 10074)
@@ -963,10 +963,10 @@
     
   if (mlpack::CLI::HasParam("approx_maxip/dual_tree")) {
     query_tree_
-      = proximity::MakeGenCosineTree<CTreeType>(queries_,
-						leaf_size_,
-						&old_from_new_queries_,
-						NULL);
+      = proximity::MakeGenConeTree<CTreeType>(queries_,
+					      leaf_size_,
+					      &old_from_new_queries_,
+					      NULL);
   }
 
   // Stop the timer we started above

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/approx_max_ip.h	2011-10-29 22:46:47 UTC (rev 10074)
@@ -15,7 +15,7 @@
 #include "general_spacetree.h"
 #include "gen_metric_tree.h"
 #include "dconebound.h"
-#include "gen_cosine_tree.h"
+#include "gen_cone_tree.h"
 
 using namespace mlpack;
 
@@ -40,7 +40,7 @@
 PARAM_FLAG("angle_prune", "The flag to trigger the tighter"
 	   " pruning using the angles as well", "approx_maxip");
 PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
-	   "computation, using a cosine tree for the "
+	   "computation, using a cone tree for the "
 	   "queries.", "approx_maxip");
 
 PARAM_FLAG("check_prune", "The flag to trigger the "

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/cosine_tree_test.cc	2011-10-29 22:46:47 UTC (rev 10074)
@@ -13,8 +13,7 @@
 
 
 PARAM_STRING_REQ("r", "The data set to be indexed", "");
-PARAM_FLAG("xx_print_tree", "The flag to print the tree", "");
-PARAM_FLAG("some_flag", "some test flag", "");
+PARAM_FLAG("print_tree", "The flag to print the tree", "");
 
 int main (int argc, char *argv[]) {
 
@@ -24,13 +23,14 @@
   string rfile = CLI::GetParam<string>("r");
 
   Log::Info << "Loading files..." << endl;
-  if (!data::Load(rfile.c_str(), rdata)) 
+  if (rdata.load(rfile.c_str()) == false) 
     Log::Fatal << "Data file " << rfile << " not found!" << endl;
 
   Log::Info << "Files loaded." << endl
 	   << "Data (" << rdata.n_rows << ", "
 	   << rdata.n_cols << ")" << endl;
 
+  rdata = arma::trans(rdata);
 
   typedef GeneralBinarySpaceTree<DConeBound<>, arma::mat> CTreeType;
 
@@ -41,14 +41,12 @@
 					      &old_from_new_data,
 					      NULL);
 
-  if (CLI::HasParam("xx_print_tree")) {
+  if (CLI::HasParam("print_tree")) {
     test_tree->Print();
   } else {
     Log::Info << "Tree built" << endl;
   }
 
-  if (CLI::HasParam("some_flag"))
-    printf("The flag is working!\n");
-}
+} // end main
 
 

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.cc	2011-10-29 22:46:47 UTC (rev 10074)
@@ -95,7 +95,6 @@
   arma::vec centroid = reference_node->bound().center();
 
   // +1: can be cached in the reference tree
-//   double c_norm = arma::norm(centroid, 2);
   double c_norm = reference_node->stat().dist_to_origin();
   double rad = std::sqrt(reference_node->bound().radius());
 
@@ -111,8 +110,6 @@
       double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
 
       // max_r sin <pr = sin_theta
-//       double sin_theta = rad / c_norm;
-//       double cos_theta = std::sqrt(1 - sin_theta * sin_theta);
       double sin_theta = reference_node->stat().sine_origin();
       double cos_theta = reference_node->stat().cosine_origin();
 
@@ -133,20 +130,37 @@
 	}
       }
     } else {
-      ball_has_origin_++;
+      ball_has_origin_ += query_node->end() - query_node->begin();
     }
     return ((c_norm + rad) * max_cos_qp);
   } // angle-prune
 
   if (mlpack::CLI::HasParam("maxip/alt_angle_prune")) { 
+    // HAVE TO WORK THIS OUT AND CHECK IT ANALYTICALLY
     // using the closed-form-maximization,
     // |p| cos (phi - w) + R
-      // +1
-    double c_norm_cos_phi = arma::dot(q, centroid) / q_norm;
-    double c_norm_sin_phi = std::sqrt(c_norm * c_norm 
-				      - c_norm_cos_phi * c_norm_cos_phi);
 
-    return (c_norm_cos_phi * cos_w + c_norm_sin_phi * sin_w + rad);
+    // +1
+    double cos_phi = arma::dot(q, centroid) / (c_norm * q_norm);
+    double sin_phi = std::sqrt(1 - cos_phi * cos_phi);
+
+    if (cos_phi < cos_w) { 
+     
+      return (c_norm * (cos_phi * cos_w + sin_phi * sin_w) + rad);
+
+    } else {
+      // phi < w; hence the query cone has the centroid
+      // so no useful bounding
+
+      cone_has_centroid_ += query_node->end() - query_node->begin();
+
+      // maybe trigger a flag which lets the peep know that 
+      // query cone too wide
+      // Although, not sure how that will benefit since the 
+      // current traversal kind of takes care of this.
+
+      return (c_norm + rad);
+    }
   } // alt-angle-prune
 
   return (c_norm + rad);
@@ -356,7 +370,7 @@
 
   if (upper_bound_p_cos_pq < query_node->stat().bound()) { 
     // Pruned
-    number_of_prunes_++;
+    number_of_prunes_ += query_node->end() - query_node->begin();
 
     if (CLI::HasParam("maxip/check_prune"))
       CheckPrune(query_node, reference_node);
@@ -512,6 +526,7 @@
   // track the number of prunes and computations
   number_of_prunes_ = 0;
   ball_has_origin_ = 0;
+  cone_has_centroid_ = 0;
   distance_computations_ = 0;
   split_decisions_ = 0;
     
@@ -550,10 +565,10 @@
    
   if (mlpack::CLI::HasParam("maxip/dual_tree")) {
     query_tree_
-      = proximity::MakeGenCosineTree<CTreeType>(queries_,
-						leaf_size_,
-						&old_from_new_queries_,
-						NULL);
+      = proximity::MakeGenConeTree<CTreeType>(queries_,
+					      leaf_size_,
+					      &old_from_new_queries_,
+					      NULL);
     set_norms_in_cones_(query_tree_);
   }
 
@@ -579,6 +594,7 @@
   // track the number of prunes and computations
   number_of_prunes_ = 0;
   ball_has_origin_ = 0;
+  cone_has_centroid_ = 0;
   distance_computations_ = 0;
   split_decisions_ = 0;
     
@@ -620,6 +636,7 @@
   // track the number of prunes and computations
   number_of_prunes_ = 0;
   ball_has_origin_ = 0;
+  cone_has_centroid_ = 0;
   distance_computations_ = 0;
   split_decisions_ = 0;
     
@@ -763,14 +780,15 @@
     }
   }
 
-  mlpack::Log::Info << "Tree-based Search - Number of prunes: " 
+  mlpack::Log::Warn << "Tree-based Search - Number of prunes: " 
 		    << number_of_prunes_ << ", Ball has origin: "
-		    << ball_has_origin_ << std::endl;
+		    << ball_has_origin_ << ", Cone has centroid: "
+		    << cone_has_centroid_ << std::endl;
   mlpack::Log::Info << "\t \t Avg. # of DC: " 
-		   << (double) distance_computations_ 
+		    << (double) distance_computations_ 
     / (double) queries_.n_cols << std::endl;
   mlpack::Log::Info << "\t \t Avg. # of SD: " 
-		   << (double) split_decisions_ 
+		    << (double) split_decisions_ 
     / (double) queries_.n_cols << std::endl;
 
   return (double) (distance_computations_ + split_decisions_)
@@ -808,12 +826,12 @@
   }
 
   mlpack::Log::Info << "Brute-force Search - Number of prunes: " 
-		   << number_of_prunes_ << std::endl;
+		    << number_of_prunes_ << std::endl;
   mlpack::Log::Info << "\t \t Avg. # of DC: " 
-		   << (double) distance_computations_ 
+		    << (double) distance_computations_ 
     / (double) queries_.n_cols << std::endl;
   mlpack::Log::Info << "\t \t Avg. # of SD: " 
-		   << (double) split_decisions_ 
+		    << (double) split_decisions_ 
     / (double) queries_.n_cols << std::endl;
 
   return (double) (distance_computations_ + split_decisions_)

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/exact_max_ip.h	2011-10-29 22:46:47 UTC (rev 10074)
@@ -6,7 +6,7 @@
 #ifndef EXACT_MAX_IP_H
 #define EXACT_MAX_IP_H
 
-#define NDEBUG
+//#define NDEBUG
 
 #include <assert.h>
 #include <mlpack/core.h>
@@ -17,7 +17,7 @@
 #include "general_spacetree.h"
 #include "gen_metric_tree.h"
 #include "dconebound.h"
-#include "gen_cosine_tree.h"
+#include "gen_cone_tree.h"
 
 using namespace mlpack;
 
@@ -137,6 +137,7 @@
   // The total number of prunes.
   size_t number_of_prunes_;
   size_t ball_has_origin_;
+  size_t cone_has_centroid_;
 
   // A permutation of the indices for tree building.
   arma::Col<size_t> old_from_new_queries_;
@@ -293,7 +294,7 @@
 PARAM_FLAG("alt_angle_prune", "The flag to trigger the tighter-er"
 	   " pruning using the angles as well", "maxip");
 PARAM_FLAG("dual_tree", "The flag to trigger dual-tree "
-	   "computation, using a cosine tree for the "
+	   "computation, using a cone tree for the "
 	   "queries.", "maxip");
 
 PARAM_FLAG("check_prune", "The flag to trigger the "

Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h	                        (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree.h	2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,77 @@
+// Copyright 2007 Georgia Institute of Technology. All rights reserved.
+// ABSOLUTELY NOT FOR DISTRIBUTCLIN
+/**
+ * @file gen_cone_tree.h
+ *
+ * Tools for cone-trees.
+ *
+ * @experimental
+ */
+
+#ifndef GEN_CONE_TREE_H
+#define GEN_CONE_TREE_H
+
+
+#include "general_spacetree.h"
+#include "gen_cone_tree_impl.h"
+
+
+/**
+ * Regular pointer-style trees (as opposed to THOR trees).
+ */
+namespace proximity {
+
+  /**
+   * Creates a cone tree from data.
+   *
+   * @experimental
+   *
+   * This requires you to pass in two unitialized ArrayLists which will contain
+   * index mappings so you can account for the re-ordering of the matrix.
+   * (By unitialized I mean don't call Init on it)
+   *
+   * @param matrix data where each column is a point, WHICH WILL BE RE-ORDERED
+   * @param leaf_size the maximum points in a leaf
+   * @param old_from_new pointer to an unitialized arraylist; it will map
+   *        new indices to original
+   * @param new_from_old pointer to an unitialized arraylist; it will map
+   *        original indexes to new indices
+   */
+  template<typename TConeTree>
+    TConeTree *MakeGenConeTree(arma::mat& matrix, size_t leaf_size,
+			       arma::Col<size_t> *old_from_new = NULL,
+			       arma::Col<size_t> *new_from_old = NULL) {
+
+    TConeTree *node = new TConeTree();
+    size_t *old_from_new_ptr;
+
+    if (old_from_new) {
+      old_from_new->set_size(matrix.n_cols);
+      
+      for (size_t i = 0; i < matrix.n_cols; i++) {
+        (*old_from_new)[i] = i;
+      }
+      
+      old_from_new_ptr = old_from_new->memptr();
+    } else {
+      old_from_new_ptr = NULL;
+    }
+
+    node->Init(0, matrix.n_cols);
+    node->bound().center().set_size(matrix.n_rows);
+    tree_gen_cone_tree_private::SplitGenConeTree<TConeTree>
+      (matrix, node, leaf_size, old_from_new_ptr);
+    
+    if (new_from_old) {
+      new_from_old->set_size(matrix.n_cols);
+      for (size_t i = 0; i < matrix.n_cols; i++) {
+        (*new_from_old)[(*old_from_new)[i]] = i;
+      }
+    }
+    
+    return node;
+  }
+
+};
+
+#endif

Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc	                        (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.cc	2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,30 @@
+#include "gen_cone_tree_impl.h"
+
+
+namespace tree_gen_cone_tree_private {
+  // fixed!!  
+  size_t FurthestColumnIndex(const arma::vec& pivot,
+			     const arma::mat& matrix, 
+			     size_t begin, size_t count,
+			     double *furthest_cosine) {
+    
+    size_t furthest_index = -1;
+    size_t end = begin + count;
+    *furthest_cosine = 1.0;
+
+    for(size_t i = begin; i < end; i++) {
+      double cosine_between_center_and_point = 
+	Cosine::Evaluate(pivot, matrix.unsafe_col(i));
+      
+      if((*furthest_cosine) > cosine_between_center_and_point) {
+	*furthest_cosine = cosine_between_center_and_point;
+	furthest_index = i;
+      }
+    }
+
+    assert((*furthest_cosine) >= -1.0);
+
+    return furthest_index;
+  }
+
+};

Added: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h	                        (rev 0)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cone_tree_impl.h	2011-10-29 22:46:47 UTC (rev 10074)
@@ -0,0 +1,232 @@
+/** @file gen_cone_tree_impl.h
+ *
+ *  Implementation for the regular pointer-style cone-tree builder.
+ *
+ *  @author Parikshit Ram (pram at cc.gatech.edu)
+ */
+#ifndef GEN_CONE_TREE_IMPL_H
+#define GEN_CONE_TREE_IMPL_H
+
+#include <assert.h>
+#include <mlpack/core.h>
+#include <armadillo>
+#include "cosine.h"
+
+namespace tree_gen_cone_tree_private {
+
+  size_t FurthestColumnIndex(const arma::vec& pivot,
+			     const arma::mat& matrix, 
+			     size_t begin, size_t count,
+			     double *furthest_cosine);
+    
+  // This function assumes that we have points embedded in Euclidean
+  // space. The representative point (center) is chosen as the mean 
+  // of the all the vectors in this set.
+
+  template<typename TBound>
+    void MakeLeafConeTreeNode(const arma::mat& matrix,
+			      size_t begin, size_t count,
+			      TBound *bounds) {
+
+    bounds->center().zeros();
+
+    size_t end = begin + count;
+    for (size_t i = begin; i < end; i++) {
+      bounds->center() 
+	+= (matrix.unsafe_col(i)
+ 	    / arma::norm(matrix.unsafe_col(i), 2));
+    }
+    bounds->center() /= (double) count;
+
+    double furthest_cosine;
+    FurthestColumnIndex(bounds->center(), matrix, begin, count,
+			&furthest_cosine);
+    bounds->set_radius(furthest_cosine);
+  }
+  
+
+  template<typename TBound>
+    size_t MatrixPartition(arma::mat& matrix, size_t first, size_t count,
+			   TBound &left_bound, TBound &right_bound,
+			   size_t *old_from_new) {
+    
+    size_t end = first + count;
+    size_t left_count = 0;
+
+    std::vector<bool> left_membership;
+    left_membership.reserve(count);
+    
+    for (size_t left = first; left < end; left++) {
+
+      // Make alias of the current point.
+      arma::vec point = matrix.unsafe_col(left);
+
+      // Compute the cosines from the two pivots.
+      double cosine_from_left_pivot =
+	Cosine::Evaluate(point, left_bound.center());
+      double cosine_from_right_pivot =
+	Cosine::Evaluate(point, right_bound.center());
+
+      // We swap if the point is more angled away from the left pivot.
+      if(cosine_from_left_pivot < cosine_from_right_pivot) {	
+	left_membership[left - first] = false;
+      } else {
+	left_membership[left - first] = true;
+	left_count++;
+      }
+    }
+
+    size_t left = first;
+    size_t right = first + count - 1;
+    
+    /* At any point:
+     *
+     *   everything < left is correct
+     *   everything > right is correct
+     */
+    for (;;) {
+      while (left_membership[left - first] && (left <= right)) {
+        left++;
+      }
+
+      while (!left_membership[right - first] && (left <= right)) {
+        right--;
+      }
+
+      if (left > right) {
+        /* left == right + 1 */
+        break;
+      }
+
+      // Swap the left vector with the right vector.
+      matrix.swap_cols(left, right);
+
+      bool tmp = left_membership[left - first];
+      left_membership[left - first] = left_membership[right - first];
+      left_membership[right - first] = tmp;
+      
+      if (old_from_new) {
+        size_t t = old_from_new[left];
+        old_from_new[left] = old_from_new[right];
+        old_from_new[right] = t;
+      }
+      
+      assert(left <= right);
+      right--;
+    }
+    
+    assert(left == right + 1);
+
+    return left_count;
+  }
+	
+
+  template<typename TConeTree>
+    bool AttemptSplitting(arma::mat& matrix,
+			  TConeTree *node,
+			  TConeTree **left, 
+			  TConeTree **right, 
+			  size_t leaf_size,
+			  size_t *old_from_new) {
+
+    // Pick a random row.
+    size_t random_row 
+      = math::RandInt(node->begin(),
+		      node->begin() + node->count());
+
+    arma::vec random_row_vec = matrix.unsafe_col(random_row);
+
+    // Now figure out the furthest point from the random row picked
+    // above.
+    double furthest_cosine;
+    size_t furthest_from_random_row =
+      FurthestColumnIndex(random_row_vec, matrix, node->begin(), node->count(),
+			  &furthest_cosine);
+    arma::vec furthest_from_random_row_vec = matrix.unsafe_col(furthest_from_random_row);
+
+    // Then figure out the furthest point from the furthest point.
+    double furthest_from_furthest_cosine;
+    size_t furthest_from_furthest_random_row =
+      FurthestColumnIndex(furthest_from_random_row_vec, matrix, node->begin(),
+			  node->count(), &furthest_from_furthest_cosine);
+    arma::vec furthest_from_furthest_random_row_vec =
+      matrix.unsafe_col(furthest_from_furthest_random_row);
+
+    if(furthest_from_furthest_cosine > (1.0 - DBL_EPSILON)) {
+      // everything in a really tight narrow cone
+      return false;
+    } else {
+      *left = new TConeTree();
+      *right = new TConeTree();
+
+      ((*left)->bound().center()) = furthest_from_random_row_vec;
+      ((*right)->bound().center()) = furthest_from_furthest_random_row_vec;
+
+      size_t left_count
+	= MatrixPartition(matrix, node->begin(), node->count(),
+			  (*left)->bound(), (*right)->bound(),
+			  old_from_new);
+
+      (*left)->Init(node->begin(), left_count);
+      (*right)->Init(node->begin() + left_count, node->count() - left_count);
+    }
+
+    return true;
+  }
+
+  template<typename TConeTree>
+    void CombineBounds(arma::mat& matrix, TConeTree *node,
+		       TConeTree *left, TConeTree *right) {
+    
+    // First clear the internal node center.
+    node->bound().center().zeros();
+
+    // Compute the weighted sum of the two pivots
+    node->bound().center() += left->count() * left->bound().center();
+    node->bound().center() += right->count() * right->bound().center();
+    node->bound().center() /= (double) node->count();
+
+    double left_min_cosine, right_min_cosine;
+    FurthestColumnIndex(node->bound().center(), matrix, left->begin(), 
+			left->count(), &left_min_cosine);
+    FurthestColumnIndex(node->bound().center(), matrix, right->begin(), 
+			right->count(), &right_min_cosine);    
+    node->bound().set_radius(std::min(left_min_cosine, right_min_cosine));
+  }
+
+  // fixed
+  template<typename TConeTree>
+    void SplitGenConeTree(arma::mat& matrix, TConeTree *node,
+			  size_t leaf_size, size_t *old_from_new) {
+    
+    TConeTree *left = NULL;
+    TConeTree *right = NULL;
+
+    // If the node is just too small, then do not split.
+    if(node->count() < leaf_size) {
+      MakeLeafConeTreeNode(matrix, node->begin(), node->count(),
+			   &(node->bound()));
+    }
+    
+    // Otherwise, attempt to split.
+    else {
+      bool can_cut = AttemptSplitting(matrix, node, &left, &right, 
+				      leaf_size, old_from_new);
+      
+      if(can_cut) {
+	SplitGenConeTree(matrix, left, leaf_size, old_from_new);
+	SplitGenConeTree(matrix, right, leaf_size, old_from_new);
+	CombineBounds(matrix, node, left, right);
+      }
+      else {
+	MakeLeafConeTreeNode(matrix, node->begin(),
+			     node->count(), &(node->bound()));
+      }
+    }
+    
+    // Set children information appropriately.
+    node->set_children(matrix, left, right);
+  }
+};
+
+#endif

Modified: mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h
===================================================================
--- mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h	2011-10-29 17:02:13 UTC (rev 10073)
+++ mlpack/trunk/src/contrib/pram/max_ip_search/gen_cosine_tree_impl.h	2011-10-29 22:46:47 UTC (rev 10074)
@@ -34,12 +34,14 @@
     size_t end = begin + count;
     for (size_t i = begin; i < end; i++) {
       bounds->center() 
-	+= (matrix.unsafe_col(i) 
-	    / arma::norm(matrix.unsafe_col(i), 2));
+	+= (matrix.unsafe_col(i)
+ 	    / arma::norm(matrix.unsafe_col(i), 2));
     }
     bounds->center() /= (double) count;
-    bounds->center() /= arma::norm(bounds->center(), 2);
+//     bounds->center() /= arma::norm(bounds->center(), 2);
 
+//     printf("c_norm: %lg\n", arma::norm(bounds->center(), 2));
+
     double furthest_cosine;
     FurthestColumnIndex(bounds->center(), matrix, begin, count,
 			&furthest_cosine);
@@ -197,7 +199,10 @@
     node->bound().center() += left->count() * left->bound().center();
     node->bound().center() += right->count() * right->bound().center();
     node->bound().center() /= (double) node->count();
+//     node->bound().center() /= arma::norm(node->bound().center(), 2);
     
+//     printf("c_norm: %lg\n", arma::norm(node->bound().center(), 2));
+
     double left_min_cosine, right_min_cosine;
     FurthestColumnIndex(node->bound().center(), matrix, left->begin(), 
 			left->count(), &left_min_cosine);




More information about the mlpack-svn mailing list