[mlpack-git] master: Refactor QDAFN to better handle sparse data matrices. (4b35ecc)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Oct 30 07:50:40 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd
>---------------------------------------------------------------
commit 4b35ecc9e8490fcb2fa499c685494b6604789ce4
Author: Ryan Curtin <ryan at ratml.org>
Date: Sun Oct 30 20:50:40 2016 +0900
Refactor QDAFN to better handle sparse data matrices.
>---------------------------------------------------------------
4b35ecc9e8490fcb2fa499c685494b6604789ce4
src/mlpack/methods/approx_kfn/qdafn.hpp | 8 +++++++-
src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 7 ++++---
2 files changed, 11 insertions(+), 4 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index ad9e206..f7949db 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -64,6 +64,11 @@ class QDAFN
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);
+ //! Get the candidate set for the given projection table.
+ const MatType& CandidateSet(const size_t t) const { return candidateSet[t]; }
+ //! Modify the candidate set for the given projection table. Careful!
+ MatType& CandidateSet(const size_t t) { return candidateSet[t]; }
+
private:
//! The number of projections.
size_t l;
@@ -79,7 +84,8 @@ class QDAFN
//! Values of a_i * x for each point in S.
arma::mat sValues;
- arma::cube candidateSet;
+ // Candidate sets; one element in the vector for each table.
+ std::vector<MatType> candidateSet;
};
} // namespace neighbor
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index 85ec99a..de6c882 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -43,9 +43,10 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
// Loop over each projection and find the top m elements.
sIndices.set_size(m, l);
sValues.set_size(m, l);
- candidateSet.set_size(referenceSet.n_rows, m, l);
+ candidateSet.resize(l);
for (size_t i = 0; i < l; ++i)
{
+ candidateSet[i].set_size(referenceSet.n_rows, m);
arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
// Grab the top m elements.
@@ -53,7 +54,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
{
sIndices(j, i) = sortedIndices[j];
sValues(j, i) = projections(sortedIndices[j], i);
- candidateSet.slice(i).col(j) = referenceSet.col(sortedIndices[j]);
+ candidateSet[i].col(j) = referenceSet.col(sortedIndices[j]);
}
}
}
@@ -106,7 +107,7 @@ void QDAFN<MatType>::Search(const MatType& querySet,
// Calculate distance from query point.
const double dist = mlpack::metric::EuclideanDistance::Evaluate(
- querySet.col(q), candidateSet.slice(p.second).col(tableIndex));
+ querySet.col(q), candidateSet[p.second].col(tableIndex));
// Is this neighbor good enough to insert into the results?
if (dist > resultsQueue.top().first)
More information about the mlpack-git
mailing list