[mlpack-git] master: Refactor FastMKS to allow sparse datasets. (a8637c1)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Sun Apr 5 18:48:19 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/c41bc8d54695f1a20b3de551bc80fe7221dd3cd1...81fe638aa0c8d9592bbf7b434110140eb9bb86e7

>---------------------------------------------------------------

commit a8637c19dd652188fb9149696ca0c94f80c8f209
Author: ryan <ryan at ratml.org>
Date:   Sun Apr 5 18:47:46 2015 -0400

    Refactor FastMKS to allow sparse datasets.


>---------------------------------------------------------------

a8637c19dd652188fb9149696ca0c94f80c8f209
 src/mlpack/methods/fastmks/fastmks.hpp            | 22 ++++++-------
 src/mlpack/methods/fastmks/fastmks_impl.hpp       | 38 ++++++++++++-----------
 src/mlpack/methods/fastmks/fastmks_rules.hpp      |  8 ++---
 src/mlpack/methods/fastmks/fastmks_rules_impl.hpp | 26 ++++++++--------
 src/mlpack/methods/fastmks/fastmks_stat.hpp       |  4 +--
 5 files changed, 50 insertions(+), 48 deletions(-)

diff --git a/src/mlpack/methods/fastmks/fastmks.hpp b/src/mlpack/methods/fastmks/fastmks.hpp
index 916290f..90ca995 100644
--- a/src/mlpack/methods/fastmks/fastmks.hpp
+++ b/src/mlpack/methods/fastmks/fastmks.hpp
@@ -63,7 +63,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const arma::mat& referenceSet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
           const bool single = false,
           const bool naive = false);
 
@@ -77,8 +77,8 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const arma::mat& referenceSet,
-          const arma::mat& querySet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
+          const typename TreeType::Mat& querySet,
           const bool single = false,
           const bool naive = false);
 
@@ -93,7 +93,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const arma::mat& referenceSet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
           KernelType& kernel,
           const bool single = false,
           const bool naive = false);
@@ -110,8 +110,8 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const arma::mat& referenceSet,
-          const arma::mat& querySet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
+          const typename TreeType::Mat& querySet,
           KernelType& kernel,
           const bool single = false,
           const bool naive = false);
@@ -128,7 +128,7 @@ class FastMKS
    * @param single Whether or not to run single-tree search.
    * @param naive Whether or not to run brute-force (naive) search.
    */
-  FastMKS(const arma::mat& referenceSet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
           TreeType* referenceTree,
           const bool single = false,
           const bool naive = false);
@@ -146,9 +146,9 @@ class FastMKS
    * @param single Whether or not to use single-tree search.
    * @param naive Whether or not to use naive (brute-force) search.
    */
-  FastMKS(const arma::mat& referenceSet,
+  FastMKS(const typename TreeType::Mat& referenceSet,
           TreeType* referenceTree,
-          const arma::mat& querySet,
+          const typename TreeType::Mat& querySet,
           TreeType* queryTree,
           const bool single = false,
           const bool naive = false);
@@ -186,9 +186,9 @@ class FastMKS
 
  private:
   //! The reference dataset.
-  const arma::mat& referenceSet;
+  const typename TreeType::Mat& referenceSet;
   //! The query dataset.
-  const arma::mat& querySet;
+  const typename TreeType::Mat& querySet;
 
   //! The tree built on the reference dataset.
   TreeType* referenceTree;
diff --git a/src/mlpack/methods/fastmks/fastmks_impl.hpp b/src/mlpack/methods/fastmks/fastmks_impl.hpp
index e044795..19f466e 100644
--- a/src/mlpack/methods/fastmks/fastmks_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_impl.hpp
@@ -20,7 +20,7 @@ namespace fastmks {
 
 // Single dataset, no instantiated kernel.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
                                        const bool single,
                                        const bool naive) :
     referenceSet(referenceSet),
@@ -44,8 +44,8 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
 
 // Two datasets, no instantiated kernel.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
-                                       const arma::mat& querySet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
+                                       const typename TreeType::Mat& querySet,
                                        const bool single,
                                        const bool naive) :
     referenceSet(referenceSet),
@@ -70,7 +70,7 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
 
 // One dataset, instantiated kernel.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
                                        KernelType& kernel,
                                        const bool single,
                                        const bool naive) :
@@ -97,8 +97,8 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
 
 // Two datasets, instantiated kernel.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
-                                       const arma::mat& querySet,
+FastMKS<KernelType, TreeType>::FastMKS(const typename TreeType::Mat& referenceSet,
+                                       const typename TreeType::Mat& querySet,
                                        KernelType& kernel,
                                        const bool single,
                                        const bool naive) :
@@ -125,10 +125,11 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
 
 // One dataset, pre-built tree.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
-                                       TreeType* referenceTree,
-                                       const bool single,
-                                       const bool naive) :
+FastMKS<KernelType, TreeType>::FastMKS(
+    const typename TreeType::Mat& referenceSet,
+    TreeType* referenceTree,
+    const bool single,
+    const bool naive) :
     referenceSet(referenceSet),
     querySet(referenceSet),
     referenceTree(referenceTree),
@@ -145,12 +146,13 @@ FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
 
 // Two datasets, pre-built trees.
 template<typename KernelType, typename TreeType>
-FastMKS<KernelType, TreeType>::FastMKS(const arma::mat& referenceSet,
-                                       TreeType* referenceTree,
-                                       const arma::mat& querySet,
-                                       TreeType* queryTree,
-                                       const bool single,
-                                       const bool naive) :
+FastMKS<KernelType, TreeType>::FastMKS(
+    const typename TreeType::Mat& referenceSet,
+    TreeType* referenceTree,
+    const typename TreeType::Mat& querySet,
+    TreeType* queryTree,
+    const bool single,
+    const bool naive) :
     referenceSet(referenceSet),
     querySet(querySet),
     referenceTree(referenceTree),
@@ -205,8 +207,8 @@ void FastMKS<KernelType, TreeType>::Search(const size_t k,
         if ((&querySet == &referenceSet) && (q == r))
           continue;
 
-        const double eval = metric.Kernel().Evaluate(querySet.unsafe_col(q),
-            referenceSet.unsafe_col(r));
+        const double eval = metric.Kernel().Evaluate(querySet.col(q),
+                                                     referenceSet.col(r));
 
         size_t insertPosition;
         for (insertPosition = 0; insertPosition < indices.n_rows;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules.hpp b/src/mlpack/methods/fastmks/fastmks_rules.hpp
index 659e448..0e264bd 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules.hpp
@@ -22,8 +22,8 @@ template<typename KernelType, typename TreeType>
 class FastMKSRules
 {
  public:
-  FastMKSRules(const arma::mat& referenceSet,
-               const arma::mat& querySet,
+  FastMKSRules(const typename TreeType::Mat& referenceSet,
+               const typename TreeType::Mat& querySet,
                arma::Mat<size_t>& indices,
                arma::mat& products,
                KernelType& kernel);
@@ -98,9 +98,9 @@ class FastMKSRules
 
  private:
   //! The reference dataset.
-  const arma::mat& referenceSet;
+  const typename TreeType::Mat& referenceSet;
   //! The query dataset.
-  const arma::mat& querySet;
+  const typename TreeType::Mat& querySet;
 
   //! The indices of the maximum kernel results.
   arma::Mat<size_t>& indices;
diff --git a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
index 96d4fe7..34290a6 100644
--- a/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_rules_impl.hpp
@@ -14,11 +14,12 @@ namespace mlpack {
 namespace fastmks {
 
 template<typename KernelType, typename TreeType>
-FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
-                                                 const arma::mat& querySet,
-                                                 arma::Mat<size_t>& indices,
-                                                 arma::mat& products,
-                                                 KernelType& kernel) :
+FastMKSRules<KernelType, TreeType>::FastMKSRules(
+    const typename TreeType::Mat& referenceSet,
+    const typename TreeType::Mat& querySet,
+    arma::Mat<size_t>& indices,
+    arma::mat& products,
+    KernelType& kernel) :
     referenceSet(referenceSet),
     querySet(querySet),
     indices(indices),
@@ -33,13 +34,13 @@ FastMKSRules<KernelType, TreeType>::FastMKSRules(const arma::mat& referenceSet,
   // Precompute each self-kernel.
   queryKernels.set_size(querySet.n_cols);
   for (size_t i = 0; i < querySet.n_cols; ++i)
-    queryKernels[i] = sqrt(kernel.Evaluate(querySet.unsafe_col(i),
-                                           querySet.unsafe_col(i)));
+    queryKernels[i] = sqrt(kernel.Evaluate(querySet.col(i),
+                                           querySet.col(i)));
 
   referenceKernels.set_size(referenceSet.n_cols);
   for (size_t i = 0; i < referenceSet.n_cols; ++i)
-    referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.unsafe_col(i),
-                                               referenceSet.unsafe_col(i)));
+    referenceKernels[i] = sqrt(kernel.Evaluate(referenceSet.col(i),
+                                               referenceSet.col(i)));
 
   // Set to invalid memory, so that the first node combination does not try to
   // dereference null pointers.
@@ -69,8 +70,8 @@ double FastMKSRules<KernelType, TreeType>::BaseCase(
   }
 
   ++baseCases;
-  double kernelEval = kernel.Evaluate(querySet.unsafe_col(queryIndex),
-                                      referenceSet.unsafe_col(referenceIndex));
+  double kernelEval = kernel.Evaluate(querySet.col(queryIndex),
+                                      referenceSet.col(referenceIndex));
 
   // Update the last kernel value, if we need to.
   if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
@@ -156,11 +157,10 @@ double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex,
   }
   else
   {
-    const arma::vec queryPoint = querySet.unsafe_col(queryIndex);
     arma::vec refCentroid;
     referenceNode.Centroid(refCentroid);
 
-    kernelEval = kernel.Evaluate(queryPoint, refCentroid);
+    kernelEval = kernel.Evaluate(querySet.col(queryIndex), refCentroid);
   }
 
   referenceNode.Stat().LastKernel() = kernelEval;
diff --git a/src/mlpack/methods/fastmks/fastmks_stat.hpp b/src/mlpack/methods/fastmks/fastmks_stat.hpp
index f6357ae..9cbef10 100644
--- a/src/mlpack/methods/fastmks/fastmks_stat.hpp
+++ b/src/mlpack/methods/fastmks/fastmks_stat.hpp
@@ -58,8 +58,8 @@ class FastMKSStat
       else
       {
         selfKernel = sqrt(node.Metric().Kernel().Evaluate(
-            node.Dataset().unsafe_col(node.Point(0)),
-            node.Dataset().unsafe_col(node.Point(0))));
+            node.Dataset().col(node.Point(0)),
+            node.Dataset().col(node.Point(0))));
       }
     }
     else



More information about the mlpack-git mailing list