[mlpack-git] master: Refactor FastMKS to take individual queries in Search(). (3ad3877)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Thu Apr 23 10:40:29 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/57d0567dddff01feea73b348f38cc040dc3cf8e3...3ad38770911c7b9840901f0934bd1a81c2249046
>---------------------------------------------------------------
commit 3ad38770911c7b9840901f0934bd1a81c2249046
Author: ryan <ryan at ratml.org>
Date: Thu Apr 23 10:40:10 2015 -0400
Refactor FastMKS to take individual queries in Search().
>---------------------------------------------------------------
3ad38770911c7b9840901f0934bd1a81c2249046
src/mlpack/methods/fastmks/fastmks.hpp | 150 ++++-----
src/mlpack/methods/fastmks/fastmks_impl.hpp | 490 ++++++++--------------------
src/mlpack/methods/fastmks/fastmks_main.cpp | 111 ++++---
3 files changed, 270 insertions(+), 481 deletions(-)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 90ca995..72ea34b 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -55,106 +55,102 @@ class FastMKS
{
public:
/**
- * Create the FastMKS object using the reference set as the query set.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
+ * Create the FastMKS object with the given reference set (this is the set
+ * that is searched). Optionally, specify whether or not single-tree search
+ * or naive (brute-force) search should be used.
*
- * @param referenceSet Set of data to run FastMKS on.
+ * @param referenceSet Set of reference data.
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
FastMKS(const typename TreeType::Mat& referenceSet,
- const bool single = false,
+ const bool singleMode = false,
const bool naive = false);
/**
- * Create the FastMKS object using separate reference and query sets.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
+ * Create the FastMKS object using the reference set (this is the set that is
+ * searched) with an initialized kernel. This is useful for when the kernel
+ * stores state. Optionally, specify whether or not single-tree search or
+ * naive (brute-force) search should be used.
*
* @param referenceSet Reference set of data for FastMKS.
- * @param querySet Set of query points for FastMKS.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object using the reference set as the query set, and
- * with an initialized kernel. This is useful for when the kernel stores
- * state. Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param kernel Initialized kernel.
- * @param single Whether or not to run single-tree search.
- * @param naive Whether or not to run brute-force (naive) search.
- */
- FastMKS(const typename TreeType::Mat& referenceSet,
- KernelType& kernel,
- const bool single = false,
- const bool naive = false);
-
- /**
- * Create the FastMKS object using separate reference and query sets, and with
- * an initialized kernel. This is useful for when the kernel stores state.
- * Optionally, specify whether or not single-tree search or naive
- * (brute-force) search should be used.
- *
- * @param referenceSet Reference set of data for FastMKS.
- * @param querySet Set of query points for FastMKS.
* @param kernel Initialized kernel.
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
FastMKS(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
KernelType& kernel,
- const bool single = false,
+ const bool singleMode = false,
const bool naive = false);
/**
* Create the FastMKS object with an already-initialized tree built on the
* reference points. Be sure that the tree is built with the metric type
- * IPMetric<KernelType>. For this constructor, the reference set and the
- * query set are the same points. Optionally, whether or not to run
- * single-tree search or brute-force (naive) search can be specified.
+ * IPMetric<KernelType>. Optionally, whether or not to run single-tree search
+ * can be specified. Brute-force search is not available with this
+ * constructor since a tree is given (use one of the other constructors).
*
- * @param referenceSet Reference set of data for FastMKS.
* @param referenceTree Tree built on reference data.
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const typename TreeType::Mat& referenceSet,
- TreeType* referenceTree,
- const bool single = false,
- const bool naive = false);
+ FastMKS(TreeType* referenceTree,
+ const bool singleMode = false);
+
+ //! Destructor for the FastMKS object.
+ ~FastMKS();
/**
- * Create the FastMKS object with already-initialized trees built on the
- * reference and query points. Be sure that the trees are built with the
- * metric type IPMetric<KernelType>. Optionally, whether or not to run
- * single-tree search or naive (brute-force) search can be specified.
+ * Search for the points in the reference set with maximum kernel evaluation
+ * to each point in the given query set. The resulting kernel evaluations are
+ * stored in the kernels matrix, and the corresponding point indices are
+ * stored in the indices matrix. The results for each point in the query set
+ * are stored in the corresponding column of the kernels and products
+ * matrices; for instance, the index of the point with maximum kernel
+ * evaluation to point 4 in the query set will be stored in row 0 and column 4
+ * of the indices matrix.
*
- * @param referenceSet Reference set of data for FastMKS.
- * @param referenceTree Tree built on reference data.
- * @param querySet Set of query points for FastMKS.
- * @param queryTree Tree built on query data.
- * @param single Whether or not to use single-tree search.
- * @param naive Whether or not to use naive (brute-force) search.
+ * If querySet only contains a few points, the extra overhead of building a
+ * tree to perform dual-tree search may not be warranted, and it may be faster
+ * to use single-tree search, either by setting singleMode to false in the
+ * constructor or with SingleMode().
+ *
+ * @param querySet Set of query points (can be a single point).
+ * @param k The number of maximum kernels to find.
+ * @param indices Matrix to store resulting indices of max-kernel search in.
+ * @param kernels Matrix to store resulting max-kernel values in.
*/
- FastMKS(const typename TreeType::Mat& referenceSet,
- TreeType* referenceTree,
- const typename TreeType::Mat& querySet,
- TreeType* queryTree,
- const bool single = false,
- const bool naive = false);
+ void Search(const typename TreeType::Mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels);
- //! Destructor for the FastMKS object.
- ~FastMKS();
+ /**
+ * Search for the points in the reference set with maximum kernel evaluation
+ * to each point in the query set corresponding to the given pre-built query
+ * tree. The resulting kernel evaluations are stored in the kernels matrix,
+ * and the corresponding point indices are stored in the indices matrix. The
+ * results for each point in the query set are stored in the corresponding
+ * column of the kernels and products matrices; for instance, the index of the
+ * point with maximum kernel evaluation to point 4 in the query set will be
+ * stored in row 0 and column 4 of the indices matrix.
+ *
+ * This will throw an exception if called while the FastMKS object has
+ * 'single' set to true.
+ *
+ * Be aware that if your tree modifies the original input matrix, the results
+ * here are with respect to the modified input matrix (that is,
+ * queryTree->Dataset()).
+ *
+ * @param queryTree Tree built on query points.
+ * @param k The number of maximum kernels to find.
+ * @param indices Matrix to store resulting indices of max-kernel search in.
+ * @param kernels Matrix to store resulting max-kernel values in.
+ */
+ void Search(TreeType* querySet,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels);
/**
* Search for the maximum inner products of the query set (or if no query set
@@ -179,6 +175,11 @@ class FastMKS
//! Modify the inner-product metric induced by the given kernel.
metric::IPMetric<KernelType>& Metric() { return metric; }
+ //! Get whether or not single-tree search is used.
+ bool SingleMode() const { return singleMode; }
+ //! Modify whether or not single-tree search is used.
+ bool& SingleMode() { return singleMode; }
+
/**
* Returns a string representation of this object.
*/
@@ -187,20 +188,13 @@ class FastMKS
private:
//! The reference dataset.
const typename TreeType::Mat& referenceSet;
- //! The query dataset.
- const typename TreeType::Mat& querySet;
-
//! The tree built on the reference dataset.
TreeType* referenceTree;
- //! The tree built on the query dataset. This is NULL if there is no query
- //! set.
- TreeType* queryTree;
-
- //! If true, this object created the trees and is responsible for them.
+ //! If true, this object created the tree and is responsible for it.
bool treeOwner;
//! If true, single-tree search is used.
- bool single;
+ bool singleMode;
//! If true, naive (brute-force) search is used.
bool naive;
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 19f466e..1835d8d 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -18,17 +18,15 @@
namespace mlpack {
namespace fastmks {
-// Single dataset, no instantiated kernel.
+// No instantiated kernel.
template<typename KernelType, typename TreeType>
FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
- const bool single,
+ const bool singleMode,
const bool naive) :
referenceSet(referenceSet),
- querySet(referenceSet),
referenceTree(NULL),
- queryTree(NULL),
treeOwner(true),
- single(single),
+ singleMode(singleMode),
naive(naive)
{
Timer::Start("tree_building");
@@ -36,50 +34,19 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
if (!naive)
referenceTree = new TreeType(referenceSet);
- if (!naive && !single)
- queryTree = new TreeType(referenceSet);
-
- Timer::Stop("tree_building");
-}
-
-// Two datasets, no instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive)
-{
- Timer::Start("tree_building");
-
- // If necessary, the trees should be built.
- if (!naive)
- referenceTree = new TreeType(referenceSet);
-
- if (!naive && !single)
- queryTree = new TreeType(querySet);
-
Timer::Stop("tree_building");
}
-// One dataset, instantiated kernel.
+// Instantiated kernel.
template<typename KernelType, typename TreeType>
FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
KernelType& kernel,
- const bool single,
+ const bool singleMode,
const bool naive) :
referenceSet(referenceSet),
- querySet(referenceSet),
referenceTree(NULL),
- queryTree(NULL),
treeOwner(true),
- single(single),
+ singleMode(singleMode),
naive(naive),
metric(kernel)
{
@@ -89,135 +56,173 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
if (!naive)
referenceTree = new TreeType(referenceSet, metric);
- if (!naive && !single)
- queryTree = new TreeType(referenceSet, metric);
-
- Timer::Stop("tree_building");
-}
-
-// Two datasets, instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
- const typename TreeType::Mat& querySet,
- KernelType& kernel,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(NULL),
- queryTree(NULL),
- treeOwner(true),
- single(single),
- naive(naive),
- metric(kernel)
-{
- Timer::Start("tree_building");
-
- // If necessary, the trees should be built.
- if (!naive)
- referenceTree = new TreeType(referenceSet, metric);
-
- if (!naive && !single)
- queryTree = new TreeType(querySet, metric);
-
Timer::Stop("tree_building");
}
// One dataset, pre-built tree.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(
- const typename TreeType::Mat& referenceSet,
- TreeType* referenceTree,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(referenceSet),
+FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
+ const bool singleMode) :
+ referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
- queryTree(NULL),
treeOwner(false),
- single(single),
- naive(naive),
+ singleMode(singleMode),
+ naive(false),
metric(referenceTree->Metric())
{
- // The query tree cannot be the same as the reference tree.
- if (referenceTree)
- queryTree = new TreeType(*referenceTree);
+ // Nothing to do.
}
-// Two datasets, pre-built trees.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(
- const typename TreeType::Mat& referenceSet,
- TreeType* referenceTree,
- const typename TreeType::Mat& querySet,
- TreeType* queryTree,
- const bool single,
- const bool naive) :
- referenceSet(referenceSet),
- querySet(querySet),
- referenceTree(referenceTree),
- queryTree(queryTree),
- treeOwner(false),
- single(single),
- naive(naive),
- metric(referenceTree->Metric())
+FastMKS<KernelType, TreeType>::~FastMKS()
{
- // Nothing to do.
+ // If we created the trees, we must delete them.
+ if (treeOwner && referenceTree)
+ delete referenceTree;
}
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::~FastMKS()
+void FastMKS<KernelType, TreeType>::Search(
+ const typename TreeType::Mat& querySet,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels)
{
- // If we created the trees, we must delete them.
- if (treeOwner)
+ Timer::Start("computing_products");
+
+ // No remapping will be necessary because we are using the cover tree.
+ indices.set_size(k, querySet.n_cols);
+ kernels.set_size(k, querySet.n_cols);
+
+ // Naive implementation.
+ if (naive)
+ {
+ // Fill kernels.
+ kernels.fill(-DBL_MAX);
+
+ // Simple double loop. Stupid, slow, but a good benchmark.
+ for (size_t q = 0; q < querySet.n_cols; ++q)
+ {
+ for (size_t r = 0; r < referenceSet.n_cols; ++r)
+ {
+ const double eval = metric.Kernel().Evaluate(querySet.col(q),
+ referenceSet.col(r));
+
+ size_t insertPosition;
+ for (insertPosition = 0; insertPosition < indices.n_rows;
+ ++insertPosition)
+ if (eval > kernels(insertPosition, q))
+ break;
+
+ if (insertPosition < indices.n_rows)
+ InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
+ }
+ }
+
+ Timer::Stop("computing_products");
+
+ return;
+ }
+
+ // Single-tree implementation.
+ if (singleMode)
{
- if (queryTree)
- delete queryTree;
- if (referenceTree)
- delete referenceTree;
+ // Fill kernels.
+ kernels.fill(-DBL_MAX);
+
+ // Create rules object (this will store the results). This constructor
+ // precalculates each self-kernel value.
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, querySet, indices, kernels, metric.Kernel());
+
+ typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+
+ for (size_t i = 0; i < querySet.n_cols; ++i)
+ traverser.Traverse(i, *referenceTree);
+
+ Log::Info << rules.BaseCases() << " base cases." << std::endl;
+ Log::Info << rules.Scores() << " scores." << std::endl;
+
+ Timer::Stop("computing_products");
+ return;
}
- else if (&querySet == &referenceSet)
+
+ // Dual-tree implementation. First, we need to build the query tree. We are
+ // assuming it doesn't map anything...
+ Timer::Stop("computing_products");
+ Timer::Start("tree_building");
+ TreeType queryTree(querySet);
+ Timer::Stop("tree_building");
+
+ Search(&queryTree, k, indices, kernels);
+}
+
+template<typename KernelType, typename TreeType>
+void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels)
+{
+ // If either naive mode or single mode is specified, this must fail.
+ if (naive || singleMode)
{
- // The user passed in a reference tree which we needed to copy.
- if (queryTree)
- delete queryTree;
+ throw std::invalid_argument("can't call Search() with a query tree when "
+ "single mode or naive search is enabled");
}
+
+ // No remapping will be necessary because we are using the cover tree.
+ indices.set_size(k, queryTree->Dataset().n_cols);
+ kernels.set_size(k, queryTree->Dataset().n_cols);
+ kernels.fill(-DBL_MAX);
+
+ Timer::Start("computing_products");
+ typedef FastMKSRules<KernelType, TreeType> RuleType;
+ RuleType rules(referenceSet, queryTree->Dataset(), indices, kernels,
+ metric.Kernel());
+
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+
+ traverser.Traverse(*queryTree, *referenceTree);
+
+ Log::Info << rules.BaseCases() << " base cases." << std::endl;
+ Log::Info << rules.Scores() << " scores." << std::endl;
+
+ Timer::Stop("computing_products");
}
template<typename KernelType, typename TreeType>
void FastMKS<KernelType, TreeType>::Search(const size_t k,
arma::Mat<size_t>& indices,
- arma::mat& products)
+ arma::mat& kernels)
{
// No remapping will be necessary because we are using the cover tree.
- indices.set_size(k, querySet.n_cols);
- products.set_size(k, querySet.n_cols);
- products.fill(-DBL_MAX);
-
Timer::Start("computing_products");
+ indices.set_size(k, referenceSet.n_cols);
+ kernels.set_size(k, referenceSet.n_cols);
+ kernels.fill(-DBL_MAX);
// Naive implementation.
if (naive)
{
// Simple double loop. Stupid, slow, but a good benchmark.
- for (size_t q = 0; q < querySet.n_cols; ++q)
+ for (size_t q = 0; q < referenceSet.n_cols; ++q)
{
for (size_t r = 0; r < referenceSet.n_cols; ++r)
{
- if ((&querySet == &referenceSet) && (q == r))
- continue;
+ if (q == r)
+ continue; // Don't return the point as its own candidate.
- const double eval = metric.Kernel().Evaluate(querySet.col(q),
+ const double eval = metric.Kernel().Evaluate(referenceSet.col(q),
referenceSet.col(r));
size_t insertPosition;
for (insertPosition = 0; insertPosition < indices.n_rows;
++insertPosition)
- if (eval > products(insertPosition, q))
+ if (eval > kernels(insertPosition, q))
break;
if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, products, q, insertPosition, r, eval);
+ InsertNeighbor(indices, kernels, q, insertPosition, r, eval);
}
}
@@ -227,16 +232,17 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
}
// Single-tree implementation.
- if (single)
+ if (singleMode)
{
// Create rules object (this will store the results). This constructor
// precalculates each self-kernel value.
typedef FastMKSRules<KernelType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
+ RuleType rules(referenceSet, referenceSet, indices, kernels,
+ metric.Kernel());
typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
- for (size_t i = 0; i < querySet.n_cols; ++i)
+ for (size_t i = 0; i < referenceSet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
// Save the number of pruned nodes.
@@ -252,21 +258,9 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
}
// Dual-tree implementation.
- typedef FastMKSRules<KernelType, TreeType> RuleType;
- RuleType rules(referenceSet, querySet, indices, products, metric.Kernel());
-
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- traverser.Traverse(*queryTree, *referenceTree);
-
- const size_t numPrunes = traverser.NumPrunes();
-
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
- Log::Info << rules.BaseCases() << " base cases." << std::endl;
- Log::Info << rules.Scores() << " scores." << std::endl;
-
Timer::Stop("computing_products");
- return;
+
+ Search(referenceTree, k, indices, kernels);
}
/**
@@ -309,229 +303,13 @@ std::string FastMKS<KernelType, TreeType>::ToString() const
std::ostringstream convert;
convert << "FastMKS [" << this << "]" << std::endl;
convert << " Naive: " << naive << std::endl;
- convert << " Single: " << single << std::endl;
+ convert << " Single: " << singleMode << std::endl;
convert << " Metric: " << std::endl;
convert << mlpack::util::Indent(metric.ToString(),2);
convert << std::endl;
return convert.str();
}
-// Specialized implementation for tighter bounds for Gaussian.
-/*
-template<>
-void FastMKS<kernel::GaussianKernel>::Search(const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& products)
-{
- Log::Warn << "Alternate implementation!" << std::endl;
-
- // Terrible copypasta. Bad bad bad.
- // No remapping will be necessary.
- indices.set_size(k, querySet.n_cols);
- products.set_size(k, querySet.n_cols);
- products.fill(-1.0);
-
- Timer::Start("computing_products");
-
- size_t kernelEvaluations = 0;
-
- // Naive implementation.
- if (naive)
- {
- // Simple double loop. Stupid, slow, but a good benchmark.
- for (size_t q = 0; q < querySet.n_cols; ++q)
- {
- for (size_t r = 0; r < referenceSet.n_cols; ++r)
- {
- const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
- referenceSet.unsafe_col(r));
- ++kernelEvaluations;
-
- size_t insertPosition;
- for (insertPosition = 0; insertPosition < indices.n_rows;
- ++insertPosition)
- if (eval > products(insertPosition, q))
- break;
-
- if (insertPosition < indices.n_rows)
- InsertNeighbor(indices, products, q, insertPosition, r, eval);
- }
- }
-
- Timer::Stop("computing_products");
-
- Log::Info << "Kernel evaluations: " << kernelEvaluations << "." << std::endl;
- return;
- }
-
- // Single-tree implementation.
- if (single)
- {
- // Calculate number of pruned nodes.
- size_t numPrunes = 0;
-
- // Precalculate query products ( || q || for all q).
- arma::vec queryProducts(querySet.n_cols);
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- queryProducts[queryIndex] = sqrt(metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex), querySet.unsafe_col(queryIndex)));
- kernelEvaluations += querySet.n_cols;
-
- // Screw the CoverTreeTraverser, we'll implement it by hand.
- for (size_t queryIndex = 0; queryIndex < querySet.n_cols; ++queryIndex)
- {
- // Use an array of priority queues?
- std::priority_queue<
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >,
- std::vector<SearchFrame<tree::CoverTree<IPMetric<
- kernel::GaussianKernel> > > >,
- SearchFrameCompare<tree::CoverTree<IPMetric<
- kernel::GaussianKernel> > > >
- frameQueue;
-
- // Add initial frame.
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
- nextFrame;
- nextFrame.node = referenceTree;
- nextFrame.eval = metric.Kernel().Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceTree->Point()));
- Log::Assert(nextFrame.eval <= 1);
- ++kernelEvaluations;
-
- // The initial evaluation will be the best so far.
- indices(0, queryIndex) = referenceTree->Point();
- products(0, queryIndex) = nextFrame.eval;
-
- frameQueue.push(nextFrame);
-
- tree::CoverTree<IPMetric<kernel::GaussianKernel> >* referenceNode;
- double eval;
- double maxProduct;
-
- while (!frameQueue.empty())
- {
- // Get the information for this node.
- const SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >&
- frame = frameQueue.top();
-
- referenceNode = frame.node;
- eval = frame.eval;
-
- // Loop through the children, seeing if we can prune them; if not, add
- // them to the queue. The self-child is different -- it has the same
- // parent (and therefore the same kernel evaluation).
- if (referenceNode->NumChildren() > 0)
- {
- SearchFrame<tree::CoverTree<IPMetric<kernel::GaussianKernel> > >
- childFrame;
-
- // We must handle the self-child differently, to avoid adding it to
- // the results twice.
- childFrame.node = &(referenceNode->Child(0));
- childFrame.eval = eval;
-
- // Alternate pruning rule.
- const double mdd = childFrame.node->FurthestDescendantDistance();
- if (eval >= (1 - std::pow(mdd, 2.0) / 2.0))
- maxProduct = 1;
- else
- maxProduct = eval * (1 - std::pow(mdd, 2.0) / 2.0) + sqrt(1 -
- std::pow(eval, 2.0)) * mdd * sqrt(1 - std::pow(mdd, 2.0) / 4.0);
-
- // Add self-child if we can't prune it.
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // But only if it has children of its own.
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- ++numPrunes;
-
- for (size_t i = 1; i < referenceNode->NumChildren(); ++i)
- {
- // Before we evaluate the child, let's see if it can possibly have
- // a better evaluation.
- const double mpdd = std::min(
- referenceNode->Child(i).ParentDistance() +
- referenceNode->Child(i).FurthestDescendantDistance(), 2.0);
- double maxChildEval = 1;
- if (eval < (1 - std::pow(mpdd, 2.0) / 2.0))
- maxChildEval = eval * (1 - std::pow(mpdd, 2.0) / 2.0) + sqrt(1 -
- std::pow(eval, 2.0)) * mpdd * sqrt(1 - std::pow(mpdd, 2.0)
- / 4.0);
-
- if (maxChildEval > products(products.n_rows - 1, queryIndex))
- {
- // Evaluate child.
- childFrame.node = &(referenceNode->Child(i));
- childFrame.eval = metric.Kernel().Evaluate(
- querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceNode->Child(i).Point()));
- ++kernelEvaluations;
-
- // Can we prune it? If we can, we can avoid putting it in the
- // queue (saves time).
- const double cmdd = childFrame.node->FurthestDescendantDistance();
- if (childFrame.eval > (1 - std::pow(cmdd, 2.0) / 2.0))
- maxProduct = 1;
- else
- maxProduct = childFrame.eval * (1 - std::pow(cmdd, 2.0) / 2.0)
- + sqrt(1 - std::pow(eval, 2.0)) * cmdd * sqrt(1 -
- std::pow(cmdd, 2.0) / 4.0);
-
- if (maxProduct > products(products.n_rows - 1, queryIndex))
- {
- // Good enough to recurse into. While we're at it, check the
- // actual evaluation and see if it's an improvement.
- if (childFrame.eval > products(products.n_rows - 1, queryIndex))
- {
- // This is a better result. Find out where to insert.
- size_t insertPosition = 0;
- for ( ; insertPosition < products.n_rows - 1;
- ++insertPosition)
- if (childFrame.eval > products(insertPosition, queryIndex))
- break;
-
- // Insert into the correct position; we are guaranteed that
- // insertPosition is valid.
- InsertNeighbor(indices, products, queryIndex, insertPosition,
- childFrame.node->Point(), childFrame.eval);
- }
-
- // Now add this to the queue (if it has any children which may
- // need to be recursed into).
- if (childFrame.node->NumChildren() > 0)
- frameQueue.push(childFrame);
- }
- else
- ++numPrunes;
- }
- else
- ++numPrunes;
- }
- }
-
- frameQueue.pop();
- }
- }
-
- Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
- Log::Info << "Kernel evaluations: " << kernelEvaluations << "."
- << std::endl;
- Log::Info << "Distance evaluations: " << distanceEvaluations << "."
- << std::endl;
-
- Timer::Stop("computing_products");
- return;
- }
-
- // Double-tree implementation.
- Log::Fatal << "Dual-tree search not implemented yet... oops..." << std::endl;
-
-}
-*/
-
}; // namespace fastmks
}; // namespace mlpack
diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index abb7708..1e0dc94 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -28,12 +28,12 @@ PROGRAM_INFO("FastMKS (Fast Max-Kernel Search)",
"'kernels.csv' and the indices are stored in 'indices.csv'."
"\n\n"
"$ fastmks --k 5 --reference_file reference.csv --query_file query.csv\n"
- " --indices_file indices.csv --products_file kernels.csv --kernel linear"
+ " --indices_file indices.csv --kernels_file kernels.csv --kernel linear"
"\n\n"
"The output files are organized such that row i and column j in the indices"
" output file corresponds to the index of the point in the reference set "
"that has i'th largest kernel evaluation with the point in the query set "
- "with index j. Row i and column j in the products output file corresponds "
+ "with index j. Row i and column j in the kernels output file corresponds "
"to the kernel evaluation between those two points."
"\n\n"
"This executable performs FastMKS using a cover tree. The base used to "
@@ -44,10 +44,10 @@ PARAM_STRING_REQ("reference_file", "File containing the reference dataset.",
"r");
PARAM_STRING("query_file", "File containing the query dataset.", "q", "");
-PARAM_INT_REQ("k", "Number of maximum inner products to find.", "k");
+PARAM_INT_REQ("k", "Number of maximum kernels to find.", "k");
-PARAM_STRING("products_file", "File to save inner products into.", "p", "");
-PARAM_STRING("indices_file", "File to save indices of inner products into.",
+PARAM_STRING("kernels_file", "File to save kernels into.", "p", "");
+PARAM_STRING("indices_file", "File to save indices of kernels into.",
"i", "");
PARAM_STRING("kernel", "Kernel type to use: 'linear', 'polynomial', 'cosine', "
@@ -76,20 +76,29 @@ void RunFastMKS(const arma::mat& referenceData,
const double base,
const size_t k,
arma::Mat<size_t>& indices,
- arma::mat& products,
+ arma::mat& kernels,
KernelType& kernel)
{
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType tree(referenceData, metric, base);
+ if (naive)
+ {
+ // No need for trees.
+ FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
+ fastmks.Search(k, indices, kernels);
+ }
+ else
+ {
+ // Create the tree with the specified base.
+ typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+ TreeType;
+ IPMetric<KernelType> metric(kernel);
+ TreeType tree(referenceData, metric, base);
- // Create FastMKS object.
- FastMKS<KernelType> fastmks(referenceData, &tree, (single && !naive), naive);
+ // Create FastMKS object.
+ FastMKS<KernelType> fastmks(&tree, single);
- // Now search with it.
- fastmks.Search(k, indices, products);
+ // Now search with it.
+ fastmks.Search(k, indices, kernels);
+ }
}
//! Run FastMKS for a given query and reference set using the given kernel type.
@@ -101,22 +110,30 @@ void RunFastMKS(const arma::mat& referenceData,
const double base,
const size_t k,
arma::Mat<size_t>& indices,
- arma::mat& products,
+ arma::mat& kernels,
KernelType& kernel)
{
- // Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
- IPMetric<KernelType> metric(kernel);
- TreeType referenceTree(referenceData, metric, base);
- TreeType queryTree(queryData, metric, base);
-
- // Create FastMKS object.
- FastMKS<KernelType> fastmks(referenceData, &referenceTree, queryData,
- &queryTree, (single && !naive), naive);
-
- // Now search with it.
- fastmks.Search(k, indices, products);
+ if (naive)
+ {
+ // No need for trees.
+ FastMKS<KernelType> fastmks(referenceData, kernel, false, naive);
+ fastmks.Search(queryData, k, indices, kernels);
+ }
+ else
+ {
+ // Create the tree with the specified base.
+ typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
+ TreeType;
+ IPMetric<KernelType> metric(kernel);
+ TreeType referenceTree(referenceData, metric, base);
+ TreeType queryTree(queryData, metric, base);
+
+ // Create FastMKS object.
+ FastMKS<KernelType> fastmks(&referenceTree, single);
+
+ // Now search with it.
+ fastmks.Search(&queryTree, k, indices, kernels);
+ }
}
int main(int argc, char** argv)
@@ -194,7 +211,7 @@ int main(int argc, char** argv)
// Matrices for output storage.
arma::Mat<size_t> indices;
- arma::mat products;
+ arma::mat kernels;
// Construct FastMKS object.
if (queryData.n_elem == 0)
@@ -203,44 +220,44 @@ int main(int argc, char** argv)
{
LinearKernel lk;
RunFastMKS<LinearKernel>(referenceData, single, naive, base, k, indices,
- products, lk);
+ kernels, lk);
}
else if (kernelType == "polynomial")
{
PolynomialKernel pk(degree, offset);
RunFastMKS<PolynomialKernel>(referenceData, single, naive, base, k,
- indices, products, pk);
+ indices, kernels, pk);
}
else if (kernelType == "cosine")
{
CosineDistance cd;
RunFastMKS<CosineDistance>(referenceData, single, naive, base, k, indices,
- products, cd);
+ kernels, cd);
}
else if (kernelType == "gaussian")
{
GaussianKernel gk(bandwidth);
RunFastMKS<GaussianKernel>(referenceData, single, naive, base, k, indices,
- products, gk);
+ kernels, gk);
}
else if (kernelType == "epanechnikov")
{
EpanechnikovKernel ek(bandwidth);
RunFastMKS<EpanechnikovKernel>(referenceData, single, naive, base, k,
- indices, products, ek);
+ indices, kernels, ek);
}
else if (kernelType == "triangular")
{
TriangularKernel tk(bandwidth);
RunFastMKS<TriangularKernel>(referenceData, single, naive, base, k,
- indices, products, tk);
+ indices, kernels, tk);
}
else if (kernelType == "hyptan")
{
HyperbolicTangentKernel htk(scale, offset);
RunFastMKS<HyperbolicTangentKernel>(referenceData, single, naive, base, k,
- indices, products, htk);
+ indices, kernels, htk);
}
}
else
@@ -249,51 +266,51 @@ int main(int argc, char** argv)
{
LinearKernel lk;
RunFastMKS<LinearKernel>(referenceData, queryData, single, naive, base, k,
- indices, products, lk);
+ indices, kernels, lk);
}
else if (kernelType == "polynomial")
{
PolynomialKernel pk(degree, offset);
RunFastMKS<PolynomialKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, pk);
+ base, k, indices, kernels, pk);
}
else if (kernelType == "cosine")
{
CosineDistance cd;
RunFastMKS<CosineDistance>(referenceData, queryData, single, naive, base,
- k, indices, products, cd);
+ k, indices, kernels, cd);
}
else if (kernelType == "gaussian")
{
GaussianKernel gk(bandwidth);
RunFastMKS<GaussianKernel>(referenceData, queryData, single, naive, base,
- k, indices, products, gk);
+ k, indices, kernels, gk);
}
else if (kernelType == "epanechnikov")
{
EpanechnikovKernel ek(bandwidth);
RunFastMKS<EpanechnikovKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, ek);
+ base, k, indices, kernels, ek);
}
else if (kernelType == "triangular")
{
TriangularKernel tk(bandwidth);
RunFastMKS<TriangularKernel>(referenceData, queryData, single, naive,
- base, k, indices, products, tk);
+ base, k, indices, kernels, tk);
}
else if (kernelType == "hyptan")
{
HyperbolicTangentKernel htk(scale, offset);
RunFastMKS<HyperbolicTangentKernel>(referenceData, queryData, single,
- naive, base, k, indices, products, htk);
+ naive, base, k, indices, kernels, htk);
}
}
// Save output, if we were asked to.
- if (CLI::HasParam("products_file"))
+ if (CLI::HasParam("kernels_file"))
{
- const string productsFile = CLI::GetParam<string>("products_file");
- data::Save(productsFile, products, false);
+ const string kernelsFile = CLI::GetParam<string>("kernels_file");
+ data::Save(kernelsFile, kernels, false);
}
if (CLI::HasParam("indices_file"))
More information about the mlpack-git
mailing list