[mlpack-git] master: Fix incorrect inequality. (fc4c3b8)

gitdub at mlpack.org gitdub at mlpack.org
Tue Oct 25 04:35:34 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd

>---------------------------------------------------------------

commit fc4c3b87846e1b001fb1fcc64ccce8c11adad886
Author: Ryan Curtin <ryan at ratml.org>
Date:   Tue Oct 25 04:35:34 2016 -0400

    Fix incorrect inequality.


>---------------------------------------------------------------

fc4c3b87846e1b001fb1fcc64ccce8c11adad886
 src/mlpack/methods/approx_kfn/approx_kfn_main.cpp  | 73 ++++++++++------------
 .../methods/approx_kfn/drusilla_select_impl.hpp    |  2 +-
 src/mlpack/methods/approx_kfn/qdafn.hpp            | 13 +++-
 src/mlpack/methods/approx_kfn/qdafn_impl.hpp       |  4 ++
 4 files changed, 49 insertions(+), 43 deletions(-)

diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
index 4d6ef67..b5f2ac2 100644
--- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
+++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
@@ -33,7 +33,7 @@ PROGRAM_INFO("Approximate furthest neighbor search",
     "\n\n"
     "Specify a reference set (set to search in) with --reference_file, "
     "specify a query set with --query_file, and specify algorithm parameters "
-    "with --num_tables (-l) and --num_projections (-m) (or don't and defaults "
+    "with --num_tables (-t) and --num_projections (-p) (or don't and defaults "
     "will be used).  The algorithm to be used (either 'ds'---the default---or "
     "'qdafn') may be specified with --algorithm.  Also specify the number of "
     "neighbors to search for with --k.  Each of those options also has short "
@@ -54,13 +54,13 @@ PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
 
 // Model loading and saving.
 PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
-PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", "");
+PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M");
 
-PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k");
+PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k", 0);
 
-PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5);
+PARAM_INT_IN("num_tables", "Number of hash tables to use.", "t", 5);
 PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
-    "table.", "m", 5);
+    "table.", "p", 5);
 PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
 
 PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
@@ -79,10 +79,11 @@ class ApproxKFNModel
 {
  public:
   int type;
-  boost::any model;
+  DrusillaSelect<> ds;
+  QDAFN<> qdafn;
 
   //! Constructor, which does nothing.
-  ApproxKFNModel() : type(0) { /* Nothing to do. */ }
+  ApproxKFNModel() : type(0), ds(1, 1), qdafn(1, 1) { }
 
   //! Serialize the model.
   template<typename Archive>
@@ -90,9 +91,13 @@ class ApproxKFNModel
   {
     ar & data::CreateNVP(type, "type");
     if (type == 0)
-      ar & data::CreateNVP(boost::any_cast<DrusillaSelect<>>(model), "model");
+    {
+      ar & data::CreateNVP(ds, "model");
+    }
     else
-      ar & data::CreateNVP(boost::any_cast<QDAFN<>>(model), "model");
+    {
+      ar & data::CreateNVP(qdafn, "model");
+    }
   }
 };
 
@@ -162,8 +167,7 @@ int main(int argc, char** argv)
       Timer::Start("drusilla_select_construct");
       Log::Info << "Building DrusillaSelect model..." << endl;
       m.type = 0;
-      m.model = boost::any(DrusillaSelect<>(referenceSet, numTables,
-          numProjections));
+      m.ds = DrusillaSelect<>(referenceSet, numTables, numProjections);
       Timer::Stop("drusilla_select_construct");
     }
     else
@@ -171,7 +175,7 @@ int main(int argc, char** argv)
       Timer::Start("qdafn_construct");
       Log::Info << "Building QDAFN model..." << endl;
       m.type = 1;
-      m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections));
+      m.qdafn = QDAFN<>(referenceSet, numTables, numProjections);
       Timer::Stop("qdafn_construct");
     }
   }
@@ -179,52 +183,41 @@ int main(int argc, char** argv)
   {
     // We must load the model from file.
     const string inputModelFile = CLI::GetParam<string>("input_model_file");
-    data::Load(inputModelFile, m);
+    data::Load(inputModelFile, "approx_kfn", m);
   }
 
   // Now, do we need to do any queries?
   if (CLI::HasParam("k"))
   {
+    arma::mat querySet; // This may or may not be used.
     const size_t k = (size_t) CLI::GetParam<int>("k");
 
     arma::Mat<size_t> neighbors;
     arma::mat distances;
 
+    arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
     if (CLI::HasParam("query_file"))
     {
       const string queryFile = CLI::GetParam<string>("query_file");
-      arma::mat querySet;
-      data::Load(querySet, queryFile);
+      data::Load(queryFile, querySet);
+    }
 
-      if (m.type == 0)
-      {
-        Timer::Start("drusilla_select_search");
-        boost::any_cast<DrusillaSelect<>>(m.model).Search(querySet, k,
-            neighbors, distances);
-        Timer::Stop("drusilla_select_search");
-      }
-      else
-      {
-        Timer::Start("qdafn_search");
-        boost::any_cast<QDAFN<>>(m.model).Search(querySet, k, neighbors,
-            distances);
-        Timer::Stop("qdafn_search");
-      }
+    if (m.type == 0)
+    {
+      Timer::Start("drusilla_select_search");
+      m.ds.Search(set, k, neighbors, distances);
+      Timer::Stop("drusilla_select_search");
     }
     else
     {
-      // We will do search with the reference set.
-      if (m.type == 0)
-        boost::any_cast<DrusillaSelect<>>(m.model).Search(k, neighbors,
-            distances);
-      else
-        boost::any_cast<QDAFN<>>(m.model).Search(k, neighbors, distances);
+      Timer::Start("qdafn_search");
+      m.qdafn.Search(set, k, neighbors, distances);
+      Timer::Stop("qdafn_search");
     }
 
     // Should we calculate error?
     if (CLI::HasParam("calculate_error"))
     {
-      arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
       arma::mat exactDistances;
       if (CLI::HasParam("exact_distances_file"))
       {
@@ -239,10 +232,10 @@ int main(int argc, char** argv)
         arma::Mat<size_t> exactNeighbors;
         kfn.Search(set, k, exactNeighbors, exactDistances);
 
-        const double averageError = arma::sum(trueDistances / distances.row(0))
+        const double averageError = arma::sum(exactDistances / distances.row(0))
             / distances.n_cols;
-        const double minError = arma::min(trueDistances / distances.row(0));
-        const double maxError = arma::max(trueDistances / distances.row(0));
+        const double minError = arma::min(exactDistances / distances.row(0));
+        const double maxError = arma::max(exactDistances / distances.row(0));
 
         Log::Info << "Average error: " << averageError << "." << endl;
         Log::Info << "Maximum error: " << maxError << "." << endl;
@@ -259,5 +252,5 @@ int main(int argc, char** argv)
 
   // Should we save the model?
   if (CLI::HasParam("output_model_file"))
-    data::Save(CLI::GetParam<string>("output_model_file"), m);
+    data::Save(CLI::GetParam<string>("output_model_file"), "approx_kfn", m);
 }
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index f264e64..9595374 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -104,7 +104,7 @@ void DrusillaSelect<MatType>::Train(
         const double distortion = arma::norm(refCopy.col(j) - offset * line);
         sums[j] = std::abs(offset) - std::abs(distortion);
         closeAngle[j] =
-            (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0));
+            (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0));
       }
       else
       {
diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index 7617fc2..ad9e206 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -29,6 +29,15 @@ class QDAFN
 {
  public:
   /**
+   * Construct the QDAFN object but do not train it.  Be sure to call Train()
+   * before calling Search().
+   *
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  QDAFN(const size_t l, const size_t m);
+
+  /**
    * Construct the QDAFN object with the given reference set (this is the set
    * that will be searched).
    *
@@ -57,9 +66,9 @@ class QDAFN
 
  private:
   //! The number of projections.
-  const size_t l;
+  size_t l;
   //! The number of elements to store for each projection.
-  const size_t m;
+  size_t m;
   //! The random lines we are projecting onto.  Has l columns.
   arma::mat lines;
   //! Projections of each point onto each random line.
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index 9220989..85ec99a 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -16,6 +16,10 @@
 namespace mlpack {
 namespace neighbor {
 
+// Non-training constructor.
+template<typename MatType>
+QDAFN<MatType>::QDAFN(const size_t l, const size_t m) : l(l), m(m) { }
+
 // Constructor.
 template<typename MatType>
 QDAFN<MatType>::QDAFN(const MatType& referenceSet,




More information about the mlpack-git mailing list