[mlpack-git] master: Add octree to RangeSearch and RASearch. (3bdd609)
gitdub at mlpack.org
gitdub at mlpack.org
Sat Sep 24 13:20:55 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/9ef7339d40550a974b3939e9fcb966fac2c09065...ebdb5abeaa3fd621a06ae663862bb72df76d2b40
>---------------------------------------------------------------
commit 3bdd6099936bc75173366a00aedb3bb93086bd37
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Sep 24 13:20:55 2016 -0400
Add octree to RangeSearch and RASearch.
>---------------------------------------------------------------
3bdd6099936bc75173366a00aedb3bb93086bd37
.../methods/range_search/range_search_main.cpp | 8 +-
src/mlpack/methods/range_search/rs_model.cpp | 104 ++++++++----
src/mlpack/methods/range_search/rs_model.hpp | 6 +-
src/mlpack/methods/range_search/rs_model_impl.hpp | 14 ++
src/mlpack/methods/rann/krann_main.cpp | 11 +-
src/mlpack/methods/rann/ra_model.hpp | 6 +-
src/mlpack/methods/rann/ra_model_impl.hpp | 176 ++++++++++++++-------
src/mlpack/tests/krann_search_test.cpp | 6 +-
src/mlpack/tests/range_search_test.cpp | 12 +-
9 files changed, 246 insertions(+), 97 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 9487832..990e3c1 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -72,10 +72,10 @@ PARAM_DOUBLE_IN("min", "Lower bound in range.", "L", 0.0);
// building.
PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
"'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
- "'r-plus-plus'.", "t", "kd");
+ "'r-plus-plus', 'octree'.", "t", "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
"vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
- "Hilbert R trees, R+ trees and R++ trees).", "l", 20);
+ "Hilbert R trees, R+ trees, R++ trees, and octrees).", "l", 20);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -191,10 +191,12 @@ int main(int argc, char *argv[])
tree = RSModel::MAX_RP_TREE;
else if (treeType == "ub")
tree = RSModel::UB_TREE;
+ else if (treeType == "octree")
+ tree = RSModel::OCTREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
<< "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
- << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+ << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', and 'octree'." << endl;
rs.TreeType() = tree;
rs.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index 6025471..140b53c 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -29,7 +29,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
vpTreeRS(NULL),
rpTreeRS(NULL),
maxRPTreeRS(NULL),
- ubTreeRS(NULL)
+ ubTreeRS(NULL),
+ octreeRS(NULL)
{
// Nothing to do.
}
@@ -164,6 +165,28 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
naive, singleMode);
break;
+
+ case OCTREE:
+ // If necessary, build the octree.
+ if (naive)
+ {
+ octreeRS = new RSType<tree::Octree>(move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ vector<size_t> oldFromNewReferences;
+ RSType<tree::Octree>::Tree* octree =
+ new RSType<tree::Octree>::Tree(move(referenceSet),
+ oldFromNewReferences, leafSize);
+ octreeRS = new RSType<tree::Octree>(octree, singleMode);
+
+ // Give the model ownership of the tree and the mappings.
+ octreeRS->treeOwner = true;
+ octreeRS->oldFromNewReferences = move(oldFromNewReferences);
+ }
+
+ break;
}
if (!naive)
@@ -301,6 +324,38 @@ void RSModel::Search(arma::mat&& querySet,
case UB_TREE:
ubTreeRS->Search(querySet, range, neighbors, distances);
break;
+
+ case OCTREE:
+ if (!octreeRS->Naive() && !octreeRS->SingleMode())
+ {
+ // Build a query tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << endl;
+ vector<size_t> oldFromNewQueries;
+ RSType<tree::Octree>::Tree queryTree(move(querySet), oldFromNewQueries,
+ leafSize);
+ Log::Info << "Tree built." << endl;
+ Timer::Stop("tree_building");
+
+ vector<vector<size_t>> neighborsOut;
+ vector<vector<double>> distancesOut;
+ octreeRS->Search(&queryTree, range, neighborsOut, distancesOut);
+
+ // Remap the query points.
+ neighbors.resize(queryTree.Dataset().n_cols);
+ distances.resize(queryTree.Dataset().n_cols);
+ for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
+ {
+ neighbors[oldFromNewQueries[i]] = neighborsOut[i];
+ distances[oldFromNewQueries[i]] = distancesOut[i];
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ octreeRS->Search(querySet, range, neighbors, distances);
+ }
+ break;
}
}
@@ -371,6 +426,10 @@ void RSModel::Search(const math::Range& range,
case UB_TREE:
ubTreeRS->Search(range, neighbors, distances);
break;
+
+ case OCTREE:
+ octreeRS->Search(range, neighbors, distances);
+ break;
}
}
@@ -405,6 +464,8 @@ std::string RSModel::TreeName() const
return "random projection tree (max split)";
case UB_TREE:
return "UB tree";
+ case OCTREE:
+ return "octree";
default:
return "unknown tree";
}
@@ -413,32 +474,20 @@ std::string RSModel::TreeName() const
// Clean memory.
void RSModel::CleanMemory()
{
- if (kdTreeRS)
- delete kdTreeRS;
- if (coverTreeRS)
- delete coverTreeRS;
- if (rTreeRS)
- delete rTreeRS;
- if (rStarTreeRS)
- delete rStarTreeRS;
- if (ballTreeRS)
- delete ballTreeRS;
- if (xTreeRS)
- delete xTreeRS;
- if (hilbertRTreeRS)
- delete hilbertRTreeRS;
- if (rPlusTreeRS)
- delete rPlusTreeRS;
- if (rPlusPlusTreeRS)
- delete rPlusPlusTreeRS;
- if (vpTreeRS)
- delete vpTreeRS;
- if (rpTreeRS)
- delete rpTreeRS;
- if (maxRPTreeRS)
- delete maxRPTreeRS;
- if (ubTreeRS)
- delete ubTreeRS;
+ delete kdTreeRS;
+ delete coverTreeRS;
+ delete rTreeRS;
+ delete rStarTreeRS;
+ delete ballTreeRS;
+ delete xTreeRS;
+ delete hilbertRTreeRS;
+ delete rPlusTreeRS;
+ delete rPlusPlusTreeRS;
+ delete vpTreeRS;
+ delete rpTreeRS;
+ delete maxRPTreeRS;
+ delete ubTreeRS;
+ delete octreeRS;
kdTreeRS = NULL;
coverTreeRS = NULL;
@@ -453,4 +502,5 @@ void RSModel::CleanMemory()
rpTreeRS = NULL;
maxRPTreeRS = NULL;
ubTreeRS = NULL;
+ octreeRS = NULL;
}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index 659319d..bea33d5 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -13,6 +13,7 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
#include "range_search.hpp"
@@ -36,7 +37,8 @@ class RSModel
VP_TREE,
RP_TREE,
MAX_RP_TREE,
- UB_TREE
+ UB_TREE,
+ OCTREE
};
private:
@@ -84,6 +86,8 @@ class RSModel
//! Universal B tree based range search object
//! (NULL if not in use).
RSType<tree::UBTree>* ubTreeRS;
+ //! Octree-based range search object (NULL if not in use).
+ RSType<tree::Octree>* octreeRS;
public:
/**
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 2a8502f..69b7b71 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -81,6 +81,10 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
case UB_TREE:
ar & CreateNVP(ubTreeRS, "range_search_model");
break;
+
+ case OCTREE:
+ ar & CreateNVP(octreeRS, "range_search_model");
+ break;
}
}
@@ -112,6 +116,8 @@ inline const arma::mat& RSModel::Dataset() const
return maxRPTreeRS->ReferenceSet();
else if (ubTreeRS)
return ubTreeRS->ReferenceSet();
+ else if (octreeRS)
+ return octreeRS->ReferenceSet();
throw std::runtime_error("no range search model initialized");
}
@@ -144,6 +150,8 @@ inline bool RSModel::SingleMode() const
return maxRPTreeRS->SingleMode();
else if (ubTreeRS)
return ubTreeRS->SingleMode();
+ else if (octreeRS)
+ return octreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -176,6 +184,8 @@ inline bool& RSModel::SingleMode()
return maxRPTreeRS->SingleMode();
else if (ubTreeRS)
return ubTreeRS->SingleMode();
+ else if (octreeRS)
+ return octreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -208,6 +218,8 @@ inline bool RSModel::Naive() const
return maxRPTreeRS->Naive();
else if (ubTreeRS)
return ubTreeRS->Naive();
+ else if (octreeRS)
+ return octreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
@@ -240,6 +252,8 @@ inline bool& RSModel::Naive()
return maxRPTreeRS->Naive();
else if (ubTreeRS)
return ubTreeRS->Naive();
+ else if (octreeRS)
+ return octreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
diff --git a/src/mlpack/methods/rann/krann_main.cpp b/src/mlpack/methods/rann/krann_main.cpp
index 591d741..7830f89 100644
--- a/src/mlpack/methods/rann/krann_main.cpp
+++ b/src/mlpack/methods/rann/krann_main.cpp
@@ -65,10 +65,11 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'ub', 'cover', 'r', "
- "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
+ "'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus', 'octree'.", "t",
+ "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
- "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and "
- "R++ trees).", "l", 20);
+ "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees, "
+ "R++ trees, and octrees).", "l", 20);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -182,10 +183,12 @@ int main(int argc, char *argv[])
tree = RANNModel::R_PLUS_PLUS_TREE;
else if (treeType == "ub")
tree = RANNModel::UB_TREE;
+ else if (treeType == "octree")
+ tree = RANNModel::OCTREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
<< "'kd', 'ub', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
- << "'r-plus' and 'r-plus-plus'." << endl;
+ << "'r-plus', 'r-plus-plus', 'octree'." << endl;
rann.TreeType() = tree;
rann.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
index 48be9e9..676a555 100644
--- a/src/mlpack/methods/rann/ra_model.hpp
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -12,6 +12,7 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/octree.hpp>
#include "ra_search.hpp"
@@ -44,7 +45,8 @@ class RAModel
HILBERT_R_TREE,
R_PLUS_TREE,
R_PLUS_PLUS_TREE,
- UB_TREE
+ UB_TREE,
+ OCTREE
};
private:
@@ -85,6 +87,8 @@ class RAModel
RAType<tree::RPlusPlusTree>* rPlusPlusTreeRA;
//! Non-NULL if the UB tree is used.
RAType<tree::UBTree>* ubTreeRA;
+ //! Non-NULL if the octree is used.
+ RAType<tree::Octree>* octreeRA;
public:
/**
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
index 30452ef..8184c35 100644
--- a/src/mlpack/methods/rann/ra_model_impl.hpp
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -26,7 +26,8 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
hilbertRTreeRA(NULL),
rPlusTreeRA(NULL),
rPlusPlusTreeRA(NULL),
- ubTreeRA(NULL)
+ ubTreeRA(NULL),
+ octreeRA(NULL)
{
// Nothing to do.
}
@@ -34,24 +35,16 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
template<typename SortPolicy>
RAModel<SortPolicy>::~RAModel()
{
- if (kdTreeRA)
- delete kdTreeRA;
- if (coverTreeRA)
- delete coverTreeRA;
- if (rTreeRA)
- delete rTreeRA;
- if (rStarTreeRA)
- delete rStarTreeRA;
- if (xTreeRA)
- delete xTreeRA;
- if (hilbertRTreeRA)
- delete hilbertRTreeRA;
- if (rPlusTreeRA)
- delete rPlusTreeRA;
- if (rPlusPlusTreeRA)
- delete rPlusPlusTreeRA;
- if (ubTreeRA)
- delete ubTreeRA;
+ delete kdTreeRA;
+ delete coverTreeRA;
+ delete rTreeRA;
+ delete rStarTreeRA;
+ delete xTreeRA;
+ delete hilbertRTreeRA;
+ delete rPlusTreeRA;
+ delete rPlusPlusTreeRA;
+ delete ubTreeRA;
+ delete octreeRA;
}
template<typename SortPolicy>
@@ -66,24 +59,16 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
// This should never happen, but just in case, be clean with memory.
if (Archive::is_loading::value)
{
- if (kdTreeRA)
- delete kdTreeRA;
- if (coverTreeRA)
- delete coverTreeRA;
- if (rTreeRA)
- delete rTreeRA;
- if (rStarTreeRA)
- delete rStarTreeRA;
- if (xTreeRA)
- delete xTreeRA;
- if (hilbertRTreeRA)
- delete hilbertRTreeRA;
- if (rPlusTreeRA)
- delete rPlusTreeRA;
- if (rPlusPlusTreeRA)
- delete rPlusPlusTreeRA;
- if (ubTreeRA)
- delete ubTreeRA;
+ delete kdTreeRA;
+ delete coverTreeRA;
+ delete rTreeRA;
+ delete rStarTreeRA;
+ delete xTreeRA;
+ delete hilbertRTreeRA;
+ delete rPlusTreeRA;
+ delete rPlusPlusTreeRA;
+ delete ubTreeRA;
+ delete octreeRA;
// Set all the pointers to NULL.
kdTreeRA = NULL;
@@ -127,6 +112,9 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
case UB_TREE:
ar & data::CreateNVP(ubTreeRA, "ra_model");
break;
+ case OCTREE:
+ ar & data::CreateNVP(octreeRA, "ra_model");
+ break;
}
}
@@ -151,6 +139,8 @@ const arma::mat& RAModel<SortPolicy>::Dataset() const
return rPlusPlusTreeRA->ReferenceSet();
else if (ubTreeRA)
return ubTreeRA->ReferenceSet();
+ else if (octreeRA)
+ return octreeRA->ReferenceSet();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -177,6 +167,8 @@ bool RAModel<SortPolicy>::Naive() const
return rPlusPlusTreeRA->Naive();
else if (ubTreeRA)
return ubTreeRA->Naive();
+ else if (octreeRA)
+ return octreeRA->Naive();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -203,6 +195,8 @@ bool& RAModel<SortPolicy>::Naive()
return rPlusPlusTreeRA->Naive();
else if (ubTreeRA)
return ubTreeRA->Naive();
+ else if (octreeRA)
+ return octreeRA->Naive();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -229,6 +223,8 @@ bool RAModel<SortPolicy>::SingleMode() const
return rPlusPlusTreeRA->SingleMode();
else if (ubTreeRA)
return ubTreeRA->SingleMode();
+ else if (octreeRA)
+ return octreeRA->SingleMode();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -255,6 +251,8 @@ bool& RAModel<SortPolicy>::SingleMode()
return rPlusPlusTreeRA->SingleMode();
else if (ubTreeRA)
return ubTreeRA->SingleMode();
+ else if (octreeRA)
+ return octreeRA->SingleMode();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -281,6 +279,8 @@ double RAModel<SortPolicy>::Tau() const
return rPlusPlusTreeRA->Tau();
else if (ubTreeRA)
return ubTreeRA->Tau();
+ else if (octreeRA)
+ return octreeRA->Tau();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -307,6 +307,8 @@ double& RAModel<SortPolicy>::Tau()
return rPlusPlusTreeRA->Tau();
else if (ubTreeRA)
return ubTreeRA->Tau();
+ else if (octreeRA)
+ return octreeRA->Tau();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -333,6 +335,8 @@ double RAModel<SortPolicy>::Alpha() const
return rPlusPlusTreeRA->Alpha();
else if (ubTreeRA)
return ubTreeRA->Alpha();
+ else if (octreeRA)
+ return octreeRA->Alpha();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -359,6 +363,8 @@ double& RAModel<SortPolicy>::Alpha()
return rPlusPlusTreeRA->Alpha();
else if (ubTreeRA)
return ubTreeRA->Alpha();
+ else if (octreeRA)
+ return octreeRA->Alpha();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -385,6 +391,8 @@ bool RAModel<SortPolicy>::SampleAtLeaves() const
return rPlusPlusTreeRA->SampleAtLeaves();
else if (ubTreeRA)
return ubTreeRA->SampleAtLeaves();
+ else if (octreeRA)
+ return octreeRA->SampleAtLeaves();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -411,6 +419,8 @@ bool& RAModel<SortPolicy>::SampleAtLeaves()
return rPlusPlusTreeRA->SampleAtLeaves();
else if (ubTreeRA)
return ubTreeRA->SampleAtLeaves();
+ else if (octreeRA)
+ return octreeRA->SampleAtLeaves();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -437,6 +447,8 @@ bool RAModel<SortPolicy>::FirstLeafExact() const
return rPlusPlusTreeRA->FirstLeafExact();
else if (ubTreeRA)
return ubTreeRA->FirstLeafExact();
+ else if (octreeRA)
+ return octreeRA->FirstLeafExact();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -463,6 +475,8 @@ bool& RAModel<SortPolicy>::FirstLeafExact()
return rPlusPlusTreeRA->FirstLeafExact();
else if (ubTreeRA)
return ubTreeRA->FirstLeafExact();
+ else if (octreeRA)
+ return octreeRA->FirstLeafExact();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -489,6 +503,8 @@ size_t RAModel<SortPolicy>::SingleSampleLimit() const
return rPlusPlusTreeRA->SingleSampleLimit();
else if (ubTreeRA)
return ubTreeRA->SingleSampleLimit();
+ else if (octreeRA)
+ return octreeRA->SingleSampleLimit();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -515,6 +531,8 @@ size_t& RAModel<SortPolicy>::SingleSampleLimit()
return rPlusPlusTreeRA->SingleSampleLimit();
else if (ubTreeRA)
return ubTreeRA->SingleSampleLimit();
+ else if (octreeRA)
+ return octreeRA->SingleSampleLimit();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -570,24 +588,16 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
}
// Clean memory, if necessary.
- if (kdTreeRA)
- delete kdTreeRA;
- if (coverTreeRA)
- delete coverTreeRA;
- if (rTreeRA)
- delete rTreeRA;
- if (rStarTreeRA)
- delete rStarTreeRA;
- if (xTreeRA)
- delete xTreeRA;
- if (hilbertRTreeRA)
- delete hilbertRTreeRA;
- if (rPlusTreeRA)
- delete rPlusTreeRA;
- if (rPlusPlusTreeRA)
- delete rPlusPlusTreeRA;
- if (ubTreeRA)
- delete ubTreeRA;
+ delete kdTreeRA;
+ delete coverTreeRA;
+ delete rTreeRA;
+ delete rStarTreeRA;
+ delete xTreeRA;
+ delete hilbertRTreeRA;
+ delete rPlusTreeRA;
+ delete rPlusPlusTreeRA;
+ delete ubTreeRA;
+ delete octreeRA;
if (randomBasis)
referenceSet = q * referenceSet;
@@ -652,6 +662,26 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
ubTreeRA = new RAType<tree::UBTree>(std::move(referenceSet),
naive, singleMode);
break;
+ case OCTREE:
+ // Build tree, if necessary.
+ if (naive)
+ {
+ octreeRA = new RAType<tree::Octree>(std::move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ std::vector<size_t> oldFromNewReferences;
+ typename RAType<tree::Octree>::Tree* octree =
+ new typename RAType<tree::Octree>::Tree(std::move(referenceSet),
+ oldFromNewReferences, leafSize);
+ octreeRA = new RAType<tree::Octree>(octree, singleMode);
+
+ // Give the model ownership of the tree.
+ octreeRA->treeOwner = true;
+ octreeRA->oldFromNewReferences = oldFromNewReferences;
+ }
+ break;
}
if (!naive)
@@ -745,6 +775,37 @@ void RAModel<SortPolicy>::Search(arma::mat&& querySet,
// No mapping necessary.
ubTreeRA->Search(querySet, k, neighbors, distances);
break;
+ case OCTREE:
+ if (!octreeRA->Naive() && !octreeRA->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << std::endl;
+ std::vector<size_t> oldFromNewQueries;
+ typename RAType<tree::Octree>::Tree queryTree(std::move(querySet),
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << std::endl;
+ Timer::Stop("tree_building");
+
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ octreeRA->Search(&queryTree, k, neighborsOut, distancesOut);
+
+ // Unmap the query points.
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+ {
+ neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ octreeRA->Search(querySet, k, neighbors, distances);
+ }
+ break;
}
}
@@ -791,6 +852,9 @@ void RAModel<SortPolicy>::Search(const size_t k,
case UB_TREE:
ubTreeRA->Search(k, neighbors, distances);
break;
+ case OCTREE:
+ octreeRA->Search(k, neighbors, distances);
+ break;
}
}
@@ -817,6 +881,8 @@ std::string RAModel<SortPolicy>::TreeName() const
return "R++ tree";
case UB_TREE:
return "UB tree";
+ case OCTREE:
+ return "octree";
default:
return "unknown tree";
}
diff --git a/src/mlpack/tests/krann_search_test.cpp b/src/mlpack/tests/krann_search_test.cpp
index 3d3f918..c0bcadb 100644
--- a/src/mlpack/tests/krann_search_test.cpp
+++ b/src/mlpack/tests/krann_search_test.cpp
@@ -625,7 +625,7 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
data::Load("rann_test_q_3_100.csv", queryData, true);
// Build all the possible models.
- KNNModel models[18];
+ KNNModel models[20];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
@@ -644,13 +644,15 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[16] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
models[17] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
+ models[18] = KNNModel(KNNModel::TreeTypes::OCTREE, false);
+ models[19] = KNNModel(KNNModel::TreeTypes::OCTREE, true);
arma::Mat<size_t> qrRanks;
data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
for (size_t j = 0; j < 3; ++j)
{
- for (size_t i = 0; i < 18; ++i)
+ for (size_t i = 0; i < 20; ++i)
{
// We only have std::move() constructors so make a copy of our data.
arma::mat referenceCopy(referenceData);
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index 0968b07..3619691 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1249,7 +1249,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- RSModel models[26];
+ RSModel models[28];
models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1276,6 +1276,8 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
models[24] = RSModel(RSModel::TreeTypes::UB_TREE, true);
models[25] = RSModel(RSModel::TreeTypes::UB_TREE, false);
+ models[26] = RSModel(RSModel::TreeTypes::OCTREE, true);
+ models[27] = RSModel(RSModel::TreeTypes::OCTREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1289,7 +1291,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
vector<vector<pair<double, size_t>>> baselineSorted;
SortResults(baselineNeighbors, baselineDistances, baselineSorted);
- for (size_t i = 0; i < 26; ++i)
+ for (size_t i = 0; i < 28; ++i)
{
// We only have std::move() constructors, so make a copy of our data.
arma::mat referenceCopy(referenceData);
@@ -1333,7 +1335,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- RSModel models[26];
+ RSModel models[28];
models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1360,6 +1362,8 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
models[24] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
models[25] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+ models[26] = RSModel(RSModel::TreeTypes::OCTREE, true);
+ models[27] = RSModel(RSModel::TreeTypes::OCTREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1372,7 +1376,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
vector<vector<pair<double, size_t>>> baselineSorted;
SortResults(baselineNeighbors, baselineDistances, baselineSorted);
- for (size_t i = 0; i < 26; ++i)
+ for (size_t i = 0; i < 28; ++i)
{
// We only have std::move() cosntructors, so make a copy of our data.
arma::mat referenceCopy(referenceData);
More information about the mlpack-git
mailing list