[mlpack-git] master: Refactor FastMKS to new TreeType API. (f7c1693)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:42 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit f7c1693c7d6b3eeb8e55d51f80cad99d487d9311
Author: ryan <ryan at ratml.org>
Date: Fri Jul 24 14:26:48 2015 -0400
Refactor FastMKS to new TreeType API.
>---------------------------------------------------------------
f7c1693c7d6b3eeb8e55d51f80cad99d487d9311
src/mlpack/methods/fastmks/fastmks.hpp | 26 +++---
src/mlpack/methods/fastmks/fastmks_impl.hpp | 119 ++++++++++++++++++----------
src/mlpack/methods/fastmks/fastmks_main.cpp | 13 +--
3 files changed, 97 insertions(+), 61 deletions(-)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 72ea34b..7b54835 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -48,12 +48,16 @@ namespace fastmks /** Fast max-kernel search. */ {
*/
template<
typename KernelType,
- typename TreeType = tree::CoverTree<metric::IPMetric<KernelType>,
- tree::FirstPointIsRoot, FastMKSStat>
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType
>
class FastMKS
{
public:
+ //! Convenience typedef.
+ typedef TreeType<metric::IPMetric<KernelType>, FastMKSStat, MatType> Tree;
+
/**
* Create the FastMKS object with the given reference set (this is the set
* that is searched). Optionally, specify whether or not single-tree search
@@ -63,7 +67,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const typename TreeType::Mat& referenceSet,
+ FastMKS(const MatType& referenceSet,
const bool singleMode = false,
const bool naive = false);
@@ -78,7 +82,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const typename TreeType::Mat& referenceSet,
+ FastMKS(const MatType& referenceSet,
KernelType& kernel,
const bool singleMode = false,
const bool naive = false);
@@ -94,7 +98,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(TreeType* referenceTree,
+ FastMKS(Tree* referenceTree,
const bool singleMode = false);
//! Destructor for the FastMKS object.
@@ -120,7 +124,7 @@ class FastMKS
* @param indices Matrix to store resulting indices of max-kernel search in.
* @param kernels Matrix to store resulting max-kernel values in.
*/
- void Search(const typename TreeType::Mat& querySet,
+ void Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& indices,
arma::mat& kernels);
@@ -147,7 +151,7 @@ class FastMKS
* @param indices Matrix to store resulting indices of max-kernel search in.
* @param kernels Matrix to store resulting max-kernel values in.
*/
- void Search(TreeType* querySet,
+ void Search(Tree* querySet,
const size_t k,
arma::Mat<size_t>& indices,
arma::mat& kernels);
@@ -187,9 +191,9 @@ class FastMKS
private:
//! The reference dataset.
- const typename TreeType::Mat& referenceSet;
+ const MatType& referenceSet;
//! The tree built on the reference dataset.
- TreeType* referenceTree;
+ Tree* referenceTree;
//! If true, this object created the tree and is responsible for it.
bool treeOwner;
@@ -210,8 +214,8 @@ class FastMKS
const double distance);
};
-}; // namespace fastmks
-}; // namespace mlpack
+} // namespace fastmks
+} // namespace mlpack
// Include implementation.
#include "fastmks_impl.hpp"
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index 1835d8d..ead5d2e 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -19,10 +19,14 @@ namespace mlpack {
namespace fastmks {
// No instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
- const bool singleMode,
- const bool naive) :
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(
+ const MatType& referenceSet,
+ const bool singleMode,
+ const bool naive) :
referenceSet(referenceSet),
referenceTree(NULL),
treeOwner(true),
@@ -32,17 +36,20 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
Timer::Start("tree_building");
if (!naive)
- referenceTree = new TreeType(referenceSet);
+ referenceTree = new Tree(referenceSet);
Timer::Stop("tree_building");
}
// Instantiated kernel.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
- KernelType& kernel,
- const bool singleMode,
- const bool naive) :
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(const MatType& referenceSet,
+ KernelType& kernel,
+ const bool singleMode,
+ const bool naive) :
referenceSet(referenceSet),
referenceTree(NULL),
treeOwner(true),
@@ -54,15 +61,18 @@ FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSe
// If necessary, the reference tree should be built. There is no query tree.
if (!naive)
- referenceTree = new TreeType(referenceSet, metric);
+ referenceTree = new Tree(referenceSet, metric);
Timer::Stop("tree_building");
}
// One dataset, pre-built tree.
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
- const bool singleMode) :
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+FastMKS<KernelType, MatType, TreeType>::FastMKS(Tree* referenceTree,
+ const bool singleMode) :
referenceSet(referenceTree->Dataset()),
referenceTree(referenceTree),
treeOwner(false),
@@ -73,17 +83,23 @@ FastMKS<KernelType, TreeType>::FastMKS(TreeType* referenceTree,
// Nothing to do.
}
-template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::~FastMKS()
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+FastMKS<KernelType, MatType, TreeType>::~FastMKS()
{
// If we created the trees, we must delete them.
if (treeOwner && referenceTree)
delete referenceTree;
}
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(
- const typename TreeType::Mat& querySet,
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+ const MatType& querySet,
const size_t k,
arma::Mat<size_t>& indices,
arma::mat& kernels)
@@ -132,10 +148,10 @@ void FastMKS<KernelType, TreeType>::Search(
// Create rules object (this will store the results). This constructor
// precalculates each self-kernel value.
- typedef FastMKSRules<KernelType, TreeType> RuleType;
+ typedef FastMKSRules<KernelType, Tree> RuleType;
RuleType rules(referenceSet, querySet, indices, kernels, metric.Kernel());
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
for (size_t i = 0; i < querySet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
@@ -151,17 +167,21 @@ void FastMKS<KernelType, TreeType>::Search(
// assuming it doesn't map anything...
Timer::Stop("computing_products");
Timer::Start("tree_building");
- TreeType queryTree(querySet);
+ Tree queryTree(querySet);
Timer::Stop("tree_building");
Search(&queryTree, k, indices, kernels);
}
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
- const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& kernels)
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+ Tree* queryTree,
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels)
{
// If either naive mode or single mode is specified, this must fail.
if (naive || singleMode)
@@ -176,11 +196,11 @@ void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
kernels.fill(-DBL_MAX);
Timer::Start("computing_products");
- typedef FastMKSRules<KernelType, TreeType> RuleType;
+ typedef FastMKSRules<KernelType, Tree> RuleType;
RuleType rules(referenceSet, queryTree->Dataset(), indices, kernels,
metric.Kernel());
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*queryTree, *referenceTree);
@@ -190,10 +210,14 @@ void FastMKS<KernelType, TreeType>::Search(TreeType* queryTree,
Timer::Stop("computing_products");
}
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::Search(const size_t k,
- arma::Mat<size_t>& indices,
- arma::mat& kernels)
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::Search(
+ const size_t k,
+ arma::Mat<size_t>& indices,
+ arma::mat& kernels)
{
// No remapping will be necessary because we are using the cover tree.
Timer::Start("computing_products");
@@ -236,11 +260,11 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
{
// Create rules object (this will store the results). This constructor
// precalculates each self-kernel value.
- typedef FastMKSRules<KernelType, TreeType> RuleType;
+ typedef FastMKSRules<KernelType, Tree> RuleType;
RuleType rules(referenceSet, referenceSet, indices, kernels,
metric.Kernel());
- typename TreeType::template SingleTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template SingleTreeTraverser<RuleType> traverser(rules);
for (size_t i = 0; i < referenceSet.n_cols; ++i)
traverser.Traverse(i, *referenceTree);
@@ -271,13 +295,17 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
* @param neighbor Index of reference point which is being inserted.
* @param distance Distance from query point to reference point.
*/
-template<typename KernelType, typename TreeType>
-void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
- arma::mat& products,
- const size_t queryIndex,
- const size_t pos,
- const size_t neighbor,
- const double distance)
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void FastMKS<KernelType, MatType, TreeType>::InsertNeighbor(
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ const size_t queryIndex,
+ const size_t pos,
+ const size_t neighbor,
+ const double distance)
{
// We only memmove() if there is actually a need to shift something.
if (pos < (products.n_rows - 1))
@@ -297,8 +325,11 @@ void FastMKS<KernelType, TreeType>::InsertNeighbor(arma::Mat<size_t>& indices,
}
// Return string of object.
-template<typename KernelType, typename TreeType>
-std::string FastMKS<KernelType, TreeType>::ToString() const
+template<typename KernelType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+std::string FastMKS<KernelType, MatType, TreeType>::ToString() const
{
std::ostringstream convert;
convert << "FastMKS [" << this << "]" << std::endl;
diff --git a/src/mlpack/methods/fastmks/fastmks_main.cpp b/src/mlpack/methods/fastmks/fastmks_main.cpp
index 1e0dc94..dc1cdf8 100644
--- a/src/mlpack/methods/fastmks/fastmks_main.cpp
+++ b/src/mlpack/methods/fastmks/fastmks_main.cpp
@@ -88,13 +88,13 @@ void RunFastMKS(const arma::mat& referenceData,
else
{
// Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
+ typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
+ FirstPointIsRoot> TreeType;
IPMetric<KernelType> metric(kernel);
TreeType tree(referenceData, metric, base);
// Create FastMKS object.
- FastMKS<KernelType> fastmks(&tree, single);
+ FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&tree, single);
// Now search with it.
fastmks.Search(k, indices, kernels);
@@ -122,14 +122,15 @@ void RunFastMKS(const arma::mat& referenceData,
else
{
// Create the tree with the specified base.
- typedef CoverTree<IPMetric<KernelType>, FirstPointIsRoot, FastMKSStat>
- TreeType;
+ typedef CoverTree<IPMetric<KernelType>, FastMKSStat, arma::mat,
+ FirstPointIsRoot> TreeType;
IPMetric<KernelType> metric(kernel);
TreeType referenceTree(referenceData, metric, base);
TreeType queryTree(queryData, metric, base);
// Create FastMKS object.
- FastMKS<KernelType> fastmks(&referenceTree, single);
+ FastMKS<KernelType, arma::mat, StandardCoverTree> fastmks(&referenceTree,
+ single);
// Now search with it.
fastmks.Search(&queryTree, k, indices, kernels);
More information about the mlpack-git
mailing list