[mlpack-git] master: Update knn/kfn methods and tests, to consider NeighborSearchMode. (e9e0d7c)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Aug 20 14:56:07 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/3274b05fcc545c3b36f783316fea2e22f79c3d03...1c77230c7d3b9c45fb102cd3c632d9c7248e085e
>---------------------------------------------------------------
commit e9e0d7c3f35f5c401ee7d6d7d5bce03f1e3f47c8
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Aug 16 05:14:17 2016 -0300
Update knn/kfn methods and tests, to consider NeighborSearchMode.
>---------------------------------------------------------------
e9e0d7c3f35f5c401ee7d6d7d5bce03f1e3f47c8
src/mlpack/methods/neighbor_search/kfn_main.cpp | 20 ++++++++++++++++----
src/mlpack/methods/neighbor_search/knn_main.cpp | 20 ++++++++++++++++----
src/mlpack/tests/aknn_test.cpp | 14 +++++++++-----
src/mlpack/tests/knn_test.cpp | 14 ++++++++------
4 files changed, 49 insertions(+), 19 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 58d68cc..0f20197 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -80,6 +80,7 @@ PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "s");
+PARAM_FLAG("greedy", "If true, greedy single-tree search is used.", "G");
PARAM_DOUBLE_IN("epsilon", "If specified, will do approximate furthest neighbor"
" search with given relative error. Must be in the range [0,1).", "e", 0);
PARAM_DOUBLE_IN("percentage", "If specified, will do approximate furthest "
@@ -179,8 +180,20 @@ int main(int argc, char *argv[])
// We either have to load the reference data, or we have to load the model.
NSModel<FurthestNeighborSort> kfn;
+
const bool naive = CLI::HasParam("naive");
const bool singleMode = CLI::HasParam("single_mode");
+ const bool greedy = CLI::HasParam("greedy");
+
+ NeighborSearchMode searchMode;
+ if (naive)
+ searchMode = NAIVE_MODE;
+ else if (singleMode)
+ searchMode = SINGLE_TREE_MODE;
+ else if (greedy)
+ searchMode = GREEDY_SINGLE_TREE_MODE;
+ else searchMode = DUAL_TREE_MODE;
+
if (CLI::HasParam("reference_file"))
{
// Get all the parameters.
@@ -227,8 +240,7 @@ int main(int argc, char *argv[])
Log::Info << "Loaded reference data from '" << referenceFile << "' ("
<< referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl;
- kfn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
- epsilon);
+ kfn.BuildModel(std::move(referenceSet), size_t(lsInt), searchMode, epsilon);
}
else
{
@@ -236,8 +248,8 @@ int main(int argc, char *argv[])
const string inputModelFile = CLI::GetParam<string>("input_model_file");
data::Load(inputModelFile, "kfn_model", kfn, true); // Fatal on failure.
- kfn.SingleMode() = CLI::HasParam("single_mode");
- kfn.Naive() = CLI::HasParam("naive");
+ // Adjust search mode.
+ kfn.SetSearchMode(searchMode);
kfn.Epsilon() = epsilon;
// If leaf_size wasn't provided, let's consider the current value in the
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index f0456f2..ce28303 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -87,6 +87,7 @@ PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
PARAM_FLAG("naive", "If true, O(n^2) naive mode is used for computation.", "N");
PARAM_FLAG("single_mode", "If true, single-tree search is used (as opposed to "
"dual-tree search).", "S");
+PARAM_FLAG("greedy", "If true, greedy single-tree search is used.", "G");
PARAM_DOUBLE_IN("epsilon", "If specified, will do approximate nearest neighbor "
"search with given relative error.", "e", 0);
@@ -191,8 +192,20 @@ int main(int argc, char *argv[])
// We either have to load the reference data, or we have to load the model.
NSModel<NearestNeighborSort> knn;
+
const bool naive = CLI::HasParam("naive");
const bool singleMode = CLI::HasParam("single_mode");
+ const bool greedy = CLI::HasParam("greedy");
+
+ NeighborSearchMode searchMode;
+ if (naive)
+ searchMode = NAIVE_MODE;
+ else if (singleMode)
+ searchMode = SINGLE_TREE_MODE;
+ else if (greedy)
+ searchMode = GREEDY_SINGLE_TREE_MODE;
+ else searchMode = DUAL_TREE_MODE;
+
if (CLI::HasParam("reference_file"))
{
// Get all the parameters.
@@ -245,8 +258,7 @@ int main(int argc, char *argv[])
<< referenceSet.n_rows << " x " << referenceSet.n_cols << ")."
<< endl;
- knn.BuildModel(std::move(referenceSet), size_t(lsInt), naive, singleMode,
- epsilon);
+ knn.BuildModel(std::move(referenceSet), size_t(lsInt), searchMode, epsilon);
}
else
{
@@ -254,8 +266,8 @@ int main(int argc, char *argv[])
const string inputModelFile = CLI::GetParam<string>("input_model_file");
data::Load(inputModelFile, "knn_model", knn, true); // Fatal on failure.
- knn.SingleMode() = CLI::HasParam("single_mode");
- knn.Naive() = CLI::HasParam("naive");
+ // Adjust search mode.
+ knn.SetSearchMode(searchMode);
knn.Epsilon() = epsilon;
// If leaf_size wasn't provided, let's consider the current value in the
diff --git a/src/mlpack/tests/aknn_test.cpp b/src/mlpack/tests/aknn_test.cpp
index 4271b9a..c236be8 100644
--- a/src/mlpack/tests/aknn_test.cpp
+++ b/src/mlpack/tests/aknn_test.cpp
@@ -368,11 +368,13 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat referenceCopy(referenceData);
arma::mat queryCopy(queryData);
if (j == 0)
- models[i].BuildModel(std::move(referenceCopy), 20, false, false, 0.05);
+ models[i].BuildModel(std::move(referenceCopy), 20, DUAL_TREE_MODE,
+ 0.05);
if (j == 1)
- models[i].BuildModel(std::move(referenceCopy), 20, false, true, 0.05);
+ models[i].BuildModel(std::move(referenceCopy), 20,
+ SINGLE_TREE_MODE, 0.05);
if (j == 2)
- models[i].BuildModel(std::move(referenceCopy), 20, true, false);
+ models[i].BuildModel(std::move(referenceCopy), 20, NAIVE_MODE);
arma::Mat<size_t> neighborsApprox;
arma::mat distancesApprox;
@@ -442,9 +444,11 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
// We only have a std::move() constructor... so copy the data.
arma::mat referenceCopy(referenceData);
if (j == 0)
- models[i].BuildModel(std::move(referenceCopy), 20, false, false, 0.05);
+ models[i].BuildModel(std::move(referenceCopy), 20, DUAL_TREE_MODE,
+ 0.05);
if (j == 1)
- models[i].BuildModel(std::move(referenceCopy), 20, false, true, 0.05);
+ models[i].BuildModel(std::move(referenceCopy), 20,
+ SINGLE_TREE_MODE, 0.05);
arma::Mat<size_t> neighborsApprox;
arma::mat distancesApprox;
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index 193c4c4..35a0f50 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -1107,11 +1107,12 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat referenceCopy(referenceData);
arma::mat queryCopy(queryData);
if (j == 0)
- models[i].BuildModel(std::move(referenceCopy), 20, false, false);
+ models[i].BuildModel(std::move(referenceCopy), 20, DUAL_TREE_MODE);
if (j == 1)
- models[i].BuildModel(std::move(referenceCopy), 20, false, true);
+ models[i].BuildModel(std::move(referenceCopy), 20,
+ SINGLE_TREE_MODE);
if (j == 2)
- models[i].BuildModel(std::move(referenceCopy), 20, true, false);
+ models[i].BuildModel(std::move(referenceCopy), 20, NAIVE_MODE);
arma::Mat<size_t> neighbors;
arma::mat distances;
@@ -1184,11 +1185,12 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
// We only have a std::move() constructor... so copy the data.
arma::mat referenceCopy(referenceData);
if (j == 0)
- models[i].BuildModel(std::move(referenceCopy), 20, false, false);
+ models[i].BuildModel(std::move(referenceCopy), 20, DUAL_TREE_MODE);
if (j == 1)
- models[i].BuildModel(std::move(referenceCopy), 20, false, true);
+ models[i].BuildModel(std::move(referenceCopy), 20,
+ SINGLE_TREE_MODE);
if (j == 2)
- models[i].BuildModel(std::move(referenceCopy), 20, true, false);
+ models[i].BuildModel(std::move(referenceCopy), 20, NAIVE_MODE);
arma::Mat<size_t> neighbors;
arma::mat distances;
More information about the mlpack-git
mailing list