[mlpack-svn] r14559 - mlpack/trunk/src/mlpack/methods/fastmks
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Mar 14 14:50:06 EDT 2013
Author: rcurtin
Date: 2013-03-14 14:50:06 -0400 (Thu, 14 Mar 2013)
New Revision: 14559
Added:
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp
Modified:
mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
Log:
Make FastMKS work in the dual-tree setting. Hooray!
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp 2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks.hpp 2013-03-14 18:50:06 UTC (rev 14559)
@@ -9,7 +9,8 @@
#include <mlpack/core.hpp>
#include "ip_metric.hpp"
-#include <mlpack/core/tree/cover_tree/cover_tree.hpp>
+#include "fastmks_stat.hpp"
+#include <mlpack/core/tree/cover_tree.hpp>
namespace mlpack {
namespace fastmks {
@@ -42,9 +43,11 @@
const arma::mat& querySet;
- tree::CoverTree<IPMetric<KernelType> >* referenceTree;
+ tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
+ referenceTree;
- tree::CoverTree<IPMetric<KernelType> >* queryTree;
+ tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot, FastMKSStat>*
+ queryTree;
bool single;
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp 2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp 2013-03-14 18:50:06 UTC (rev 14559)
@@ -18,191 +18,6 @@
namespace mlpack {
namespace fastmks {
-template<typename TreeType>
-void RecurseTreeCountLeaves(const TreeType& node, arma::vec& counts)
-{
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).NumChildren() == 0)
- counts[node.Child(i).Point()]++;
- else
- RecurseTreeCountLeaves<TreeType>(node.Child(i), counts);
- }
-}
-
-template<typename TreeType>
-void CheckSelfChild(const TreeType& node)
-{
- if (node.NumChildren() == 0)
- return; // No self-child applicable here.
-
- bool found = false;
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- if (node.Child(i).Point() == node.Point())
- found = true;
-
- // Recursively check the children.
- CheckSelfChild(node.Child(i));
- }
-
- // Ensure this has its own self-child.
- Log::Assert(found == true);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckCovering(const TreeType& node)
-{
- // Return if a leaf. No checking necessary.
- if (node.NumChildren() == 0)
- return;
-
- const arma::mat& dataset = node.Dataset();
- const size_t nodePoint = node.Point();
-
- // To ensure that this node satisfies the covering principle, we must ensure
- // that the distance to each child is less than pow(expansionConstant, scale).
- double maxDistance = pow(node.Base(), node.Scale());
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- const size_t childPoint = node.Child(i).Point();
-
- double distance = MetricType::Evaluate(dataset.col(nodePoint),
- dataset.col(childPoint));
-
- Log::Assert(distance <= maxDistance);
-
- // Check the child.
- CheckCovering<TreeType, MetricType>(node.Child(i));
- }
-}
-
-template<typename TreeType, typename MetricType>
-void CheckIndividualSeparation(const TreeType& constantNode,
- const TreeType& node)
-{
- // Don't check points at a lower scale.
- if (node.Scale() < constantNode.Scale())
- return;
-
- // If at a higher scale, recurse.
- if (node.Scale() > constantNode.Scale())
- {
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- // Don't recurse into leaves.
- if (node.Child(i).NumChildren() > 0)
- CheckIndividualSeparation<TreeType, MetricType>(constantNode,
- node.Child(i));
- }
-
- return;
- }
-
- // Don't compare the same point against itself.
- if (node.Point() == constantNode.Point())
- return;
-
- // Now we know we are at the same scale, so make the comparison.
- const arma::mat& dataset = constantNode.Dataset();
- const size_t constantPoint = constantNode.Point();
- const size_t nodePoint = node.Point();
-
- // Make sure the distance is at least the following value (in accordance with
- // the separation principle of cover trees).
- double minDistance = pow(constantNode.ExpansionConstant(),
- constantNode.Scale());
-
- double distance = MetricType::Evaluate(dataset.col(constantPoint),
- dataset.col(nodePoint));
-
-}
-
-template<typename TreeType, typename MetricType>
-void CheckSeparation(const TreeType& node, const TreeType& root)
-{
- // Check the separation between this point and all other points on this scale.
- CheckIndividualSeparation<TreeType, MetricType>(node, root);
-
- // Check the children, but only if they are not leaves. Leaves don't need to
- // be checked.
- for (size_t i = 0; i < node.NumChildren(); ++i)
- if (node.Child(i).NumChildren() > 0)
- CheckSeparation<TreeType, MetricType>(node.Child(i), root);
-}
-
-template<typename TreeType, typename MetricType>
-void GetMaxDistance(TreeType& node,
- TreeType& constantNode,
- double& best,
- size_t& index)
-{
- const arma::mat& dataset = node.Dataset();
- const double eval = MetricType::Evaluate(dataset.unsafe_col(node.Point()),
- dataset.unsafe_col(constantNode.Point()));
- if (eval > best)
- {
- best = eval;
- index = node.Point();
- }
-
- // Recurse into children.
- for (size_t i = 0; i < node.NumChildren(); ++i)
- GetMaxDistance<TreeType, MetricType>(node.Child(i), constantNode, best,
- index);
-}
-
-template<typename TreeType, typename MetricType>
-void CheckMaxDistances(TreeType& node)
-{
- // Check child distances.
- for (size_t i = 0; i < node.NumChildren(); ++i)
- {
- const arma::mat& dataset = node.Dataset();
- double eval = MetricType::Evaluate(dataset.unsafe_col(node.Point()),
- dataset.unsafe_col(node.Child(i).Point()));
-
- Log::Assert(std::abs(eval - node.Child(i).ParentDistance()) < 1e-10);
- }
-
- // Check all descendants.
- double maxDescendantDistance = 0;
- size_t maxIndex = 0;
- GetMaxDistance<TreeType, MetricType>(node, node, maxDescendantDistance,
- maxIndex);
-
- Log::Assert(std::abs(maxDescendantDistance -
- node.FurthestDescendantDistance()) < 1e-10);
-
- for (size_t i = 0; i < node.NumChildren(); ++i)
- CheckMaxDistances<TreeType, MetricType>(node.Child(i));
-}
-
-template<typename TreeType>
-struct SearchFrame
-{
- TreeType* node;
- double eval;
-};
-
-template<typename TreeType>
-class SearchFrameCompare
-{
- public:
- bool operator()(const SearchFrame<TreeType>& lhs,
- const SearchFrame<TreeType>& rhs)
- {
- // Compare scale.
- if (lhs.node->Scale() != rhs.node->Scale())
- return (lhs.node->Scale() < rhs.node->Scale());
- else
- {
- // Now we have to compare by evaluation.
- return (lhs.eval < rhs.eval);
- }
- }
-};
-
template<typename KernelType>
FastMKS<KernelType>::FastMKS(const arma::mat& referenceSet,
KernelType& kernel,
@@ -223,8 +38,9 @@
if (naive)
referenceTree = NULL;
else
- referenceTree = new tree::CoverTree<IPMetric<KernelType> >(referenceSet,
- expansionConstant, &metric);
+ referenceTree = new tree::CoverTree<IPMetric<KernelType>,
+ tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
+ &metric);
Timer::Stop("tree_building");
}
@@ -248,42 +64,17 @@
if (naive)
referenceTree = NULL;
else
- referenceTree = new tree::CoverTree<IPMetric<KernelType> >(referenceSet,
- expansionConstant, &metric);
+ referenceTree = new tree::CoverTree<IPMetric<KernelType>,
+ tree::FirstPointIsRoot, FastMKSStat>(referenceSet, expansionConstant,
+ &metric);
if (single || naive)
queryTree = NULL;
else
- queryTree = new tree::CoverTree<IPMetric<KernelType> >(querySet,
- expansionConstant, &metric);
+ queryTree = new tree::CoverTree<IPMetric<KernelType>,
+ tree::FirstPointIsRoot, FastMKSStat>(querySet, expansionConstant,
+ &metric);
-/* if (referenceTree != NULL)
- {
- Log::Debug << "Check counts" << std::endl;
- // Now loop through the tree and ensure that each leaf is only created once.
- arma::vec counts;
- counts.zeros(referenceSet.n_elem);
- RecurseTreeCountLeaves(*referenceTree, counts);
-
- // Each point should only have one leaf node representing it.
- for (size_t i = 0; i < 20; ++i)
- Log::Assert(counts[i] == 1);
-
- Log::Debug << "Check self child\n";
- // Each non-leaf should have a self-child.
- CheckSelfChild<tree::CoverTree<IPMetric<KernelType> > >(*referenceTree);
-
- Log::Debug << "Check covering\n";
- // Each node must satisfy the covering principle (its children must be less
- // than or equal to a certain distance apart).
- CheckCovering<tree::CoverTree<IPMetric<KernelType> >, IPMetric<KernelType> >(*referenceTree);
-
- Log::Debug << "Check max distances\n";
- // Check maximum distance of children and grandchildren.
- CheckMaxDistances<tree::CoverTree<IPMetric<KernelType> >, IPMetric<KernelType> >(*referenceTree);
- Log::Debug << "Done\n";
- }*/
-
Timer::Stop("tree_building");
}
@@ -301,10 +92,10 @@
arma::Mat<size_t>& indices,
arma::mat& products)
{
- // No remapping will be necessary.
+ // No remapping will be necessary because we are using the cover tree.
indices.set_size(k, querySet.n_cols);
products.set_size(k, querySet.n_cols);
- products.fill(-1.0);
+ products.fill(-DBL_MAX);
Timer::Start("computing_products");
@@ -318,6 +109,9 @@
{
for (size_t r = 0; r < referenceSet.n_cols; ++r)
{
+ if ((&querySet == &referenceSet) && (q == r))
+ continue;
+
const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
referenceSet.unsafe_col(r));
++kernelEvaluations;
@@ -343,144 +137,21 @@
// Single-tree implementation.
if (single)
{
- // Calculate number of pruned nodes.
- size_t numPrunes = 0;
+ // Create rules object (this will store the results). This constructor
+ // precalculates each self-kernel value.
+ typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
+ FastMKSStat> TreeType;
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
- // Precalculate query products ( || q || for all q).
- arma::vec queryProducts(querySet.n_cols);
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
- kernelEvaluations += querySet.n_cols;
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
- // Screw the CoverTreeTraverser, we'll implement it by hand.
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- {
- // Use an array of priority queues?
- std::priority_queue<
- SearchFrame<tree::CoverTree<IPMetric<KernelType> > >,
- std::vector<SearchFrame<tree::CoverTree<IPMetric<KernelType> > > >,
- SearchFrameCompare<tree::CoverTree<IPMetric<KernelType> > > >
- frameQueue;
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
- // Add initial frame.
- SearchFrame<tree::CoverTree<IPMetric<KernelType> > > nextFrame;
- nextFrame.node = referenceTree;
- nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceTree->Point()));
- ++kernelEvaluations;
+ // Save the number of pruned nodes.
+ const size_t numPrunes = traverser.NumPrunes();
- // The initial evaluation will be the best so far.
- indices(0, queryIndex) = referenceTree->Point();
- products(0, queryIndex) = nextFrame.eval;
-
- frameQueue.push(nextFrame);
-
- tree::CoverTree<IPMetric<KernelType> >* referenceNode;
- double eval;
- double maxProduct;
-
- while (!frameQueue.empty())
- {
- // Get the information for this node.
- const SearchFrame<tree::CoverTree<IPMetric<KernelType> > >& frame =
- frameQueue.top();
-
- referenceNode = frame.node;
- eval = frame.eval;
-
- // Loop through the children, seeing if we can prune them; if not, add
- // them to the queue. The self-child is different -- it has the same
- // parent (and therefore the same kernel evaluation).
- if (referenceNode->NumChildren() > 0)
- {
- SearchFrame<tree::CoverTree<IPMetric<KernelType> > > childFrame;
-
- // We must handle the self-child differently, to avoid adding it to
- // the results twice.
- childFrame.node = &(referenceNode->Child(0));
- childFrame.eval = eval;
-
-// maxProduct = eval + std::pow(childFrame.node->ExpansionConstant(),
-// childFrame.node->Scale() + 1) * queryProducts[queryIndex];
- // Alternate pruning rule.
- maxProduct = eval + childFrame.node->FurthestDescendantDistance() *
- queryProducts[queryIndex];
-
- // Add self-child if we can't prune it.
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // But only if it has children of its own.
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- ++numPrunes;
-
- for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
- {
- // Before we evaluate the child, let's see if it can possibly have
- // a better evaluation.
- double maxChildEval = eval + queryProducts[queryIndex] *
- (referenceNode->Child(i).ParentDistance() +
- referenceNode->Child(i).FurthestDescendantDistance());
-
- if (maxChildEval <= products(products.n_rows - 1, queryIndex))
- {
- ++numPrunes;
- continue; // Skip this child; it can't be any better.
- }
-
- // Evaluate child.
- childFrame.node = &(referenceNode->Child(i));
- childFrame.eval = metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceNode->Child(i).Point()));
- ++kernelEvaluations;
-
- // Can we prune it? If we can, we can avoid putting it in the queue
- // (saves time).
-// double maxProduct = childFrame.eval +
-// std::pow(childFrame.node->ExpansionConstant(),
-// childFrame.node->Scale() + 1) * queryProducts[queryIndex];
- // Alternate pruning rule.
- maxProduct = childFrame.eval + queryProducts[queryIndex] *
- childFrame.node->FurthestDescendantDistance();
-
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // Good enough to recurse into. While we're at it, check the
- // actual evaluation and see if it's an improvement.
- if (childFrame.eval > products(products.n_rows - 1, queryIndex))
- {
- // This is a better result. Find out where to insert.
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows - 1; ++insertPosition)
- if (childFrame.eval > products(insertPosition, queryIndex))
- break;
-
- // Insert into the correct position; we are guaranteed that
- // insertPosition is valid.
- InsertNeighbor(indices, products, queryIndex, insertPosition,
- childFrame.node->Point(), childFrame.eval);
- }
-
- // Now add this to the queue (if it has any children which may
- // need to be recursed into).
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- {
- ++numPrunes;
- }
- }
- }
-
- frameQueue.pop();
- }
- }
-
Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
<< std::endl;
@@ -491,8 +162,28 @@
return;
}
- // Double-tree implementation.
- Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
+ // Dual-tree implementation.
+ typedef tree::CoverTree<IPMetric<KernelType>, tree::FirstPointIsRoot,
+ FastMKSStat> TreeType;
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ if (queryTree)
+ traverser.Traverse(*queryTree, *referenceTree);
+ else
+ traverser.Traverse(*referenceTree, *referenceTree);
+
+ const size_t numPrunes = traverser.NumPrunes();
+
+ Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+ Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
+ Log::Info << "Distance evaluations: " << distanceEvaluations << "."
+ << std::endl;
+
+ Timer::Stop("computing_products");
+ return;
}
/**
@@ -529,6 +220,7 @@
}
// Specialized implementation for tighter bounds for Gaussian.
+/*
template<>
void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
arma::Mat<size_t>& indices,
@@ -741,6 +433,7 @@
Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
}
+*/
}; // namespace fastmks
}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp 2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules.hpp 2013-03-14 18:50:06 UTC (rev 14559)
@@ -13,21 +13,97 @@
namespace mlpack {
namespace fastmks {
-template<typename MetricType>
+template<typename KernelType, typename TreeType>
class FastMKSRules
{
public:
FastMKSRules(const arma::mat& referenceSet,
const arma::mat& querySet,
arma::Mat<size_t>& indices,
- arma::mat& products);
+ arma::mat& products,
+ KernelType& kernel);
- void BaseCase(const size_t queryIndex, const size_t referenceIndex);
+ //! Compute the base case (kernel value) between two points.
+ double BaseCase(const size_t queryIndex, const size_t referenceIndex);
- bool CanPrune(const size_t queryIndex,
- tree::CoverTree<MetricType>& referenceNode,
- const size_t parentIndex);
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate to be recursed into.
+ */
+ double Score(const size_t queryIndex, TreeType& referenceNode) const;
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
+ */
+ double Score(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Get the score for recursion order. A low score indicates priority for
+ * recursion, while DBL_MAX indicates that the node should not be recursed
+ * into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ */
+ double Score(TreeType& queryNode, TreeType& referenceNode) const;
+
+ /**
+ * Get the score for recursion order, passing the base case result (in the
+ * situation where it may be needed to calculate the recursion order). A low
+ * score indicates priority for recursion, while DBL_MAX indicates that the
+ * node should not be recursed into at all (it should be pruned).
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ * @param baseCaseResult Result of BaseCase(queryNode, referenceNode).
+ */
+ double Score(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that a node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryIndex Index of query point.
+ * @param referenceNode Candidate node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(const size_t queryIndex,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
+ /**
+ * Re-evaluate the score for recursion order. A low score indicates priority
+ * for recursion, while DBL_MAX indicates that a node should not be recursed
+ * into at all (it should be pruned). This is used when the score has already
+ * been calculated, but another recursion may have modified the bounds for
+ * pruning. So the old score is checked against the new pruning bound.
+ *
+ * @param queryNode Candidate query node to be recursed into.
+ * @param referenceNode Candidate reference node to be recursed into.
+ * @param oldScore Old score produced by Score() (or Rescore()).
+ */
+ double Rescore(TreeType& queryNode,
+ TreeType& referenceNode,
+ const double oldScore) const;
+
private:
const arma::mat& referenceSet;
@@ -38,7 +114,14 @@
arma::mat& products;
arma::vec queryKernels; // || q || for each q.
+ arma::vec referenceKernels;
+ //! The instantiated kernel.
+ KernelType& kernel;
+
+ //! Calculate the bound for a given query node.
+ double CalculateBound(TreeType& queryNode) const;
+
void InsertNeighbor(const size_t queryIndex,
const size_t pos,
const size_t neighbor,
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp 2013-03-14 02:26:39 UTC (rev 14558)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp 2013-03-14 18:50:06 UTC (rev 14559)
@@ -13,57 +13,237 @@
namespace mlpack {
namespace fastmks {
-template<typename MetricType>
-FastMKSRules<MetricType>::FastMKSRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products) :
+template<typename KernelType, typename TreeType>
+FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
+ const arma::mat& querySet,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel) :
referenceSet(referenceSet),
querySet(querySet),
indices(indices),
- products(products)
+ products(products),
+ kernel(kernel)
{
// Precompute each self-kernel.
-// queryKernels.set_size(querySet.n_cols);
-// for (size_t i = 0; i < querySet.n_cols; ++i)
-// queryKernels[i] = sqrt(MetricType::Kernel::Evaluate(querySet.unsafe_col(i),
-// querySet.unsafe_col(i)));
+ queryKernels.set_size(querySet.n_cols);
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
+ querySet.unsafe_col(i)));
+
+ referenceKernels.set_size(referenceSet.n_cols);
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
+ referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
+ referenceSet.unsafe_col(i)));
}
-template<typename MetricType>
-bool FastMKSRules<MetricType>::CanPrune(const size_t queryIndex,
- tree::CoverTree<MetricType>& referenceNode,
- const size_t parentIndex)
+template<typename KernelType, typename TreeType>
+double FastMKSRules<KernelType, TreeType>::BaseCase(
+ const size_t queryIndex,
+ const size_t referenceIndex)
{
- // The maximum possible inner product is given by
- // <q, p_0> + R_p || q ||
- // and since we are using cover trees, p_0 is the point referred to by the
- // node, and R_p will be the expansion constant to the power of the scale plus
- // one.
- const double eval = MetricType::Kernel::Evaluate(querySet.col(queryIndex),
- referenceSet.col(referenceNode.Point()));
- // See if base case can be added.
- if (eval > products(products.n_rows - 1, queryIndex))
+ double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
+ referenceSet.unsafe_col(referenceIndex));
+
+ // If the reference and query sets are identical, we still need to compute the
+ // base case (so that things can be bounded properly), but we won't add it to
+ // the results.
+ if ((&querySet == &referenceSet) && (queryIndex == referenceIndex))
+ return kernelEval;
+
+ // If this is a better candidate, insert it into the list.
+ if (kernelEval < products(products.n_rows - 1, queryIndex))
+ return kernelEval;
+
+ size_t insertPosition = 0;
+ for ( ; insertPosition < products.n_rows; ++insertPosition)
+ if (kernelEval >= products(insertPosition, queryIndex))
+ break;
+
+ InsertNeighbor(queryIndex, insertPosition, referenceIndex, kernelEval);
+
+ return kernelEval;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(const size_t queryIndex,
+ TreeType& referenceNode) const
+{
+ // Calculate the maximum possible kernel value.
+ const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
+ const arma::vec refCentroid;
+ referenceNode.Bound().Centroid(refCentroid);
+
+ const double maxKernel = kernel.Evaluate(queryPoint, refCentroid) +
+ referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+
+ // Compare with the current best.
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+ const size_t queryIndex,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ // We already have the base case result. Add the bound.
+ const double maxKernel = baseCaseResult +
+ referenceNode.FurthestDescendantDistance() * queryKernels[queryIndex];
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(TreeType& queryNode,
+ TreeType& referenceNode) const
+{
+ // Calculate the maximum possible kernel value.
+ const arma::vec queryCentroid;
+ const arma::vec refCentroid;
+ queryNode.Bound().Centroid(queryCentroid);
+ referenceNode.Bound().Centroid(refCentroid);
+
+ const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+ referenceNode.Stat().SelfKernel();
+ const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+ queryNode.Stat().SelfKernel();
+
+ const double maxKernel = kernel.Evaluate(queryCentroid, refCentroid) +
+ refKernelTerm + queryKernelTerm +
+ (queryNode.FurthestDescendantDistance() *
+ referenceNode.FurthestDescendantDistance());
+
+ // The existing bound.
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Score(
+ TreeType& queryNode,
+ TreeType& referenceNode,
+ const double baseCaseResult) const
+{
+ // We already have the base case, so we need to add the bounds.
+ const double refKernelTerm = queryNode.FurthestDescendantDistance() *
+ referenceNode.Stat().SelfKernel();
+ const double queryKernelTerm = referenceNode.FurthestDescendantDistance() *
+ queryNode.Stat().SelfKernel();
+
+ const double maxKernel = baseCaseResult + refKernelTerm + queryKernelTerm +
+ (queryNode.FurthestDescendantDistance() *
+ referenceNode.FurthestDescendantDistance());
+
+ // The existing bound.
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ // We return the inverse of the maximum kernel so that larger kernels are
+ // recursed into first.
+ return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(const size_t queryIndex,
+ TreeType& /*referenceNode*/,
+ const double oldScore) const
+{
+ const double bestKernel = products(products.n_rows - 1, queryIndex);
+
+ return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::Rescore(TreeType& queryNode,
+ TreeType& /*referenceNode*/,
+ const double oldScore) const
+{
+ queryNode.Stat().Bound() = CalculateBound(queryNode);
+ const double bestKernel = queryNode.Stat().Bound();
+
+ return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
+}
+
+/**
+ * Calculate the bound for the given query node. This bound represents the
+ * minimum value which a node combination must achieve to guarantee an
+ * improvement in the results.
+ *
+ * @param queryNode Query node to calculate bound for.
+ */
+template<typename MetricType, typename TreeType>
+double FastMKSRules<MetricType, TreeType>::CalculateBound(TreeType& queryNode)
+ const
+{
+ // We have four possible bounds -- just like NeighborSearchRules, but they are
+ // slightly different in this context.
+ //
+ // (1) min ( min_{all points p in queryNode} P_p[k],
+ // min_{all children c in queryNode} B(c) );
+ // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
+ // descendant distance) sqrt(K(I_p[k], I_p[k]));
+ // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored.
+ // (4) B(parent of queryNode);
+ double worstPointKernel = DBL_MAX;
+ double bestAdjustedPointKernel = -DBL_MAX;
+ double bestPointSelfKernel = -DBL_MAX;
+ const double queryDescendantDistance = queryNode.FurthestDescendantDistance();
+
+ // Loop over all points in this node to find the best and worst.
+ for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows; ++insertPosition)
- if (eval > products(insertPosition, queryIndex))
- break;
+ const size_t point = queryNode.Point(i);
+ if (products(products.n_rows - 1, point) < worstPointKernel)
+ worstPointKernel = products(products.n_rows - 1, point);
- // We are guaranteed insertPosition is in the valid range.
- InsertNeighbor(queryIndex, insertPosition, referenceNode.Point(), eval);
+ if (products(products.n_rows - 1, point) == -DBL_MAX)
+ continue; // Avoid underflow.
+
+ const double candidateKernel = products(products.n_rows - 1, point) -
+ (2 * queryDescendantDistance) *
+ referenceKernels[indices(indices.n_rows - 1, point)];
+
+ if (candidateKernel > bestAdjustedPointKernel)
+ bestAdjustedPointKernel = candidateKernel;
}
- double maxProduct = eval + std::pow(referenceNode.ExpansionConstant(),
- referenceNode.Scale() + 1) *
- sqrt(MetricType::Kernel::Evaluate(querySet.col(queryIndex),
- querySet.col(queryIndex)));
+ // Loop over all the children in the node.
+ double worstChildKernel = DBL_MAX;
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- return false;
- else
- return true;
+ for (size_t i = 0; i < queryNode.NumChildren(); ++i)
+ {
+ if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
+ worstChildKernel = queryNode.Child(i).Stat().Bound();
+ }
+
+ // Now assemble bound (1).
+ const double firstBound = (worstPointKernel < worstChildKernel) ?
+ worstPointKernel : worstChildKernel;
+
+ // Bound (2) is bestAdjustedPointKernel.
+ const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
+ queryNode.Parent()->Stat().Bound();
+
+ // Pick the best of these bounds.
+ const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
+ bestAdjustedPointKernel;
+// const double interA = 0.0;
+ const double interB = fourthBound;
+
+ return (interA > interB) ? interA : interB;
}
/**
@@ -74,11 +254,11 @@
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename MetricType>
-void FastMKSRules<MetricType>::InsertNeighbor(const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
+template<typename MetricType, typename TreeType>
+void FastMKSRules<MetricType, TreeType>::InsertNeighbor(const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
{
// We only memmove() if there is actually a need to shift something.
if (pos < (products.n_rows - 1))
Added: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_stat.hpp 2013-03-14 18:50:06 UTC (rev 14559)
@@ -0,0 +1,89 @@
+/**
+ * @file fastmks_stat.hpp
+ * @author Ryan Curtin
+ *
+ * The statistic used in trees with FastMKS.
+ */
+#ifndef __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+#define __MLPACK_METHODS_FASTMKS_FASTMKS_STAT_HPP
+
+#include <mlpack/core.hpp>
+#include <mlpack/core/tree/tree_traits.hpp>
+
+namespace mlpack {
+namespace fastmks {
+
+/**
+ * The statistic used in trees with FastMKS. This stores both the bound and the
+ * self-kernels for each node in the tree.
+ */
+class FastMKSStat
+{
+ public:
+ /**
+ * Default initialization.
+ */
+ FastMKSStat() : bound(-DBL_MAX), selfKernel(0.0) { }
+
+ /**
+ * Initialize this statistic for the given tree node. The TreeType's metric
+ * better be IPMetric with some kernel type (that is, Metric().Kernel() must
+ * exist).
+ *
+ * @param node Node that this statistic is built for.
+ */
+ template<typename TreeType>
+ FastMKSStat(const TreeType& node) :
+ bound(-DBL_MAX)
+ {
+ // Do we have to calculate the centroid?
+ if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
+ {
+ // If this type of tree has self-children, then maybe the evaluation is
+ // already done. These statistics are built bottom-up, so the child stat
+ // should already be done.
+ if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
+ (node.NumChildren() > 0) &&
+ (node.Point(0) == node.Child(0).Point(0)))
+ {
+ selfKernel = node.Child(0).Stat().SelfKernel();
+ }
+ else
+ {
+ selfKernel = sqrt(node.Metric().Kernel().Evaluate(
+ node.Dataset().unsafe_col(node.Point(0)),
+ node.Dataset().unsafe_col(node.Point(0))));
+ }
+ }
+ else
+ {
+ // Calculate the centroid.
+ arma::vec centroid;
+ node.Centroid(centroid);
+
+ selfKernel = sqrt(node.Metric().Kernel().Evaluate(centroid, centroid));
+ }
+ }
+
+ //! Get the self-kernel.
+ double SelfKernel() const { return selfKernel; }
+ //! Modify the self-kernel.
+ double& SelfKernel() { return selfKernel; }
+
+ //! Get the bound.
+ double Bound() const { return bound; }
+ //! Modify the bound.
+ double& Bound() { return bound; }
+
+ private:
+ //! The bound for pruning.
+ double bound;
+
+ //! The self-kernel evaluation: sqrt(K(centroid, centroid)).
+ double selfKernel;
+};
+
+}; // namespace fastmks
+}; // namespace mlpack
+
+#endif
More information about the mlpack-svn
mailing list