[mlpack-svn] r12746 - mlpack/trunk/src/mlpack/methods/maxip
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon May 21 14:51:22 EDT 2012
Author: rcurtin
Date: 2012-05-21 14:51:21 -0400 (Mon, 21 May 2012)
New Revision: 12746
Modified:
mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
Log:
Revamp breadth-first descent to consider tree levels so it is the proper
breadth-first descent.
Modified: mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-05-21 18:39:45 UTC (rev 12745)
+++ mlpack/trunk/src/mlpack/methods/maxip/max_ip_impl.hpp 2012-05-21 18:51:21 UTC (rev 12746)
@@ -19,10 +19,27 @@
struct SearchFrame
{
TreeType* node;
- size_t parent;
- double parentEval;
+ double eval;
};
+template<typename TreeType>
+class SearchFrameCompare
+{
+ public:
+ bool operator()(const SearchFrame<TreeType>& lhs,
+ const SearchFrame<TreeType>& rhs)
+ {
+ // Compare scale.
+ if (lhs.node->Scale() != rhs.node->Scale())
+ return (lhs.node->Scale() < rhs.node->Scale());
+ else
+ {
+ // Now we have to compare by evaluation.
+ return (lhs.eval < rhs.eval);
+ }
+ }
+};
+
template<typename KernelType>
MaxIP<KernelType>::MaxIP(const arma::mat& referenceSet,
bool single,
@@ -117,7 +134,7 @@
}
Timer::Stop("computing_products");
-
+
Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
return;
}
@@ -138,88 +155,116 @@
// Screw the CoverTreeTraverser, we'll implement it by hand.
for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
{
- std::queue<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >
+ // Use an array of priority queues?
+ std::priority_queue<
+ SearchFrame<tree::CoverTree<IPMetric<KernelType> > >,
+ std::vector<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >,
+ SearchFrameCompare<tree::CoverTree<IPMetric<KernelType> > > >
frameQueue;
+ // Add initial frame.
SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
nextFrame.node = referenceTree;
- nextFrame.parent = size_t() - 1;
- nextFrame.parentEval = 0;
+ nextFrame.eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceTree->Point()));
+ // The initial evaluation will be the best so far.
+ indices(0, queryIndex) = referenceTree->Point();
+ products(0, queryIndex) = nextFrame.eval;
+
frameQueue.push(nextFrame);
tree::CoverTree<IPMetric<KernelType> >* referenceNode;
- size_t currentParent;
- double currentParentEval;
- double eval; // Kernel evaluation.
+ double eval;
+ double maxProduct;
while (!frameQueue.empty())
{
// Get the information for this node.
const SearchFrame<tree::CoverTree<IPMetric<KernelType> > >& frame =
- frameQueue.front();
+ frameQueue.top();
+
referenceNode = frame.node;
- currentParent = frame.parent;
- currentParentEval = frame.parentEval;
+ eval = frame.eval;
- frameQueue.pop();
-
- // See if this has the same parent.
- if (referenceNode->Point() == currentParent)
+ // Loop through the children, seeing if we can prune them; if not, add
+ // them to the queue. The self-child is different -- it has the same
+ // parent (and therefore the same kernel evaluation).
+ if (referenceNode->NumChildren() > 0)
{
- // We don't have to evaluate the kernel again.
- eval = currentParentEval;
- }
- else
- {
- // Evaluate the kernel. Then see if it is a result to keep.
- eval = KernelType::Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceNode->Point()));
- ++kernelEvaluations;
+ SearchFrame<tree::CoverTree<IPMetric<KernelType> > > childFrame;
- // Is the result good enough to be saved?
- if (eval > products(products.n_rows - 1, queryIndex))
+ // We must handle the self-child differently, to avoid adding it to
+ // the results twice.
+ childFrame.node = &(referenceNode->Child(0));
+ childFrame.eval = eval;
+
+ maxProduct = eval + std::pow(childFrame.node->ExpansionConstant(),
+ childFrame.node->Scale() + 1) * queryProducts[queryIndex];
+
+ // Add self-child if we can't prune it.
+ if (maxProduct > products(products.n_rows - 1, queryIndex))
{
- // Figure out where to insert.
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
- if (eval > products(insertPosition, queryIndex))
- break;
-
- // We are guaranteed that insertPosition is valid.
- InsertNeighbor(indices, products, queryIndex, insertPosition,
- referenceNode->Point(), eval);
+ // But only if it has children of its own.
+ if (childFrame.node->NumChildren() > 0)
+ frameQueue.push(childFrame);
}
- }
+ else
+ ++numPrunes;
- // Now discover if we can prune this node or not.
- double maxProduct = eval + std::pow(referenceNode->ExpansionConstant(),
- referenceNode->Scale() + 1) * queryProducts[queryIndex];
-
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // We can't prune. So add our children.
- for (size_t i = 0; i < referenceNode->NumChildren(); ++i)
+ for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
{
- SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
+ // Evaluate child.
+ childFrame.node = &(referenceNode->Child(i));
+ childFrame.eval = KernelType::Evaluate(
+ querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceNode->Child(i).Point()));
- nextFrame.node = &(referenceNode->Child(i));
- nextFrame.parent = referenceNode->Point();
- nextFrame.parentEval = eval;
+ // Can we prune it? If we can, we can avoid putting it in the queue
+ // (saves time).
+ double maxProduct = childFrame.eval +
+ std::pow(childFrame.node->ExpansionConstant(),
+ childFrame.node->Scale() + 1) * queryProducts[queryIndex];
- frameQueue.push(nextFrame);
+ if (maxProduct > products(products.n_rows - 1, queryIndex))
+ {
+ // Good enough to recurse into. While we're at it, check the
+ // actual evaluation and see if it's an improvement.
+ if (childFrame.eval > products(products.n_rows - 1, queryIndex))
+ {
+ // This is a better result. Find out where to insert.
+ size_t insertPosition = 0;
+ for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
+ if (childFrame.eval > products(insertPosition, queryIndex))
+ break;
+
+ // Insert into the correct position; we are guaranteed that
+ // insertPosition is valid.
+ InsertNeighbor(indices, products, queryIndex, insertPosition,
+ childFrame.node->Point(), childFrame.eval);
+ }
+
+ // Now add this to the queue (if it has any children which may
+ // need to be recursed into).
+ if (childFrame.node->NumChildren() > 0)
+ frameQueue.push(childFrame);
+ }
+ else
+ {
+ ++numPrunes;
+ }
}
}
- else
- {
- numPrunes++;
- }
+
+ frameQueue.pop();
}
}
Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
- Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
- Log::Info << "Distance evaluations: " << distanceEvaluations << "." << std::endl;
+ Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
+ << std::endl;
+ Log::Info << "Distance evaluations: " << distanceEvaluations << "."
+ << std::endl;
Timer::Stop("computing_products");
return;
More information about the mlpack-svn
mailing list