[mlpack-svn] r12723 - mlpack/trunk/src/mlpack/methods/maxip
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri May 18 18:00:36 EDT 2012
Author: rcurtin
Date: 2012-05-18 18:00:35 -0400 (Fri, 18 May 2012)
New Revision: 12723
Modified:
mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
Log:
Reimplement single-tree search. It's faster now.
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-05-18 22:00:35 UTC (rev 12723)
@@ -10,7 +10,7 @@
// In case it hasn't yet been included.
#include "max_ip.hpp"
-#include <mlpack/core/tree/traversers/single_tree_breadth_first_traverser.hpp>
+#include <mlpack/core/tree/traversers/single_cover_tree_traverser.hpp>
#include "max_ip_rules.hpp"
namespace mlpack {
@@ -112,18 +112,91 @@
// Single-tree implementation.
if (single)
{
- MaxIPRules<IPMetric<KernelType> > rules(referenceSet, querySet, indices,
- products);
+ // Calculate number of pruned nodes.
+ size_t numPrunes = 0;
- tree::SingleTreeBreadthFirstTraverser<
- tree::CoverTree<IPMetric<KernelType> >,
- MaxIPRules<IPMetric<KernelType> > > traverser(rules);
+ // Precalculate query products ( || q || for all q).
+ arma::vec queryProducts(querySet.n_cols);
+ for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+ queryProducts[queryIndex] = KernelType::Evaluate(
+ querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex));
- for (size_t i = 0; i < querySet.n_cols; ++i)
- traverser.Traverse(i, *referenceTree);
+ // Screw the CoverTreeTraverser, we'll implement it by hand.
+ for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
+ {
+ std::queue<tree::CoverTree<IPMetric<KernelType> >*> pointQueue;
+ std::queue<size_t> parentQueue;
+ std::queue<double> parentEvalQueue;
+ pointQueue.push(referenceTree);
+ parentQueue.push(size_t() - 1); // Has no parent.
+ parentEvalQueue.push(0); // No possible parent evaluation.
- Log::Info << "Pruned " << traverser.NumPrunes() << " nodes." << std::endl;
+ tree::CoverTree<IPMetric<KernelType> >* referenceNode;
+ size_t currentParent;
+ double currentParentEval;
+ double eval; // Kernel evaluation.
+ while (!pointQueue.empty())
+ {
+ // Get the information for this node.
+ referenceNode = pointQueue.front();
+ currentParent = parentQueue.front();
+ currentParentEval = parentEvalQueue.front();
+
+ pointQueue.pop();
+ parentQueue.pop();
+ parentEvalQueue.pop();
+
+ // See if this has the same parent.
+ if (referenceNode->Point() == currentParent)
+ {
+ // We don't have to evaluate the kernel again.
+ eval = currentParentEval;
+ }
+ else
+ {
+ // Evaluate the kernel. Then see if it is a result to keep.
+ eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceNode->Point()));
+
+ // Is the result good enough to be saved?
+ if (eval > products(products.n_rows - 1, queryIndex))
+ {
+ // Figure out where to insert.
+ size_t insertPosition = 0;
+ for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
+ if (eval > products(insertPosition, queryIndex))
+ break;
+
+ // We are guaranteed that insertPosition is valid.
+ InsertNeighbor(indices, products, queryIndex, insertPosition,
+ referenceNode->Point(), eval);
+ }
+ }
+
+ // Now discover if we can prune this node or not.
+ double maxProduct = eval + std::pow(referenceNode->ExpansionConstant(),
+ referenceNode->Scale() + 1) * queryProducts[queryIndex];
+
+ if (maxProduct > products(products.n_rows - 1, queryIndex))
+ {
+ // We can't prune. So add our children.
+ for (size_t i = 0; i < referenceNode->NumChildren(); ++i)
+ {
+ pointQueue.push(&(referenceNode->Child(i)));
+ parentQueue.push(referenceNode->Point());
+ parentEvalQueue.push(eval);
+ }
+ }
+ else
+ {
+ numPrunes++;
+ }
+ }
+ }
+
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+
Timer::Stop("computing_products");
return;
}
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp 2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules.hpp 2012-05-18 22:00:35 UTC (rev 12723)
@@ -26,7 +26,8 @@
void BaseCase(const size_t queryIndex, const size_t referenceIndex);
bool CanPrune(const size_t queryIndex,
- tree::CoverTree<MetricType>& referenceNode);
+ tree::CoverTree<MetricType>& referenceNode,
+ const size_t parentIndex);
private:
const arma::mat& referenceSet;
@@ -37,6 +38,8 @@
arma::mat& products;
+ arma::vec queryKernels; // || q || for each q.
+
void InsertNeighbor(const size_t queryIndex,
const size_t pos,
const size_t neighbor,
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp 2012-05-18 21:56:22 UTC (rev 12722)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_rules_impl.hpp 2012-05-18 22:00:35 UTC (rev 12723)
@@ -22,27 +22,26 @@
querySet(querySet),
indices(indices),
products(products)
-{ /* Nothing left to do. */ }
-
-template<typename MetricType>
-void MaxIPRules<MetricType>::BaseCase(const size_t queryIndex,
- const size_t referenceIndex)
{
- // Should be optimized out...
+ // Precompute each self-kernel.
+// queryKernels.set_size(querySet.n_cols);
+// for (size_t i = 0; i < querySet.n_cols; ++i)
+// queryKernels[i] = sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(i),
+// querySet.unsafe_col(i)));
}
template<typename MetricType>
bool MaxIPRules<MetricType>::CanPrune(const size_t queryIndex,
- tree::CoverTree<MetricType>& referenceNode)
+ tree::CoverTree<MetricType>& referenceNode,
+ const size_t parentIndex)
{
// The maximum possible inner product is given by
// <q, p_0> + R_p || q ||
// and since we are using cover trees, p_0 is the point referred to by the
// node, and R_p will be the expansion constant to the power of the scale plus
// one.
- const double eval = MetricType::Kernel::Evaluate(
- querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceNode.Point()));
+ const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceNode.Point()));
// See if base case can be added.
if (eval > products(products.n_rows - 1, queryIndex))
@@ -58,8 +57,8 @@
double maxProduct = eval + std::pow(referenceNode.ExpansionConstant(),
referenceNode.Scale() + 1) *
- sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(queryIndex),
- querySet.unsafe_col(queryIndex)));
+sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
+querySet.col(queryIndex)));
if (maxProduct > products(products.n_rows - 1, queryIndex))
return false;
More information about the mlpack-svn
mailing list