[mlpack-git] master: Fix potential bug and simplify memory requirements. (6d7e0ee)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 25 03:57:05 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 6d7e0ee10359adff4a3dd15fa1cdeaf9f2f58921
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Oct 25 16:57:05 2016 +0900

    Fix potential bug and simplify memory requirements.


>---------------------------------------------------------------

6d7e0ee10359adff4a3dd15fa1cdeaf9f2f58921
 src/mlpack/methods/approx_kfn/qdafn.hpp      | 15 +++------
 src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 50 +++++++++++-----------------
 2 files changed, 25 insertions(+), 40 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index 694fbf9..7617fc2 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -51,10 +51,11 @@ class QDAFN
               arma::Mat<size_t>& neighbors,
               arma::mat& distances);
 
- private:
-  //! The reference set.
-  const MatType& referenceSet;
+  //! Serialize the model.
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
 
+ private:
   //! The number of projections.
   const size_t l;
   //! The number of elements to store for each projection.
@@ -69,13 +70,7 @@ class QDAFN
   //! Values of a_i * x for each point in S.
   arma::mat sValues;
 
-  //! Insert a neighbor into a set of results for a given query point.
-  void InsertNeighbor(arma::mat& distances,
-                      arma::Mat<size_t>& neighbors,
-                      const size_t queryIndex,
-                      const size_t pos,
-                      const size_t neighbor,
-                      const double distance) const;
+  arma::cube candidateSet;
 };
 
 } // namespace neighbor
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index bf462da..f1d04fa 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -21,7 +21,6 @@ template<typename MatType>
 QDAFN<MatType>::QDAFN(const MatType& referenceSet,
                       const size_t l,
                       const size_t m) :
-    referenceSet(referenceSet),
     l(l),
     m(m)
 {
@@ -40,6 +39,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
   // Loop over each projection and find the top m elements.
   sIndices.set_size(m, l);
   sValues.set_size(m, l);
+  candidateSet.set_size(referenceSet.n_rows, m, l);
   for (size_t i = 0; i < l; ++i)
   {
     arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
@@ -49,6 +49,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
     {
       sIndices(j, i) = sortedIndices[j];
       sValues(j, i) = projections(sortedIndices[j], i);
+      candidateSet.slice(l).col(j) = referenceSet.col(sortedIndices[j]);
     }
   }
 }
@@ -77,8 +78,8 @@ void QDAFN<MatType>::Search(const MatType& querySet,
     std::priority_queue<std::pair<double, size_t>> queue;
     for (size_t i = 0; i < l; ++i)
     {
-      const double val = projections(0, i) - arma::dot(querySet.col(q),
-                                                       lines.col(i));
+      const double val = sValues(0, i) - arma::dot(querySet.col(q),
+          lines.col(i));
       queue.push(std::make_pair(val, i));
     }
 
@@ -97,17 +98,17 @@ void QDAFN<MatType>::Search(const MatType& querySet,
       queue.pop();
 
       // Get index of reference point to look at.
-      size_t referenceIndex = sIndices(tableLocations[p.second], p.second);
+      const size_t tableIndex = tableLocations[p.second];
 
       // Calculate distance from query point.
       const double dist = mlpack::metric::EuclideanDistance::Evaluate(
-          querySet.col(q), referenceSet.col(referenceIndex));
+          querySet.col(q), candidateSet.slice(p.second).col(tableIndex));
 
       // Is this neighbor good enough to insert into the results?
       if (dist > resultsQueue.top().first)
       {
         resultsQueue.pop();
-        resultsQueue.push(std::make_pair(dist, referenceIndex));
+        resultsQueue.push(std::make_pair(dist, sIndices(tableIndex, p.second)));
       }
 
       // Now (line 14) get the next element and insert into the queue.  Do this
@@ -116,9 +117,8 @@ void QDAFN<MatType>::Search(const MatType& querySet,
       if (i < m - 1)
       {
         tableLocations[p.second]++;
-        const double val = p.first -
-            projections(tableLocations[p.second] - 1, p.second) +
-            projections(tableLocations[p.second], p.second);
+        const double val = p.first - sValues(tableIndex, p.second) +
+            sValues(tableIndex + 1, p.second);
 
         queue.push(std::make_pair(val, p.second));
       }
@@ -135,28 +135,18 @@ void QDAFN<MatType>::Search(const MatType& querySet,
 }
 
 template<typename MatType>
-void QDAFN<MatType>::InsertNeighbor(arma::mat& distances,
-                                    arma::Mat<size_t>& neighbors,
-                                    const size_t queryIndex,
-                                    const size_t pos,
-                                    const size_t neighbor,
-                                    const double distance) const
+template<typename Archive>
+void QDAFN<MatType>::Serialize(Archive& ar, const unsigned int /* version */)
 {
-  // We only memmove() if there is actually a need to shift something.
-  if (pos < (distances.n_rows - 1))
-  {
-    const size_t len = (distances.n_rows - 1) - pos;
-    memmove(distances.colptr(queryIndex) + (pos + 1),
-        distances.colptr(queryIndex) + pos,
-        sizeof(double) * len);
-    memmove(neighbors.colptr(queryIndex) + (pos + 1),
-        neighbors.colptr(queryIndex) + pos,
-        sizeof(size_t) * len);
-  }
-
-  // Now put the new information in the right index.
-  distances(pos, queryIndex) = distance;
-  neighbors(pos, queryIndex) = neighbor;
+  using data::CreateNVP;
+
+  ar & CreateNVP(l, "l");
+  ar & CreateNVP(m, "m");
+  ar & CreateNVP(lines, "lines");
+  ar & CreateNVP(projections, "projections");
+  ar & CreateNVP(sIndices, "sIndices");
+  ar & CreateNVP(sValues, "sValues");
+  ar & CreateNVP(candidateSet, "candidateSet");
 }
 
 } // namespace neighbor




More information about the mlpack-git mailing list