[mlpack-git] master: Refactor FastMKS to allow sparse datasets. (a8637c1)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Sun Apr 5 18:48:19 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/c41bc8d54695f1a20b3de551bc80fe7221dd3cd1...81fe638aa0c8d9592bbf7b434110140eb9bb86e7
>---------------------------------------------------------------
commit a8637c19dd652188fb9149696ca0c94f80c8f209
Author: ryan <ryan at ratml.org>
Date: Sun Apr 5 18:47:46 2015 -0400
Refactor FastMKS to allow sparse datasets.
>---------------------------------------------------------------
a8637c19dd652188fb9149696ca0c94f80c8f209
src/mlpack/methods/fastmks/fastmks.hpp | 22 ++++++-------
src/mlpack/methods/fastmks/fastmks_impl.hpp | 38 ++++++++++++-----------
src/mlpack/methods/fastmks/fastmks_rules.hpp | 8 ++---
src/mlpack/methods/fastmks/fastmks_rules_impl.hpp | 26 ++++++++--------
src/mlpack/methods/fastmks/fastmks_stat.hpp | 4 +--
5 files changed, 50 insertions(+), 48 deletions(-)
diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 916290f..90ca995 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -63,7 +63,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const arma::mat& referenceSet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
const bool single = false,
const bool naive = false);
@@ -77,8 +77,8 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool single = false,
const bool naive = false);
@@ -93,7 +93,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const arma::mat& referenceSet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
KernelType& kernel,
const bool single = false,
const bool naive = false);
@@ -110,8 +110,8 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
KernelType& kernel,
const bool single = false,
const bool naive = false);
@@ -128,7 +128,7 @@ class FastMKS
* @param single Whether or not to run single-tree search.
* @param naive Whether or not to run brute-force (naive) search.
*/
- FastMKS(const arma::mat& referenceSet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
TreeType* referenceTree,
const bool single = false,
const bool naive = false);
@@ -146,9 +146,9 @@ class FastMKS
* @param single Whether or not to use single-tree search.
* @param naive Whether or not to use naive (brute-force) search.
*/
- FastMKS(const arma::mat& referenceSet,
+ FastMKS(const typename TreeType::Mat& referenceSet,
TreeType* referenceTree,
- const arma::mat& querySet,
+ const typename TreeType::Mat& querySet,
TreeType* queryTree,
const bool single = false,
const bool naive = false);
@@ -186,9 +186,9 @@ class FastMKS
private:
//! The reference dataset.
- const arma::mat& referenceSet;
+ const typename TreeType::Mat& referenceSet;
//! The query dataset.
- const arma::mat& querySet;
+ const typename TreeType::Mat& querySet;
//! The tree built on the reference dataset.
TreeType* referenceTree;
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index e044795..19f466e 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -20,7 +20,7 @@ namespace fastmks {
// Single dataset, no instantiated kernel.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
const bool single,
const bool naive) :
referenceSet(referenceSet),
@@ -44,8 +44,8 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
// Two datasets, no instantiated kernel.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
const bool single,
const bool naive) :
referenceSet(referenceSet),
@@ -70,7 +70,7 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
// One dataset, instantiated kernel.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
KernelType& kernel,
const bool single,
const bool naive) :
@@ -97,8 +97,8 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
// Two datasets, instantiated kernel.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- const arma::mat& querySet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
KernelType& kernel,
const bool single,
const bool naive) :
@@ -125,10 +125,11 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
// One dataset, pre-built tree.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const bool single,
- const bool naive) :
+FastMKS<KernelType, TreeType>::FastMKS(
+ const typename TreeType::Mat& referenceSet,
+ TreeType* referenceTree,
+ const bool single,
+ const bool naive) :
referenceSet(referenceSet),
querySet(referenceSet),
referenceTree(referenceTree),
@@ -145,12 +146,13 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
// Two datasets, pre-built trees.
template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
- TreeType* referenceTree,
- const arma::mat& querySet,
- TreeType* queryTree,
- const bool single,
- const bool naive) :
+FastMKS<KernelType, TreeType>::FastMKS(
+ const typename TreeType::Mat& referenceSet,
+ TreeType* referenceTree,
+ const typename TreeType::Mat& querySet,
+ TreeType* queryTree,
+ const bool single,
+ const bool naive) :
referenceSet(referenceSet),
querySet(querySet),
referenceTree(referenceTree),
@@ -205,8 +207,8 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
if ((&querySet == &referenceSet) && (q == r))
continue;
- const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
- referenceSet.unsafe_col(r));
+ const double eval = metric.Kernel().Evaluate(querySet.col(q),
+ referenceSet.col(r));
size_t insertPosition;
for (insertPosition = 0; insertPosition < indices.n_rows;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 659e448..0e264bd 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -22,8 +22,8 @@ template<typename KernelType, typename TreeType>
class FastMKSRules
{
public:
- FastMKSRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
+ FastMKSRules(const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
arma::Mat<size_t>& indices,
arma::mat& products,
KernelType& kernel);
@@ -98,9 +98,9 @@ class FastMKSRules
private:
//! The reference dataset.
- const arma::mat& referenceSet;
+ const typename TreeType::Mat& referenceSet;
//! The query dataset.
- const arma::mat& querySet;
+ const typename TreeType::Mat& querySet;
//! The indices of the maximum kernel results.
arma::Mat<size_t>& indices;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index 96d4fe7..34290a6 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -14,11 +14,12 @@ namespace mlpack {
namespace fastmks {
template<typename KernelType, typename TreeType>
-FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
- const arma::mat& querySet,
- arma::Mat<size_t>& indices,
- arma::mat& products,
- KernelType& kernel) :
+FastMKSRules<KernelType, TreeType>::FastMKSRules(
+ const typename TreeType::Mat& referenceSet,
+ const typename TreeType::Mat& querySet,
+ arma::Mat<size_t>& indices,
+ arma::mat& products,
+ KernelType& kernel) :
referenceSet(referenceSet),
querySet(querySet),
indices(indices),
@@ -33,13 +34,13 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
// Precompute each self-kernel.
queryKernels.set_size(querySet.n_cols);
for (size_t i = 0; i < querySet.n_cols; ++i)
- queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
- querySet.unsafe_col(i)));
+ queryKernels[i] = sqrt(kernel.Evaluate(querySet.col(i),
+ querySet.col(i)));
referenceKernels.set_size(referenceSet.n_cols);
for (size_t i = 0; i < referenceSet.n_cols; ++i)
- referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
- referenceSet.unsafe_col(i)));
+ referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.col(i),
+ referenceSet.col(i)));
// Set to invalid memory, so that the first node combination does not try to
// dereference null pointers.
@@ -69,8 +70,8 @@ double FastMKSRules<KernelType, TreeType>::BaseCase(
}
++baseCases;
- double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
- referenceSet.unsafe_col(referenceIndex));
+ double kernelEval = kernel.Evaluate(querySet.col(queryIndex),
+ referenceSet.col(referenceIndex));
// Update the last kernel value, if we need to.
if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
@@ -156,11 +157,10 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
}
else
{
- const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
arma::vec refCentroid;
referenceNode.Centroid(refCentroid);
- kernelEval = kernel.Evaluate(queryPoint, refCentroid);
+ kernelEval = kernel.Evaluate(querySet.col(queryIndex), refCentroid);
}
referenceNode.Stat().LastKernel() = kernelEval;
diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp
index f6357ae..9cbef10 100644
--- a/src/mlpack/methods/fastmks/fastmks_stat.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp
@@ -58,8 +58,8 @@ class FastMKSStat
else
{
selfKernel = sqrt(node.Metric().Kernel().Evaluate(
- node.Dataset().unsafe_col(node.Point(0)),
- node.Dataset().unsafe_col(node.Point(0))));
+ node.Dataset().col(node.Point(0)),
+ node.Dataset().col(node.Point(0))));
}
}
else
More information about the mlpack-git
mailing list