[mlpack-git] master: Fix incorrect inequality. (fc4c3b8)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Oct 25 04:35:34 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/31995784e651e1c17c988c79d9f53c9dbad620f8...81fce4edfc8bfb4c26b48ed388f559ec1cee26dd
>---------------------------------------------------------------
commit fc4c3b87846e1b001fb1fcc64ccce8c11adad886
Author: Ryan Curtin <ryan at ratml.org>
Date: Tue Oct 25 04:35:34 2016 -0400
Fix incorrect inequality.
>---------------------------------------------------------------
fc4c3b87846e1b001fb1fcc64ccce8c11adad886
src/mlpack/methods/approx_kfn/approx_kfn_main.cpp | 73 ++++++++++------------
.../methods/approx_kfn/drusilla_select_impl.hpp | 2 +-
src/mlpack/methods/approx_kfn/qdafn.hpp | 13 +++-
src/mlpack/methods/approx_kfn/qdafn_impl.hpp | 4 ++
4 files changed, 49 insertions(+), 43 deletions(-)
diff --git a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
index 4d6ef67..b5f2ac2 100644
--- a/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
+++ b/src/mlpack/methods/approx_kfn/approx_kfn_main.cpp
@@ -33,7 +33,7 @@ PROGRAM_INFO("Approximate furthest neighbor search",
"\n\n"
"Specify a reference set (set to search in) with --reference_file, "
"specify a query set with --query_file, and specify algorithm parameters "
- "with --num_tables (-l) and --num_projections (-m) (or don't and defaults "
+ "with --num_tables (-t) and --num_projections (-p) (or don't and defaults "
"will be used). The algorithm to be used (either 'ds'---the default---or "
"'qdafn') may be specified with --algorithm. Also specify the number of "
"neighbors to search for with --k. Each of those options also has short "
@@ -54,13 +54,13 @@ PARAM_STRING_IN("query_file", "File containing query points.", "q", "");
// Model loading and saving.
PARAM_STRING_IN("input_model_file", "File containing input model.", "m", "");
-PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M", "");
+PARAM_STRING_OUT("output_model_file", "File to save output model to.", "M");
-PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k");
+PARAM_INT_IN("k", "Number of furthest neighbors to search for.", "k", 0);
-PARAM_INT_IN("num_tables", "Number of hash tables to use.", "l", 5);
+PARAM_INT_IN("num_tables", "Number of hash tables to use.", "t", 5);
PARAM_INT_IN("num_projections", "Number of projections to use in each hash "
- "table.", "m", 5);
+ "table.", "p", 5);
PARAM_STRING_IN("algorithm", "Algorithm to use: 'ds' or 'qdafn'.", "a", "ds");
PARAM_STRING_IN("neighbors_file", "File to save furthest neighbor indices to.",
@@ -79,10 +79,11 @@ class ApproxKFNModel
{
public:
int type;
- boost::any model;
+ DrusillaSelect<> ds;
+ QDAFN<> qdafn;
//! Constructor, which does nothing.
- ApproxKFNModel() : type(0) { /* Nothing to do. */ }
+ ApproxKFNModel() : type(0), ds(1, 1), qdafn(1, 1) { }
//! Serialize the model.
template<typename Archive>
@@ -90,9 +91,13 @@ class ApproxKFNModel
{
ar & data::CreateNVP(type, "type");
if (type == 0)
- ar & data::CreateNVP(boost::any_cast<DrusillaSelect<>>(model), "model");
+ {
+ ar & data::CreateNVP(ds, "model");
+ }
else
- ar & data::CreateNVP(boost::any_cast<QDAFN<>>(model), "model");
+ {
+ ar & data::CreateNVP(qdafn, "model");
+ }
}
};
@@ -162,8 +167,7 @@ int main(int argc, char** argv)
Timer::Start("drusilla_select_construct");
Log::Info << "Building DrusillaSelect model..." << endl;
m.type = 0;
- m.model = boost::any(DrusillaSelect<>(referenceSet, numTables,
- numProjections));
+ m.ds = DrusillaSelect<>(referenceSet, numTables, numProjections);
Timer::Stop("drusilla_select_construct");
}
else
@@ -171,7 +175,7 @@ int main(int argc, char** argv)
Timer::Start("qdafn_construct");
Log::Info << "Building QDAFN model..." << endl;
m.type = 1;
- m.model = boost::any(QDAFN<>(referenceSet, numTables, numProjections));
+ m.qdafn = QDAFN<>(referenceSet, numTables, numProjections);
Timer::Stop("qdafn_construct");
}
}
@@ -179,52 +183,41 @@ int main(int argc, char** argv)
{
// We must load the model from file.
const string inputModelFile = CLI::GetParam<string>("input_model_file");
- data::Load(inputModelFile, m);
+ data::Load(inputModelFile, "approx_kfn", m);
}
// Now, do we need to do any queries?
if (CLI::HasParam("k"))
{
+ arma::mat querySet; // This may or may not be used.
const size_t k = (size_t) CLI::GetParam<int>("k");
arma::Mat<size_t> neighbors;
arma::mat distances;
+ arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
if (CLI::HasParam("query_file"))
{
const string queryFile = CLI::GetParam<string>("query_file");
- arma::mat querySet;
- data::Load(querySet, queryFile);
+ data::Load(queryFile, querySet);
+ }
- if (m.type == 0)
- {
- Timer::Start("drusilla_select_search");
- boost::any_cast<DrusillaSelect<>>(m.model).Search(querySet, k,
- neighbors, distances);
- Timer::Stop("drusilla_select_search");
- }
- else
- {
- Timer::Start("qdafn_search");
- boost::any_cast<QDAFN<>>(m.model).Search(querySet, k, neighbors,
- distances);
- Timer::Stop("qdafn_search");
- }
+ if (m.type == 0)
+ {
+ Timer::Start("drusilla_select_search");
+ m.ds.Search(set, k, neighbors, distances);
+ Timer::Stop("drusilla_select_search");
}
else
{
- // We will do search with the reference set.
- if (m.type == 0)
- boost::any_cast<DrusillaSelect<>>(m.model).Search(k, neighbors,
- distances);
- else
- boost::any_cast<QDAFN<>>(m.model).Search(k, neighbors, distances);
+ Timer::Start("qdafn_search");
+ m.qdafn.Search(set, k, neighbors, distances);
+ Timer::Stop("qdafn_search");
}
// Should we calculate error?
if (CLI::HasParam("calculate_error"))
{
- arma::mat& set = CLI::HasParam("query_file") ? querySet : referenceSet;
arma::mat exactDistances;
if (CLI::HasParam("exact_distances_file"))
{
@@ -239,10 +232,10 @@ int main(int argc, char** argv)
arma::Mat<size_t> exactNeighbors;
kfn.Search(set, k, exactNeighbors, exactDistances);
- const double averageError = arma::sum(trueDistances / distances.row(0))
+ const double averageError = arma::sum(exactDistances / distances.row(0))
/ distances.n_cols;
- const double minError = arma::min(trueDistances / distances.row(0));
- const double maxError = arma::max(trueDistances / distances.row(0));
+ const double minError = arma::min(exactDistances / distances.row(0));
+ const double maxError = arma::max(exactDistances / distances.row(0));
Log::Info << "Average error: " << averageError << "." << endl;
Log::Info << "Maximum error: " << maxError << "." << endl;
@@ -259,5 +252,5 @@ int main(int argc, char** argv)
// Should we save the model?
if (CLI::HasParam("output_model_file"))
- data::Save(CLI::GetParam<string>("output_model_file"), m);
+ data::Save(CLI::GetParam<string>("output_model_file"), "approx_kfn", m);
}
diff --git a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
index f264e64..9595374 100644
--- a/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
@@ -104,7 +104,7 @@ void DrusillaSelect<MatType>::Train(
const double distortion = arma::norm(refCopy.col(j) - offset * line);
sums[j] = std::abs(offset) - std::abs(distortion);
closeAngle[j] =
- (std::atan(distortion / std::abs(offset)) >= (M_PI / 8.0));
+ (std::atan(distortion / std::abs(offset)) < (M_PI / 8.0));
}
else
{
diff --git a/src/mlpack/methods/approx_kfn/qdafn.hpp b/src/mlpack/methods/approx_kfn/qdafn.hpp
index 7617fc2..ad9e206 100644
--- a/src/mlpack/methods/approx_kfn/qdafn.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn.hpp
@@ -29,6 +29,15 @@ class QDAFN
{
public:
/**
+ * Construct the QDAFN object but do not train it. Be sure to call Train()
+ * before calling Search().
+ *
+ * @param l Number of projections.
+ * @param m Number of elements to store for each projection.
+ */
+ QDAFN(const size_t l, const size_t m);
+
+ /**
* Construct the QDAFN object with the given reference set (this is the set
* that will be searched).
*
@@ -57,9 +66,9 @@ class QDAFN
private:
//! The number of projections.
- const size_t l;
+ size_t l;
//! The number of elements to store for each projection.
- const size_t m;
+ size_t m;
//! The random lines we are projecting onto. Has l columns.
arma::mat lines;
//! Projections of each point onto each random line.
diff --git a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
index 9220989..85ec99a 100644
--- a/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
+++ b/src/mlpack/methods/approx_kfn/qdafn_impl.hpp
@@ -16,6 +16,10 @@
namespace mlpack {
namespace neighbor {
+// Non-training constructor.
+template<typename MatType>
+QDAFN<MatType>::QDAFN(const size_t l, const size_t m) : l(l), m(m) { }
+
// Constructor.
template<typename MatType>
QDAFN<MatType>::QDAFN(const MatType& referenceSet,
More information about the mlpack-git
mailing list