[mlpack-git] master: Add tested support to KNN/KFN for octree. (e4ca4c1)

gitdub at mlpack.org gitdub at mlpack.org
Sat Sep 24 12:44:52 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit e4ca4c1e0ae024ffaec24942a6e54feb6889ca95
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Sep 24 12:44:52 2016 -0400

    Add tested support to KNN/KFN for octree.


>---------------------------------------------------------------

e4ca4c1e0ae024ffaec24942a6e54feb6889ca95
 src/mlpack/methods/neighbor_search/kfn_main.cpp    |  9 +++++---
 src/mlpack/methods/neighbor_search/knn_main.cpp    |  9 +++++---
 src/mlpack/methods/neighbor_search/ns_model.hpp    | 13 +++++++++--
 .../methods/neighbor_search/ns_model_impl.hpp      | 27 +++++++++++++++++++---
 src/mlpack/tests/knn_test.cpp                      | 12 ++++++----
 5 files changed, 55 insertions(+), 15 deletions(-)

diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 2eb5bd6..9645472 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -68,10 +68,10 @@ PARAM_INT_IN("k", "Number of furthest neighbors to find.", "k", 0);
 // building.
 PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
     "'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
-    "'r-plus-plus'.", "t", "kd");
+    "'r-plus-plus', 'octree'.", "t", "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
     "vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
-    "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
+    "Hilbert R trees, R+ trees, R++ trees, and octrees).", "l", 20);
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -262,10 +262,13 @@ int main(int argc, char *argv[])
       tree = KFNModel::MAX_RP_TREE;
     else if (treeType == "ub")
       tree = KFNModel::UB_TREE;
+    else if (treeType == "octree")
+      tree = KFNModel::OCTREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
           << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
-          << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+          << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', and 'octree'."
+          << endl;
 
     kfn.TreeType() = tree;
     kfn.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 0e29ff8..2127b79 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -73,7 +73,8 @@ PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
     "'r-plus-plus', 'spill'.", "t", "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, vp "
     "trees, random projection trees, UB trees, R trees, R* trees, X trees, "
-    "Hilbert R trees, R+ trees, R++ trees and spill trees).", "l", 20);
+    "Hilbert R trees, R+ trees, R++ trees, spill trees, and octrees).", "l",
+    20);
 PARAM_DOUBLE_IN("tau", "Overlapping size (only valid for spill trees).", "u",
     0);
 PARAM_DOUBLE_IN("rho", "Balance threshold (only valid for spill trees).", "b",
@@ -276,11 +277,13 @@ int main(int argc, char *argv[])
       tree = KNNModel::MAX_RP_TREE;
     else if (treeType == "ub")
       tree = KNNModel::UB_TREE;
+    else if (treeType == "octree")
+      tree = KNNModel::OCTREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
           << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
-          << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus' and 'spill'."
-          << endl;
+          << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'spill', and "
+          << "'octree'." << endl;
 
     knn.TreeType() = tree;
     knn.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 4f4c47d..b0ae690 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -14,6 +14,7 @@
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
 #include <mlpack/core/tree/spill_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
 #include <boost/variant.hpp>
 #include "neighbor_search.hpp"
 
@@ -133,6 +134,9 @@ class BiSearchVisitor : public boost::static_visitor<void>
   //! Bichromatic neighbor search specialized for SPTrees.
   void operator()(SpillKNN* ns) const;
 
+  //! Bichromatic neighbor search specialized for octrees.
+  void operator()(NSTypeT<tree::Octree>* ns) const;
+
   //! Construct the BiSearchVisitor.
   BiSearchVisitor(const arma::mat& querySet,
                   const size_t k,
@@ -188,6 +192,9 @@ class TrainVisitor : public boost::static_visitor<void>
   //! Train specialized for SPTrees.
   void operator()(SpillKNN* ns) const;
 
+  //! Train specialized for octrees.
+  void operator()(NSTypeT<tree::Octree>* ns) const;
+
   //! Construct the TrainVisitor object with the given reference set, leafSize
   //! for BinarySpaceTrees, and tau and rho for spill trees.
   TrainVisitor(arma::mat&& referenceSet,
@@ -287,7 +294,8 @@ class NSModel
     RP_TREE,
     MAX_RP_TREE,
     SPILL_TREE,
-    UB_TREE
+    UB_TREE,
+    OCTREE
   };
 
  private:
@@ -325,7 +333,8 @@ class NSModel
                  NSType<SortPolicy, tree::RPTree>*,
                  NSType<SortPolicy, tree::MaxRPTree>*,
                  SpillKNN*,
-                 NSType<SortPolicy, tree::UBTree>*> nSearch;
+                 NSType<SortPolicy, tree::UBTree>*,
+                 NSType<SortPolicy, tree::Octree>*> nSearch;
 
  public:
   /**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 062a6ec..c2d6c6b 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -96,6 +96,15 @@ void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
     throw std::runtime_error("no neighbor search model initialized");
 }
 
+//! Bichromatic neighbor search specialized for octrees.
+template<typename SortPolicy>
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
+{
+  if (ns)
+    return SearchLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Bichromatic neighbor search on the given NSType considering the leafSize.
 template<typename SortPolicy>
 template<typename NSType>
@@ -150,7 +159,7 @@ void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
 
 //! Train on the given NSType specialized for KDTrees.
 template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
 {
   if (ns)
     return TrainLeaf(ns);
@@ -159,7 +168,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::KDTree>* ns) const
 
 //! Train on the given NSType specialized for BallTrees.
 template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
 {
   if (ns)
     return TrainLeaf(ns);
@@ -168,7 +177,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
 
 //! Train specialized for SPTrees.
 template<typename SortPolicy>
-void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
+void TrainVisitor<SortPolicy>::operator()(SpillKNN* ns) const
 {
   if (ns)
   {
@@ -184,6 +193,15 @@ void TrainVisitor<SortPolicy>::operator ()(SpillKNN* ns) const
     throw std::runtime_error("no neighbor search model initialized");
 }
 
+//! Train specialized for Octrees.
+template<typename SortPolicy>
+void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
+{
+  if (ns)
+    return TrainLeaf(ns);
+  throw std::runtime_error("no neighbor search model initialized");
+}
+
 //! Train on the given NSType considering the leafSize.
 template<typename SortPolicy>
 template<typename NSType>
@@ -485,6 +503,9 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
     case UB_TREE:
       nSearch = new NSType<SortPolicy, tree::UBTree>(searchMode, epsilon);
       break;
+    case OCTREE:
+      nSearch = new NSType<SortPolicy, tree::Octree>(searchMode, epsilon);
+      break;
   }
 
   TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index 8045b71..6f956df 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -1066,7 +1066,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  KNNModel models[26];
+  KNNModel models[28];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1093,6 +1093,8 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
   models[23] = KNNModel(KNNModel::TreeTypes::MAX_RP_TREE, false);
   models[24] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
   models[25] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+  models[26] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
+  models[27] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1102,7 +1104,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
     arma::mat baselineDistances;
     knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
 
-    for (size_t i = 0; i < 26; ++i)
+    for (size_t i = 0; i < 28; ++i)
     {
       // We only have std::move() constructors so make a copy of our data.
       arma::mat referenceCopy(referenceData);
@@ -1147,7 +1149,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  KNNModel models[26];
+  KNNModel models[28];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1174,6 +1176,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
   models[23] = KNNModel(KNNModel::TreeTypes::MAX_RP_TREE, false);
   models[24] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
   models[25] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+  models[26] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
+  models[27] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1183,7 +1187,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
     arma::mat baselineDistances;
     knn.Search(3, baselineNeighbors, baselineDistances);
 
-    for (size_t i = 0; i < 26; ++i)
+    for (size_t i = 0; i < 28; ++i)
     {
       // We only have a std::move() constructor... so copy the data.
       arma::mat referenceCopy(referenceData);




More information about the mlpack-git mailing list