[mlpack-git] master: Added the UB tree to RSModel. Update range search tests. Replaced 'maxPRTreeRS' by 'maxRPTreeRS'. (43a065b)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Aug 26 16:34:49 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit 43a065b288999f3b044dd02a2cb6f942376c035c
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Fri Aug 26 23:34:49 2016 +0300
Added the UB tree to RSModel. Update range search tests. Replaced 'maxPRTreeRS' by 'maxRPTreeRS'.
>---------------------------------------------------------------
43a065b288999f3b044dd02a2cb6f942376c035c
.../methods/range_search/range_search_main.cpp | 10 +++---
src/mlpack/methods/range_search/rs_model.cpp | 33 +++++++++++++++-----
src/mlpack/methods/range_search/rs_model.hpp | 8 +++--
src/mlpack/methods/range_search/rs_model_impl.hpp | 36 +++++++++++++++-------
src/mlpack/tests/range_search_test.cpp | 12 +++++---
5 files changed, 71 insertions(+), 28 deletions(-)
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index 2505024..9487832 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -71,10 +71,10 @@ PARAM_DOUBLE_IN("min", "Lower bound in range.", "L", 0.0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'vp', 'rp', 'max-rp', "
- "'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
+ "'ub', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', "
"'r-plus-plus'.", "t", "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
- "vp trees, random projection trees, R trees, R* trees, X trees, "
+ "vp trees, random projection trees, UB trees, R trees, R* trees, X trees, "
"Hilbert R trees, R+ trees and R++ trees).", "l", 20);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
@@ -189,10 +189,12 @@ int main(int argc, char *argv[])
tree = RSModel::RP_TREE;
else if (treeType == "max-rp")
tree = RSModel::MAX_RP_TREE;
+ else if (treeType == "ub")
+ tree = RSModel::UB_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
- << "'kd', 'vp', 'rp', 'max-rp', 'cover', 'r', 'r-star', 'x', 'ball', "
- << "'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
+ << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', "
+ << "'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
rs.TreeType() = tree;
rs.RandomBasis() = randomBasis;
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index c926e4b..6025471 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -28,7 +28,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
rPlusPlusTreeRS(NULL),
vpTreeRS(NULL),
rpTreeRS(NULL),
- maxPRTreeRS(NULL)
+ maxRPTreeRS(NULL),
+ ubTreeRS(NULL)
{
// Nothing to do.
}
@@ -155,7 +156,12 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
break;
case MAX_RP_TREE:
- maxPRTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
+ maxRPTreeRS = new RSType<tree::MaxRPTree>(move(referenceSet),
+ naive, singleMode);
+ break;
+
+ case UB_TREE:
+ ubTreeRS = new RSType<tree::UBTree>(move(referenceSet),
naive, singleMode);
break;
}
@@ -289,7 +295,11 @@ void RSModel::Search(arma::mat&& querySet,
break;
case MAX_RP_TREE:
- maxPRTreeRS->Search(querySet, range, neighbors, distances);
+ maxRPTreeRS->Search(querySet, range, neighbors, distances);
+ break;
+
+ case UB_TREE:
+ ubTreeRS->Search(querySet, range, neighbors, distances);
break;
}
}
@@ -355,7 +365,11 @@ void RSModel::Search(const math::Range& range,
break;
case MAX_RP_TREE:
- maxPRTreeRS->Search(range, neighbors, distances);
+ maxRPTreeRS->Search(range, neighbors, distances);
+ break;
+
+ case UB_TREE:
+ ubTreeRS->Search(range, neighbors, distances);
break;
}
}
@@ -389,6 +403,8 @@ std::string RSModel::TreeName() const
return "random projection tree (mean split)";
case MAX_RP_TREE:
return "random projection tree (max split)";
+ case UB_TREE:
+ return "UB tree";
default:
return "unknown tree";
}
@@ -419,8 +435,10 @@ void RSModel::CleanMemory()
delete vpTreeRS;
if (rpTreeRS)
delete rpTreeRS;
- if (maxPRTreeRS)
- delete maxPRTreeRS;
+ if (maxRPTreeRS)
+ delete maxRPTreeRS;
+ if (ubTreeRS)
+ delete ubTreeRS;
kdTreeRS = NULL;
coverTreeRS = NULL;
@@ -433,5 +451,6 @@ void RSModel::CleanMemory()
rPlusPlusTreeRS = NULL;
vpTreeRS = NULL;
rpTreeRS = NULL;
- maxPRTreeRS = NULL;
+ maxRPTreeRS = NULL;
+ ubTreeRS = NULL;
}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index a073968..659319d 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -35,7 +35,8 @@ class RSModel
R_PLUS_PLUS_TREE,
VP_TREE,
RP_TREE,
- MAX_RP_TREE
+ MAX_RP_TREE,
+ UB_TREE
};
private:
@@ -79,7 +80,10 @@ class RSModel
RSType<tree::RPTree>* rpTreeRS;
//! Random projection tree (max) based range search object
//! (NULL if not in use).
- RSType<tree::MaxRPTree>* maxPRTreeRS;
+ RSType<tree::MaxRPTree>* maxRPTreeRS;
+ //! Universal B tree based range search object
+ //! (NULL if not in use).
+ RSType<tree::UBTree>* ubTreeRS;
public:
/**
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index 183f599..2a8502f 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -75,7 +75,11 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
break;
case MAX_RP_TREE:
- ar & CreateNVP(maxPRTreeRS, "range_search_model");
+ ar & CreateNVP(maxRPTreeRS, "range_search_model");
+ break;
+
+ case UB_TREE:
+ ar & CreateNVP(ubTreeRS, "range_search_model");
break;
}
}
@@ -104,8 +108,10 @@ inline const arma::mat& RSModel::Dataset() const
return vpTreeRS->ReferenceSet();
else if (rpTreeRS)
return rpTreeRS->ReferenceSet();
- else if (maxPRTreeRS)
- return maxPRTreeRS->ReferenceSet();
+ else if (maxRPTreeRS)
+ return maxRPTreeRS->ReferenceSet();
+ else if (ubTreeRS)
+ return ubTreeRS->ReferenceSet();
throw std::runtime_error("no range search model initialized");
}
@@ -134,8 +140,10 @@ inline bool RSModel::SingleMode() const
return vpTreeRS->SingleMode();
else if (rpTreeRS)
return rpTreeRS->SingleMode();
- else if (maxPRTreeRS)
- return maxPRTreeRS->SingleMode();
+ else if (maxRPTreeRS)
+ return maxRPTreeRS->SingleMode();
+ else if (ubTreeRS)
+ return ubTreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -164,8 +172,10 @@ inline bool& RSModel::SingleMode()
return vpTreeRS->SingleMode();
else if (rpTreeRS)
return rpTreeRS->SingleMode();
- else if (maxPRTreeRS)
- return maxPRTreeRS->SingleMode();
+ else if (maxRPTreeRS)
+ return maxRPTreeRS->SingleMode();
+ else if (ubTreeRS)
+ return ubTreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -194,8 +204,10 @@ inline bool RSModel::Naive() const
return vpTreeRS->Naive();
else if (rpTreeRS)
return rpTreeRS->Naive();
- else if (maxPRTreeRS)
- return maxPRTreeRS->Naive();
+ else if (maxRPTreeRS)
+ return maxRPTreeRS->Naive();
+ else if (ubTreeRS)
+ return ubTreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
@@ -224,8 +236,10 @@ inline bool& RSModel::Naive()
return vpTreeRS->Naive();
else if (rpTreeRS)
return rpTreeRS->Naive();
- else if (maxPRTreeRS)
- return maxPRTreeRS->Naive();
+ else if (maxRPTreeRS)
+ return maxRPTreeRS->Naive();
+ else if (ubTreeRS)
+ return ubTreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index fcfa9eb..0968b07 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1249,7 +1249,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- RSModel models[24];
+ RSModel models[26];
models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1274,6 +1274,8 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
models[21] = RSModel(RSModel::TreeTypes::RP_TREE, false);
models[22] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+ models[24] = RSModel(RSModel::TreeTypes::UB_TREE, true);
+ models[25] = RSModel(RSModel::TreeTypes::UB_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1287,7 +1289,7 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
vector<vector<pair<double, size_t>>> baselineSorted;
SortResults(baselineNeighbors, baselineDistances, baselineSorted);
- for (size_t i = 0; i < 24; ++i)
+ for (size_t i = 0; i < 26; ++i)
{
// We only have std::move() constructors, so make a copy of our data.
arma::mat referenceCopy(referenceData);
@@ -1331,7 +1333,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- RSModel models[24];
+ RSModel models[26];
models[0] = RSModel(RSModel::TreeTypes::KD_TREE, true);
models[1] = RSModel(RSModel::TreeTypes::KD_TREE, false);
models[2] = RSModel(RSModel::TreeTypes::COVER_TREE, true);
@@ -1356,6 +1358,8 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
models[21] = RSModel(RSModel::TreeTypes::RP_TREE, false);
models[22] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
models[23] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
+ models[24] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, true);
+ models[25] = RSModel(RSModel::TreeTypes::MAX_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1368,7 +1372,7 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
vector<vector<pair<double, size_t>>> baselineSorted;
SortResults(baselineNeighbors, baselineDistances, baselineSorted);
- for (size_t i = 0; i < 24; ++i)
+ for (size_t i = 0; i < 26; ++i)
{
// We only have std::move() cosntructors, so make a copy of our data.
arma::mat referenceCopy(referenceData);
More information about the mlpack-git
mailing list