[mlpack-git] master: Add tests for sparse operation and fix sparse bugs. (d01b20f)

gitdub at mlpack.org gitdub at mlpack.org
Sun Oct 30 08:30:30 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit d01b20fe87591e51421308dde5340816193429e3
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sun Oct 30 21:30:30 2016 +0900

    Add tests for sparse operation and fix sparse bugs.


>---------------------------------------------------------------

d01b20fe87591e51421308dde5340816193429e3
 .../methods/approx_kfn/drusilla_select_impl.hpp       | 13 ++++++++-----
 src/mlpack/tests/drusilla_select_test.cpp             | 19 +++++++++++++++++++
 2 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index 9595374..942063b 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -76,12 +76,15 @@ void DrusillaSelect<MatType>::Train(
   candidateSet.set_size(referenceSet.n_rows, l * m);
   candidateIndices.set_size(l * m);
 
-  arma::vec dataMean = arma::mean(referenceSet, 1);
+  arma::vec dataMean(arma::mean(referenceSet, 1));
   arma::vec norms(referenceSet.n_cols);
 
-  arma::mat refCopy = referenceSet.each_col() - dataMean;
+  MatType refCopy(referenceSet.n_rows, referenceSet.n_cols);
   for (size_t i = 0; i < refCopy.n_cols; ++i)
-    norms[i] = arma::norm(refCopy.col(i) - dataMean);
+  {
+    refCopy.col(i) = referenceSet.col(i) - dataMean;
+    norms[i] = arma::norm(refCopy.col(i));
+  }
 
   // Find the top m points for each of the l projections...
   for (size_t i = 0; i < l; ++i)
@@ -90,7 +93,7 @@ void DrusillaSelect<MatType>::Train(
     arma::uword maxIndex;
     norms.max(maxIndex);
 
-    arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
+    arma::vec line(refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex)));
     const size_t n_nonzero = (size_t) arma::sum(norms > 0);
 
     // Calculate distortion and offset and make scores.
@@ -176,7 +179,7 @@ void DrusillaSelect<MatType>::Search(const MatType& querySet,
   // TreeType.
   metric::EuclideanDistance metric;
   NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
-      tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
+      tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, MatType>>
       rules(candidateSet, querySet, k, metric, 0, false);
 
   for (size_t q = 0; q < querySet.n_cols; ++q)
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
index b60a1ad..cce2704 100644
--- a/src/mlpack/tests/drusilla_select_test.cpp
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -143,4 +143,23 @@ BOOST_AUTO_TEST_CASE(SerializationTest)
   }
 }
 
+// Make sure we can create the object with a sparse matrix.
+BOOST_AUTO_TEST_CASE(SparseTest)
+{
+  arma::sp_mat dataset;
+  dataset.sprandu(50, 1000, 0.3);
+
+  DrusillaSelect<arma::sp_mat> ds(dataset, 5, 10);
+
+  // Run a search.
+  arma::mat distances;
+  arma::Mat<size_t> neighbors;
+  ds.Search(dataset, 3, neighbors, distances);
+
+  BOOST_REQUIRE_EQUAL(neighbors.n_cols, 1000);
+  BOOST_REQUIRE_EQUAL(neighbors.n_rows, 3);
+  BOOST_REQUIRE_EQUAL(distances.n_cols, 1000);
+  BOOST_REQUIRE_EQUAL(distances.n_rows, 3);
+}
+
 BOOST_AUTO_TEST_SUITE_END();




More information about the mlpack-git mailing list