[mlpack-git] master: Add octree to RangeSearch and RASearch. (3bdd609)

gitdub at mlpack.org gitdub at mlpack.org
Sat Sep 24 13:20:55 EDT 2016


Repository : https://github.com/mlpack/mlpack
On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40

>---------------------------------------------------------------

commit 3bdd6099936bc75173366a00aedb3bb93086bd37
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Sep 24 13:20:55 2016 -0400

    Add octree to RangeSearch and RASearch.


>---------------------------------------------------------------

3bdd6099936bc75173366a00aedb3bb93086bd37
 .../methods/range_search/range_search_main.cpp     |   8 +-
 src/mlpack/methods/range_search/rs_model.cpp       | 104 ++++++++----
 src/mlpack/methods/range_search/rs_model.hpp       |   6 +-
 src/mlpack/methods/range_search/rs_model_impl.hpp  |  14 ++
 src/mlpack/methods/rann/krann_main.cpp             |  11 +-
 src/mlpack/methods/rann/ra_model.hpp               |   6 +-
 src/mlpack/methods/rann/ra_model_impl.hpp          | 176 ++++++++++++++-------
 src/mlpack/tests/krann_search_test.cpp             |   6 +-
 src/mlpack/tests/range_search_test.cpp             |  12 +-
 9 files changed, 246 insertions(+), 97 deletions(-)

diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 9487832..990e3c1 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -72,10 +72,10 @@ PARAM_DOUBLE_IN("min", "Lower bound in range.", "L", 0.0);
 // building.
 PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
     "'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
-    "'r-plus-plus'.", "t", "kd");
+    "'r-plus-plus', 'octree'.", "t", "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
     "vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
-    "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
+    "Hilbert R trees, R+ trees, R++ trees, and octrees).", "l", 20);
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -191,10 +191,12 @@ int main(int argc, char *argv[])
       tree = RSModel::MAX_RP_TREE;
     else if (treeType == "ub")
       tree = RSModel::UB_TREE;
+    else if (treeType == "octree")
+      tree = RSModel::OCTREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
           << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
-          << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+          << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', and 'octree'." << endl;
 
     rs.TreeType() = tree;
     rs.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index 6025471..140b53c 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -29,7 +29,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
     vpTreeRS(NULL),
     rpTreeRS(NULL),
     maxRPTreeRS(NULL),
-    ubTreeRS(NULL)
+    ubTreeRS(NULL),
+    octreeRS(NULL)
 {
   // Nothing to do.
 }
@@ -164,6 +165,28 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
       ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
           naive, singleMode);
       break;
+
+    case OCTREE:
+      // If necessary, build the octree.
+      if (naive)
+      {
+        octreeRS = new RSType<tree::Octree>(move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        vector<size_t> oldFromNewReferences;
+        RSType<tree::Octree>::Tree* octree =
+            new RSType<tree::Octree>::Tree(move(referenceSet),
+            oldFromNewReferences, leafSize);
+        octreeRS = new RSType<tree::Octree>(octree, singleMode);
+
+        // Give the model ownership of the tree and the mappings.
+        octreeRS->treeOwner = true;
+        octreeRS->oldFromNewReferences = move(oldFromNewReferences);
+      }
+
+      break;
   }
 
   if (!naive)
@@ -301,6 +324,38 @@ void RSModel::Search(arma::mat&& querySet,
     case UB_TREE:
       ubTreeRS->Search(querySet, range, neighbors, distances);
       break;
+
+    case OCTREE:
+      if (!octreeRS->Naive() && !octreeRS->SingleMode())
+      {
+        // Build a query tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << endl;
+        vector<size_t> oldFromNewQueries;
+        RSType<tree::Octree>::Tree queryTree(move(querySet), oldFromNewQueries,
+            leafSize);
+        Log::Info << "Tree built." << endl;
+        Timer::Stop("tree_building");
+
+        vector<vector<size_t>> neighborsOut;
+        vector<vector<double>> distancesOut;
+        octreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+        // Remap the query points.
+        neighbors.resize(queryTree.Dataset().n_cols);
+        distances.resize(queryTree.Dataset().n_cols);
+        for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+        {
+          neighbors[oldFromNewQueries[i]] = neighborsOut[i];
+          distances[oldFromNewQueries[i]] = distancesOut[i];
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        octreeRS->Search(querySet, range, neighbors, distances);
+      }
+      break;
   }
 }
 
@@ -371,6 +426,10 @@ void RSModel::Search(const math::Range& range,
     case UB_TREE:
       ubTreeRS->Search(range, neighbors, distances);
       break;
+
+    case OCTREE:
+      octreeRS->Search(range, neighbors, distances);
+      break;
   }
 }
 
@@ -405,6 +464,8 @@ std::string RSModel::TreeName() const
       return "random projection tree (max split)";
     case UB_TREE:
       return "UB tree";
+    case OCTREE:
+      return "octree";
     default:
       return "unknown tree";
   }
@@ -413,32 +474,20 @@ std::string RSModel::TreeName() const
 // Clean memory.
 void RSModel::CleanMemory()
 {
-  if (kdTreeRS)
-    delete kdTreeRS;
-  if (coverTreeRS)
-    delete coverTreeRS;
-  if (rTreeRS)
-    delete rTreeRS;
-  if (rStarTreeRS)
-    delete rStarTreeRS;
-  if (ballTreeRS)
-    delete ballTreeRS;
-  if (xTreeRS)
-    delete xTreeRS;
-  if (hilbertRTreeRS)
-    delete hilbertRTreeRS;
-  if (rPlusTreeRS)
-    delete rPlusTreeRS;
-  if (rPlusPlusTreeRS)
-    delete rPlusPlusTreeRS;
-  if (vpTreeRS)
-    delete vpTreeRS;
-  if (rpTreeRS)
-    delete rpTreeRS;
-  if (maxRPTreeRS)
-    delete maxRPTreeRS;
-  if (ubTreeRS)
-    delete ubTreeRS;
+  delete kdTreeRS;
+  delete coverTreeRS;
+  delete rTreeRS;
+  delete rStarTreeRS;
+  delete ballTreeRS;
+  delete xTreeRS;
+  delete hilbertRTreeRS;
+  delete rPlusTreeRS;
+  delete rPlusPlusTreeRS;
+  delete vpTreeRS;
+  delete rpTreeRS;
+  delete maxRPTreeRS;
+  delete ubTreeRS;
+  delete octreeRS;
 
   kdTreeRS = NULL;
   coverTreeRS = NULL;
@@ -453,4 +502,5 @@ void RSModel::CleanMemory()
   rpTreeRS = NULL;
   maxRPTreeRS = NULL;
   ubTreeRS = NULL;
+  octreeRS = NULL;
 }
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index 659319d..bea33d5 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -13,6 +13,7 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
 
 #include "range_search.hpp"
 
@@ -36,7 +37,8 @@ class RSModel
     VP_TREE,
     RP_TREE,
     MAX_RP_TREE,
-    UB_TREE
+    UB_TREE,
+    OCTREE
   };
 
  private:
@@ -84,6 +86,8 @@ class RSModel
   //! Universal B tree based range search object
   //! (NULL if not in use).
   RSType<tree::UBTree>* ubTreeRS;
+  //! Octree-based range search object (NULL if not in use).
+  RSType<tree::Octree>* octreeRS;
 
  public:
   /**
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 2a8502f..69b7b71 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -81,6 +81,10 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
     case UB_TREE:
       ar & CreateNVP(ubTreeRS, "range_search_model");
       break;
+
+    case OCTREE:
+      ar & CreateNVP(octreeRS, "range_search_model");
+      break;
   }
 }
 
@@ -112,6 +116,8 @@ inline const arma::mat& RSModel::Dataset() const
     return maxRPTreeRS->ReferenceSet();
   else if (ubTreeRS)
     return ubTreeRS->ReferenceSet();
+  else if (octreeRS)
+    return octreeRS->ReferenceSet();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -144,6 +150,8 @@ inline bool RSModel::SingleMode() const
     return maxRPTreeRS->SingleMode();
   else if (ubTreeRS)
     return ubTreeRS->SingleMode();
+  else if (octreeRS)
+    return octreeRS->SingleMode();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -176,6 +184,8 @@ inline bool& RSModel::SingleMode()
     return maxRPTreeRS->SingleMode();
   else if (ubTreeRS)
     return ubTreeRS->SingleMode();
+  else if (octreeRS)
+    return octreeRS->SingleMode();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -208,6 +218,8 @@ inline bool RSModel::Naive() const
     return maxRPTreeRS->Naive();
   else if (ubTreeRS)
     return ubTreeRS->Naive();
+  else if (octreeRS)
+    return octreeRS->Naive();
 
   throw std::runtime_error("no range search model initialized");
 }
@@ -240,6 +252,8 @@ inline bool& RSModel::Naive()
     return maxRPTreeRS->Naive();
   else if (ubTreeRS)
     return ubTreeRS->Naive();
+  else if (octreeRS)
+    return octreeRS->Naive();
 
   throw std::runtime_error("no range search model initialized");
 }
diff --git a/src/mlpack/methods/rann/krann_main.cpp b/src/mlpack/methods/rann/krann_main.cpp
index 591d741..7830f89 100644
--- a/src/mlpack/methods/rann/krann_main.cpp
+++ b/src/mlpack/methods/rann/krann_main.cpp
@@ -65,10 +65,11 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
 // The user may specify the type of tree to use, and a few parameters for tree
 // building.
 PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'ub', 'cover', 'r', "
-    "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
+    "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus', 'octree'.", "t",
+    "kd");
 PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
-    "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and "
-    "R++ trees).", "l", 20);
+    "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees, "
+    "R++ trees, and octrees).", "l", 20);
 PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
     "random orthogonal basis.", "R");
 PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -182,10 +183,12 @@ int main(int argc, char *argv[])
       tree = RANNModel::R_PLUS_PLUS_TREE;
     else if (treeType == "ub")
       tree = RANNModel::UB_TREE;
+    else if (treeType == "octree")
+      tree = RANNModel::OCTREE;
     else
       Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
           << "'kd', 'ub', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
-          << "'r-plus' and 'r-plus-plus'." << endl;
+          << "'r-plus', 'r-plus-plus', 'octree'." << endl;
 
     rann.TreeType() = tree;
     rann.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
index 48be9e9..676a555 100644
--- a/src/mlpack/methods/rann/ra_model.hpp
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -12,6 +12,7 @@
 #include <mlpack/core/tree/binary_space_tree.hpp>
 #include <mlpack/core/tree/cover_tree.hpp>
 #include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
 
 #include "ra_search.hpp"
 
@@ -44,7 +45,8 @@ class RAModel
     HILBERT_R_TREE,
     R_PLUS_TREE,
     R_PLUS_PLUS_TREE,
-    UB_TREE
+    UB_TREE,
+    OCTREE
   };
 
  private:
@@ -85,6 +87,8 @@ class RAModel
   RAType<tree::RPlusPlusTree>* rPlusPlusTreeRA;
   //! Non-NULL if the UB tree is used.
   RAType<tree::UBTree>* ubTreeRA;
+  //! Non-NULL if the octree is used.
+  RAType<tree::Octree>* octreeRA;
 
  public:
   /**
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
index 30452ef..8184c35 100644
--- a/src/mlpack/methods/rann/ra_model_impl.hpp
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -26,7 +26,8 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
     hilbertRTreeRA(NULL),
     rPlusTreeRA(NULL),
     rPlusPlusTreeRA(NULL),
-    ubTreeRA(NULL)
+    ubTreeRA(NULL),
+    octreeRA(NULL)
 {
   // Nothing to do.
 }
@@ -34,24 +35,16 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
 template<typename SortPolicy>
 RAModel<SortPolicy>::~RAModel()
 {
-  if (kdTreeRA)
-    delete kdTreeRA;
-  if (coverTreeRA)
-    delete coverTreeRA;
-  if (rTreeRA)
-    delete rTreeRA;
-  if (rStarTreeRA)
-    delete rStarTreeRA;
-  if (xTreeRA)
-    delete xTreeRA;
-  if (hilbertRTreeRA)
-    delete hilbertRTreeRA;
-  if (rPlusTreeRA)
-    delete rPlusTreeRA;
-  if (rPlusPlusTreeRA)
-    delete rPlusPlusTreeRA;
-  if (ubTreeRA)
-    delete ubTreeRA;
+  delete kdTreeRA;
+  delete coverTreeRA;
+  delete rTreeRA;
+  delete rStarTreeRA;
+  delete xTreeRA;
+  delete hilbertRTreeRA;
+  delete rPlusTreeRA;
+  delete rPlusPlusTreeRA;
+  delete ubTreeRA;
+  delete octreeRA;
 }
 
 template<typename SortPolicy>
@@ -66,24 +59,16 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
   // This should never happen, but just in case, be clean with memory.
   if (Archive::is_loading::value)
   {
-    if (kdTreeRA)
-      delete kdTreeRA;
-    if (coverTreeRA)
-      delete coverTreeRA;
-    if (rTreeRA)
-      delete rTreeRA;
-    if (rStarTreeRA)
-      delete rStarTreeRA;
-    if (xTreeRA)
-      delete xTreeRA;
-    if (hilbertRTreeRA)
-      delete hilbertRTreeRA;
-    if (rPlusTreeRA)
-      delete rPlusTreeRA;
-    if (rPlusPlusTreeRA)
-      delete rPlusPlusTreeRA;
-    if (ubTreeRA)
-      delete ubTreeRA;
+    delete kdTreeRA;
+    delete coverTreeRA;
+    delete rTreeRA;
+    delete rStarTreeRA;
+    delete xTreeRA;
+    delete hilbertRTreeRA;
+    delete rPlusTreeRA;
+    delete rPlusPlusTreeRA;
+    delete ubTreeRA;
+    delete octreeRA;
 
     // Set all the pointers to NULL.
     kdTreeRA = NULL;
@@ -127,6 +112,9 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
     case UB_TREE:
       ar & data::CreateNVP(ubTreeRA, "ra_model");
       break;
+    case OCTREE:
+      ar & data::CreateNVP(octreeRA, "ra_model");
+      break;
   }
 }
 
@@ -151,6 +139,8 @@ const arma::mat& RAModel<SortPolicy>::Dataset() const
     return rPlusPlusTreeRA->ReferenceSet();
   else if (ubTreeRA)
     return ubTreeRA->ReferenceSet();
+  else if (octreeRA)
+    return octreeRA->ReferenceSet();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -177,6 +167,8 @@ bool RAModel<SortPolicy>::Naive() const
     return rPlusPlusTreeRA->Naive();
   else if (ubTreeRA)
     return ubTreeRA->Naive();
+  else if (octreeRA)
+    return octreeRA->Naive();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -203,6 +195,8 @@ bool& RAModel<SortPolicy>::Naive()
     return rPlusPlusTreeRA->Naive();
   else if (ubTreeRA)
     return ubTreeRA->Naive();
+  else if (octreeRA)
+    return octreeRA->Naive();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -229,6 +223,8 @@ bool RAModel<SortPolicy>::SingleMode() const
     return rPlusPlusTreeRA->SingleMode();
   else if (ubTreeRA)
     return ubTreeRA->SingleMode();
+  else if (octreeRA)
+    return octreeRA->SingleMode();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -255,6 +251,8 @@ bool& RAModel<SortPolicy>::SingleMode()
     return rPlusPlusTreeRA->SingleMode();
   else if (ubTreeRA)
     return ubTreeRA->SingleMode();
+  else if (octreeRA)
+    return octreeRA->SingleMode();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -281,6 +279,8 @@ double RAModel<SortPolicy>::Tau() const
     return rPlusPlusTreeRA->Tau();
   else if (ubTreeRA)
     return ubTreeRA->Tau();
+  else if (octreeRA)
+    return octreeRA->Tau();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -307,6 +307,8 @@ double& RAModel<SortPolicy>::Tau()
     return rPlusPlusTreeRA->Tau();
   else if (ubTreeRA)
     return ubTreeRA->Tau();
+  else if (octreeRA)
+    return octreeRA->Tau();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -333,6 +335,8 @@ double RAModel<SortPolicy>::Alpha() const
     return rPlusPlusTreeRA->Alpha();
   else if (ubTreeRA)
     return ubTreeRA->Alpha();
+  else if (octreeRA)
+    return octreeRA->Alpha();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -359,6 +363,8 @@ double& RAModel<SortPolicy>::Alpha()
     return rPlusPlusTreeRA->Alpha();
   else if (ubTreeRA)
     return ubTreeRA->Alpha();
+  else if (octreeRA)
+    return octreeRA->Alpha();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -385,6 +391,8 @@ bool RAModel<SortPolicy>::SampleAtLeaves() const
     return rPlusPlusTreeRA->SampleAtLeaves();
   else if (ubTreeRA)
     return ubTreeRA->SampleAtLeaves();
+  else if (octreeRA)
+    return octreeRA->SampleAtLeaves();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -411,6 +419,8 @@ bool& RAModel<SortPolicy>::SampleAtLeaves()
     return rPlusPlusTreeRA->SampleAtLeaves();
   else if (ubTreeRA)
     return ubTreeRA->SampleAtLeaves();
+  else if (octreeRA)
+    return octreeRA->SampleAtLeaves();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -437,6 +447,8 @@ bool RAModel<SortPolicy>::FirstLeafExact() const
     return rPlusPlusTreeRA->FirstLeafExact();
   else if (ubTreeRA)
     return ubTreeRA->FirstLeafExact();
+  else if (octreeRA)
+    return octreeRA->FirstLeafExact();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -463,6 +475,8 @@ bool& RAModel<SortPolicy>::FirstLeafExact()
     return rPlusPlusTreeRA->FirstLeafExact();
   else if (ubTreeRA)
     return ubTreeRA->FirstLeafExact();
+  else if (octreeRA)
+    return octreeRA->FirstLeafExact();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -489,6 +503,8 @@ size_t RAModel<SortPolicy>::SingleSampleLimit() const
     return rPlusPlusTreeRA->SingleSampleLimit();
   else if (ubTreeRA)
     return ubTreeRA->SingleSampleLimit();
+  else if (octreeRA)
+    return octreeRA->SingleSampleLimit();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -515,6 +531,8 @@ size_t& RAModel<SortPolicy>::SingleSampleLimit()
     return rPlusPlusTreeRA->SingleSampleLimit();
   else if (ubTreeRA)
     return ubTreeRA->SingleSampleLimit();
+  else if (octreeRA)
+    return octreeRA->SingleSampleLimit();
 
   throw std::runtime_error("no rank-approximate nearest neighbor search model "
       "initialized");
@@ -570,24 +588,16 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
   }
 
   // Clean memory, if necessary.
-  if (kdTreeRA)
-    delete kdTreeRA;
-  if (coverTreeRA)
-    delete coverTreeRA;
-  if (rTreeRA)
-    delete rTreeRA;
-  if (rStarTreeRA)
-    delete rStarTreeRA;
-  if (xTreeRA)
-    delete xTreeRA;
-  if (hilbertRTreeRA)
-    delete hilbertRTreeRA;
-  if (rPlusTreeRA)
-    delete rPlusTreeRA;
-  if (rPlusPlusTreeRA)
-    delete rPlusPlusTreeRA;
-  if (ubTreeRA)
-    delete ubTreeRA;
+  delete kdTreeRA;
+  delete coverTreeRA;
+  delete rTreeRA;
+  delete rStarTreeRA;
+  delete xTreeRA;
+  delete hilbertRTreeRA;
+  delete rPlusTreeRA;
+  delete rPlusPlusTreeRA;
+  delete ubTreeRA;
+  delete octreeRA;
 
   if (randomBasis)
     referenceSet = q * referenceSet;
@@ -652,6 +662,26 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
       ubTreeRA = new RAType<tree::UBTree>(std::move(referenceSet),
           naive, singleMode);
       break;
+    case OCTREE:
+      // Build tree, if necessary.
+      if (naive)
+      {
+        octreeRA = new RAType<tree::Octree>(std::move(referenceSet), naive,
+            singleMode);
+      }
+      else
+      {
+        std::vector<size_t> oldFromNewReferences;
+        typename RAType<tree::Octree>::Tree* octree =
+            new typename RAType<tree::Octree>::Tree(std::move(referenceSet),
+            oldFromNewReferences, leafSize);
+        octreeRA = new RAType<tree::Octree>(octree, singleMode);
+
+        // Give the model ownership of the tree.
+        octreeRA->treeOwner = true;
+        octreeRA->oldFromNewReferences = oldFromNewReferences;
+      }
+      break;
   }
 
   if (!naive)
@@ -745,6 +775,37 @@ void RAModel<SortPolicy>::Search(arma::mat&& querySet,
       // No mapping necessary.
       ubTreeRA->Search(querySet, k, neighbors, distances);
       break;
+    case OCTREE:
+      if (!octreeRA->Naive() && !octreeRA->SingleMode())
+      {
+        // Build a second tree and search.
+        Timer::Start("tree_building");
+        Log::Info << "Building query tree..." << std::endl;
+        std::vector<size_t> oldFromNewQueries;
+        typename RAType<tree::Octree>::Tree queryTree(std::move(querySet),
+            oldFromNewQueries, leafSize);
+        Log::Info << "Tree built." << std::endl;
+        Timer::Stop("tree_building");
+
+        arma::Mat<size_t> neighborsOut;
+        arma::mat distancesOut;
+        octreeRA->Search(&queryTree, k, neighborsOut, distancesOut);
+
+        // Unmap the query points.
+        distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+        neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+        for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+        {
+          neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+          distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+        }
+      }
+      else
+      {
+        // Search without building a second tree.
+        octreeRA->Search(querySet, k, neighbors, distances);
+      }
+      break;
   }
 }
 
@@ -791,6 +852,9 @@ void RAModel<SortPolicy>::Search(const size_t k,
     case UB_TREE:
       ubTreeRA->Search(k, neighbors, distances);
       break;
+    case OCTREE:
+      octreeRA->Search(k, neighbors, distances);
+      break;
   }
 }
 
@@ -817,6 +881,8 @@ std::string RAModel<SortPolicy>::TreeName() const
       return "R++ tree";
     case UB_TREE:
       return "UB tree";
+    case OCTREE:
+      return "octree";
     default:
       return "unknown tree";
   }
diff --git a/src/mlpack/tests/krann_search_test.cpp b/src/mlpack/tests/krann_search_test.cpp
index 3d3f918..c0bcadb 100644
--- a/src/mlpack/tests/krann_search_test.cpp
+++ b/src/mlpack/tests/krann_search_test.cpp
@@ -625,7 +625,7 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
   data::Load("rann_test_q_3_100.csv", queryData, true);
 
   // Build all the possible models.
-  KNNModel models[18];
+  KNNModel models[20];
   models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
   models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
   models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
@@ -644,13 +644,15 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
   models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
   models[16] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
   models[17] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
+  models[18] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
+  models[19] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
 
   arma::Mat<size_t> qrRanks;
   data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
 
   for (size_t j = 0; j < 3; ++j)
   {
-    for (size_t i = 0; i < 18; ++i)
+    for (size_t i = 0; i < 20; ++i)
     {
       // We only have std::move() constructors so make a copy of our data.
       arma::mat referenceCopy(referenceData);
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 0968b07..3619691 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1249,7 +1249,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  RSModel models[26];
+  RSModel models[28];
   models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
   models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
   models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1276,6 +1276,8 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
   models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
   models[24] = RSModel(RSModel::TreeTypes::UB_TREE, true);
   models[25] = RSModel(RSModel::TreeTypes::UB_TREE, false);
+  models[26] = RSModel(RSModel::TreeTypes::OCTREE, true);
+  models[27] = RSModel(RSModel::TreeTypes::OCTREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1289,7 +1291,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
     vector<vector<pair<double, size_t>>> baselineSorted;
     SortResults(baselineNeighbors, baselineDistances, baselineSorted);
 
-    for (size_t i = 0; i < 26; ++i)
+    for (size_t i = 0; i < 28; ++i)
     {
       // We only have std::move() constructors, so make a copy of our data.
       arma::mat referenceCopy(referenceData);
@@ -1333,7 +1335,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
   arma::mat referenceData = arma::randu<arma::mat>(10, 200);
 
   // Build all the possible models.
-  RSModel models[26];
+  RSModel models[28];
   models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
   models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
   models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1360,6 +1362,8 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
   models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
   models[24] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
   models[25] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+  models[26] = RSModel(RSModel::TreeTypes::OCTREE, true);
+  models[27] = RSModel(RSModel::TreeTypes::OCTREE, false);
 
   for (size_t j = 0; j < 2; ++j)
   {
@@ -1372,7 +1376,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
     vector<vector<pair<double, size_t>>> baselineSorted;
     SortResults(baselineNeighbors, baselineDistances, baselineSorted);
 
-    for (size_t i = 0; i < 26; ++i)
+    for (size_t i = 0; i < 28; ++i)
     {
       // We only have std::move() cosntructors, so make a copy of our data.
       arma::mat referenceCopy(referenceData);




More information about the mlpack-git mailing list