[mlpack-git] master: Refactor QDAFN to better handle sparse data matrices. (4b35ecc)

gitdub at mlpack.org gitdub at mlpack.org
Sun Oct 30 07:50:40 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 4b35ecc9e8490fcb2fa499c685494b6604789ce4
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 30 20:50:40 2016 +0900

    Refactor QDAFN to better handle sparse data matrices.


>---------------------------------------------------------------

4b35ecc9e8490fcb2fa499c685494b6604789ce4
 src/mlpack/methods/approx_kfn/qdafn.hpp      | 8 +++++++-
 src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 7 ++++---
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index ad9e206..f7949db 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -64,6 +64,11 @@ class QDAFN
   template<typename Archive>
   void Serialize(Archive& ar, const unsigned int /* version */);
 
+  //! Get the candidate set for the given projection table.
+  const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; }
+  //! Modify the candidate set for the given projection table.  Careful!
+  MatType& CandidateSet(const size_t t) { return candidateSet[t]; }
+
  private:
   //! The number of projections.
   size_t l;
@@ -79,7 +84,8 @@ class QDAFN
   //! Values of a_i * x for each point in S.
   arma::mat sValues;
 
-  arma::cube candidateSet;
+  // Candidate sets; one element in the vector for each table.
+  std::vector<MatType> candidateSet;
 };
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index 85ec99a..de6c882 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -43,9 +43,10 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
   // Loop over each projection and find the top m elements.
   sIndices.set_size(m, l);
   sValues.set_size(m, l);
-  candidateSet.set_size(referenceSet.n_rows, m, l);
+  candidateSet.resize(l);
   for (size_t i = 0; i < l; ++i)
   {
+    candidateSet[i].set_size(referenceSet.n_rows, m);
     arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
 
     // Grab the top m elements.
@@ -53,7 +54,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
     {
       sIndices(j, i) = sortedIndices[j];
       sValues(j, i) = projections(sortedIndices[j], i);
-      candidateSet.slice(i).col(j) = referenceSet.col(sortedIndices[j]);
+      candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]);
     }
   }
 }
@@ -106,7 +107,7 @@ void QDAFN<MatType>::Search(const MatType& querySet,
 
       // Calculate distance from query point.
       const double dist = mlpack::metric::EuclideanDistance::Evaluate(
-          querySet.col(q), candidateSet.slice(p.second).col(tableIndex));
+          querySet.col(q), candidateSet[p.second].col(tableIndex));
 
       // Is this neighbor good enough to insert into the results?
       if (dist > resultsQueue.top().first)




More information about the mlpack-git mailing list