[mlpack-git] master: Add tested support to KNN/KFN for octree. (e4ca4c1)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Sep 24 12:44:52 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit e4ca4c1e0ae024ffaec24942a6e54feb6889ca95
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Sep 24 12:44:52 2016 -0400
Add tested support to KNN/KFN for octree.
>---------------------------------------------------------------
e4ca4c1e0ae024ffaec24942a6e54feb6889ca95
src/mlpack/methods/neighbor_search/kfn_main.cpp | 9 +++++---
src/mlpack/methods/neighbor_search/knn_main.cpp | 9 +++++---
src/mlpack/methods/neighbor_search/ns_model.hpp | 13 +++++++++--
.../methods/neighbor_search/ns_model_impl.hpp | 27 +++++++++++++++++++---
src/mlpack/tests/knn_test.cpp | 12 ++++++----
5 files changed, 55 insertions(+), 15 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 2eb5bd6..9645472 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -68,10 +68,10 @@ PARAM_INT_IN("k", "Number of furthest neighbors to find.", "k", 0);
// building.
PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
"'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
- "'r-plus-plus'.", "t", "kd");
+ "'r-plus-plus', 'octree'.", "t", "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
"vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
- "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
+ "Hilbert R trees, R+ trees, R++ trees, and octrees).", "l", 20);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -262,10 +262,13 @@ int main(int argc, char *argv[])
tree = KFNModel::MAX_RP_TREE;
else if (treeType == "ub")
tree = KFNModel::UB_TREE;
+ else if (treeType == "octree")
+ tree = KFNModel::OCTREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
<< "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
- << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+ << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', and 'octree'."
+ << endl;
kfn.TreeType() = tree;
kfn.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 0e29ff8..2127b79 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -73,7 +73,8 @@ PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
"'r-plus-plus', 'spill'.", "t", "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, vp "
"trees, random projection trees, UB trees, R trees, R* trees, X trees, "
- "Hilbert R trees, R+ trees, R++ trees and spill trees).", "l", 20);
+ "Hilbert R trees, R+ trees, R++ trees, spill trees, and octrees).", "l",
+ 20);
PARAM_DOUBLE_IN("tau", "Overlapping size (only valid for spill trees).", "u",
0);
PARAM_DOUBLE_IN("rho", "Balance threshold (only valid for spill trees).", "b",
@@ -276,11 +277,13 @@ int main(int argc, char *argv[])
tree = KNNModel::MAX_RP_TREE;
else if (treeType == "ub")
tree = KNNModel::UB_TREE;
+ else if (treeType == "octree")
+ tree = KNNModel::OCTREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
<< "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
- << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus' and 'spill'."
- << endl;
+ << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'spill', and "
+ << "'octree'." << endl;
knn.TreeType() = tree;
knn.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 4f4c47d..b0ae690 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -14,6 +14,7 @@
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
#include <mlpack/core/tree/spill_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
#include <boost/variant.hpp>
#include "neighbor_search.hpp"
@@ -133,6 +134,9 @@ class BiSearchVisitor : public boost::static_visitor<void>
//! Bichromatic neighbor search specialized for SPTrees.
void operator()(SpillKNN* ns) const;
+ //! Bichromatic neighbor search specialized for octrees.
+ void operator()(NSTypeT<tree::Octree>* ns) const;
+
//! Construct the BiSearchVisitor.
BiSearchVisitor(const arma::mat& querySet,
const size_t k,
@@ -188,6 +192,9 @@ class TrainVisitor : public boost::static_visitor<void>
//! Train specialized for SPTrees.
void operator()(SpillKNN* ns) const;
+ //! Train specialized for octrees.
+ void operator()(NSTypeT<tree::Octree>* ns) const;
+
//! Construct the TrainVisitor object with the given reference set, leafSize
//! for BinarySpaceTrees, and tau and rho for spill trees.
TrainVisitor(arma::mat&& referenceSet,
@@ -287,7 +294,8 @@ class NSModel
RP_TREE,
MAX_RP_TREE,
SPILL_TREE,
- UB_TREE
+ UB_TREE,
+ OCTREE
};
private:
@@ -325,7 +333,8 @@ class NSModel
NSType<SortPolicy, tree::RPTree>*,
NSType<SortPolicy, tree::MaxRPTree>*,
SpillKNN*,
- NSType<SortPolicy, tree::UBTree>*> nSearch;
+ NSType<SortPolicy, tree::UBTree>*,
+ NSType<SortPolicy, tree::Octree>*> nSearch;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 062a6ec..c2d6c6b 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -96,6 +96,15 @@ void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
+//! Bichromatic neighbor search specialized for octrees.
+template<typename SortPolicy>
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
+{
+ if (ns)
+ return SearchLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
@@ -150,7 +159,7 @@ void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
//! Train on the given NSType specialized for KDTrees.
template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
{
if (ns)
return TrainLeaf(ns);
@@ -159,7 +168,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
//! Train on the given NSType specialized for BallTrees.
template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
{
if (ns)
return TrainLeaf(ns);
@@ -168,7 +177,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
//! Train specialized for SPTrees.
template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
+void TrainVisitor<SortPolicy>::operator()(SpillKNN* ns) const
{
if (ns)
{
@@ -184,6 +193,15 @@ void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
+//! Train specialized for Octrees.
+template<typename SortPolicy>
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
+{
+ if (ns)
+ return TrainLeaf(ns);
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Train on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
@@ -485,6 +503,9 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
case UB_TREE:
nSearch = new NSType<SortPolicy, tree::UBTree>(searchMode, epsilon);
break;
+ case OCTREE:
+ nSearch = new NSType<SortPolicy, tree::Octree>(searchMode, epsilon);
+ break;
}
TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index 8045b71..6f956df 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -1066,7 +1066,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- KNNModel models[26];
+ KNNModel models[28];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1093,6 +1093,8 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
models[23] = KNNModel(KNNModel::TreeTypes::MAX_RP_TREE, false);
models[24] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
models[25] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+ models[26] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
+ models[27] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1102,7 +1104,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat baselineDistances;
knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
- for (size_t i = 0; i < 26; ++i)
+ for (size_t i = 0; i < 28; ++i)
{
// We only have std::move() constructors so make a copy of our data.
arma::mat referenceCopy(referenceData);
@@ -1147,7 +1149,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- KNNModel models[26];
+ KNNModel models[28];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1174,6 +1176,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
models[23] = KNNModel(KNNModel::TreeTypes::MAX_RP_TREE, false);
models[24] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
models[25] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+ models[26] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
+ models[27] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1183,7 +1187,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
arma::mat baselineDistances;
knn.Search(3, baselineNeighbors, baselineDistances);
- for (size_t i = 0; i < 26; ++i)
+ for (size_t i = 0; i < 28; ++i)
{
// We only have a std::move() constructor... so copy the data.
arma::mat referenceCopy(referenceData);
More information about the mlpack-git
mailing list