[mlpack-git] master: Fix failing tests and bugs. (84bed62)

gitdub at mlpack.org gitdub at mlpack.org
Mon Oct 24 05:09:05 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit 84bed62e1a30dbc7b2cea5db7333e78c980d4529
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Oct 24 05:08:49 2016 -0400

    Fix failing tests and bugs.


>---------------------------------------------------------------

84bed62e1a30dbc7b2cea5db7333e78c980d4529
 .../methods/approx_kfn/drusilla_select_impl.hpp    | 54 +++++++++++-----------
 src/mlpack/tests/drusilla_select_test.cpp          | 17 +++----
 2 files changed, 36 insertions(+), 35 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index a84b304..f264e64 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -24,7 +24,8 @@ template<typename MatType>
 DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
                                         const size_t l,
                                         const size_t m) :
-    candidateSet(referenceSet.n_rows, l * m),
+    candidateSet(referenceSet.n_cols, l * m),
+    candidateIndices(l * m),
     l(l),
     m(m)
 {
@@ -41,6 +42,8 @@ DrusillaSelect<MatType>::DrusillaSelect(const MatType& referenceSet,
 // Constructor with no training.
 template<typename MatType>
 DrusillaSelect<MatType>::DrusillaSelect(const size_t l, const size_t m) :
+    candidateSet(0, l * m),
+    candidateIndices(l * m),
     l(l),
     m(m)
 {
@@ -70,6 +73,9 @@ void DrusillaSelect<MatType>::Train(
         "large!  Choose smaller values.  l*m must be smaller than the number "
         "of points in the dataset.");
 
+  candidateSet.set_size(referenceSet.n_rows, l * m);
+  candidateIndices.set_size(l * m);
+
   arma::vec dataMean = arma::mean(referenceSet, 1);
   arma::vec norms(referenceSet.n_cols);
 
@@ -87,29 +93,24 @@ void DrusillaSelect<MatType>::Train(
     arma::vec line = refCopy.col(maxIndex) / arma::norm(refCopy.col(maxIndex));
     const size_t n_nonzero = (size_t) arma::sum(norms > 0);
 
-    // Calculate distortion and offset.
-    arma::vec distortions(referenceSet.n_cols);
-    arma::vec offsets(referenceSet.n_cols);
+    // Calculate distortion and offset and make scores.
+    std::vector<bool> closeAngle(referenceSet.n_cols, false);
+    arma::vec sums(referenceSet.n_cols);
     for (size_t j = 0; j < referenceSet.n_cols; ++j)
     {
       if (norms[j] > 0.0)
       {
-        offsets[j] = arma::dot(refCopy.col(j), line);
-        distortions[j] = arma::norm(refCopy.col(j) - offsets[j] *
-            line);
+        const double offset = arma::dot(refCopy.col(j), line);
+        const double distortion = arma::norm(refCopy.col(j) - offset * line);
+        sums[j] = std::abs(offset) - std::abs(distortion);
+        closeAngle[j] =
+            (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0));
       }
       else
       {
-        offsets[j] = 0.0;
-        distortions[j] = 0.0;
+        sums[j] = norms[j];
       }
     }
-    arma::vec sums = arma::abs(offsets) - arma::abs(distortions);
-    arma::uvec sortedSums = arma::sort_index(sums, "descend");
-
-    arma::vec bestSums(m);
-    arma::Col<size_t> bestIndices(m);
-    bestSums.fill(-DBL_MAX);
 
     // Find the top m elements using a priority queue.
     typedef std::pair<double, size_t> Candidate;
@@ -117,11 +118,11 @@ void DrusillaSelect<MatType>::Train(
     {
       bool operator()(const Candidate& c1, const Candidate& c2)
       {
-        return c2.first > c1.first;
+        return c2.first < c1.first;
       }
     };
 
-    std::vector<Candidate> clist(m, std::make_pair(size_t(-1), double(0.0)));
+    std::vector<Candidate> clist(m, std::make_pair(double(-1.0), size_t(-1)));
     std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
         pq(CandidateCmp(), std::move(clist));
 
@@ -141,16 +142,17 @@ void DrusillaSelect<MatType>::Train(
       const size_t index = pq.top().second;
       pq.pop();
       candidateSet.col(i * m + j) = referenceSet.col(index);
+      candidateIndices[i * m + j] = index;
 
-      // Mark the norm as 0 so we don't see this point again.
-      norms[index] = 0.0;
+      // Mark the norm as -1 so we don't see this point again.
+      norms[index] = -1.0;
     }
 
     // Calculate angles from the current projection.  Anything close enough,
     // mark the norm as 0.
-    arma::vec farPoints = arma::conv_to<arma::vec>::from(
-        arma::atan(distortions / arma::abs(offsets)) >= (M_PI / 8.0));
-    norms %= farPoints;
+    for (size_t j = 0; j < norms.n_elem; ++j)
+      if (norms[j] > 0.0 && closeAngle[j])
+        norms[j] = 0.0;
   }
 }
 
@@ -175,16 +177,14 @@ void DrusillaSelect<MatType>::Search(const MatType& querySet,
   metric::EuclideanDistance metric;
   NeighborSearchRules<FurthestNeighborSort, metric::EuclideanDistance,
       tree::KDTree<metric::EuclideanDistance, tree::EmptyStatistic, arma::mat>>
-      rules(querySet, candidateSet, k, metric, 0, false);
-
-  neighbors.set_size(k, querySet.n_cols);
-  neighbors.fill(size_t() - 1);
-  distances.zeros(k, querySet.n_cols);
+      rules(candidateSet, querySet, k, metric, 0, false);
 
   for (size_t q = 0; q < querySet.n_cols; ++q)
     for (size_t r = 0; r < candidateSet.n_cols; ++r)
       rules.BaseCase(q, r);
 
+  rules.GetResults(neighbors, distances);
+
   // Map the neighbors back to their original indices in the reference set.
   for (size_t i = 0; i < neighbors.n_elem; ++i)
     neighbors[i] = candidateIndices[neighbors[i]];
diff --git a/src/mlpack/tests/drusilla_select_test.cpp b/src/mlpack/tests/drusilla_select_test.cpp
index 504fd62..b60a1ad 100644
--- a/src/mlpack/tests/drusilla_select_test.cpp
+++ b/src/mlpack/tests/drusilla_select_test.cpp
@@ -21,7 +21,7 @@ BOOST_AUTO_TEST_SUITE(DrusillaSelectTest);
 BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
 {
   arma::mat dataset = arma::randu<arma::mat>(5, 100);
-  dataset.col(100) += 100; // Make last column very large.
+  dataset.col(99) += 100; // Make last column very large.
 
   // Construct with some reasonable parameters.
   DrusillaSelect<> ds(dataset, 5, 5);
@@ -29,7 +29,7 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
   // Query with every point except the extreme point.
   arma::mat distances;
   arma::Mat<size_t> neighbors;
-  ds.Search(dataset.cols(0, 99), 1, neighbors, distances);
+  ds.Search(dataset.cols(0, 98), 1, neighbors, distances);
 
   BOOST_REQUIRE_EQUAL(neighbors.n_cols, 99);
   BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
@@ -37,7 +37,9 @@ BOOST_AUTO_TEST_CASE(DrusillaSelectExtremeOutlierTest)
   BOOST_REQUIRE_EQUAL(distances.n_rows, 1);
 
   for (size_t i = 0; i < 99; ++i)
-    BOOST_REQUIRE_EQUAL(neighbors[i], 100);
+  {
+    BOOST_REQUIRE_EQUAL(neighbors[i], 99);
+  }
 }
 
 // If we use only one projection with the number of points equal to what is in
@@ -82,7 +84,6 @@ BOOST_AUTO_TEST_CASE(RetrainTest)
   arma::Mat<size_t> neighbors;
   ds.Search(dataset, 1, neighbors, distances);
 
-  BOOST_REQUIRE_EQUAL(dataset.n_elem, 0);
   BOOST_REQUIRE_EQUAL(neighbors.n_cols, 200);
   BOOST_REQUIRE_EQUAL(neighbors.n_rows, 1);
   BOOST_REQUIRE_EQUAL(distances.n_cols, 200);
@@ -97,11 +98,11 @@ BOOST_AUTO_TEST_CASE(SerializationTest)
 
   DrusillaSelect<> ds(dataset, 3, 3);
 
-  arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 5);
-  arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 8);
-  DrusillaSelect<> dsXml(fakeDataset1, 10, 10);
+  arma::mat fakeDataset1 = arma::randu<arma::mat>(2, 15);
+  arma::mat fakeDataset2 = arma::randu<arma::mat>(10, 18);
+  DrusillaSelect<> dsXml(fakeDataset1, 5, 3);
   DrusillaSelect<> dsText(2, 2);
-  DrusillaSelect<> dsBinary(5, 6);
+  DrusillaSelect<> dsBinary(5, 2);
   dsBinary.Train(fakeDataset2);
 
   // Now do the serialization.




More information about the mlpack-git mailing list