[mlpack-git] master: Added the UB tree to RAModel. Fixed UB tree traits (added UniqueNumDescendants). (d3563e9)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Aug 26 17:55:09 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/1797a49c8f76d65814fec4a122d0d2fea01fc2d9...9e5cd0ac9c5cde9ac141bc84e7327bd11e19d42e
>---------------------------------------------------------------
commit d3563e9b121054e4e5c1f04849b5edcb7416d1dd
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Sat Aug 27 00:55:09 2016 +0300
Added the UB tree to RAModel. Fixed UB tree traits (added UniqueNumDescendants).
>---------------------------------------------------------------
d3563e9b121054e4e5c1f04849b5edcb7416d1dd
src/mlpack/core/tree/binary_space_tree/traits.hpp | 1 +
src/mlpack/methods/rann/krann_main.cpp | 12 +++--
src/mlpack/methods/rann/ra_model.hpp | 5 +-
src/mlpack/methods/rann/ra_model_impl.hpp | 56 ++++++++++++++++++++++-
src/mlpack/tests/krann_search_test.cpp | 6 ++-
5 files changed, 71 insertions(+), 9 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp
index 69bda7c..ac6dc18 100644
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@ -237,6 +237,7 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType,
static const bool HasSelfChildren = false;
static const bool RearrangesDataset = true;
static const bool BinaryTree = true;
+ static const bool UniqueNumDescendants = true;
};
} // namespace tree
diff --git a/src/mlpack/methods/rann/krann_main.cpp b/src/mlpack/methods/rann/krann_main.cpp
index 9950cd2..591d741 100644
--- a/src/mlpack/methods/rann/krann_main.cpp
+++ b/src/mlpack/methods/rann/krann_main.cpp
@@ -64,11 +64,11 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
-PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'cover', 'r', or "
+PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'ub', 'cover', 'r', "
"'x', 'r-star', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
-PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
- "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l",
- 20);
+PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, "
+ "UB trees, R trees, R* trees, X trees, Hilbert R trees, R+ trees and "
+ "R++ trees).", "l", 20);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -180,9 +180,11 @@ int main(int argc, char *argv[])
tree = RANNModel::R_PLUS_TREE;
else if (treeType == "r-plus-plus")
tree = RANNModel::R_PLUS_PLUS_TREE;
+ else if (treeType == "ub")
+ tree = RANNModel::UB_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
- << "'kd', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
+ << "'kd', 'ub', 'cover', 'r', 'r-star', 'x', 'hilbert-r', "
<< "'r-plus' and 'r-plus-plus'." << endl;
rann.TreeType() = tree;
diff --git a/src/mlpack/methods/rann/ra_model.hpp b/src/mlpack/methods/rann/ra_model.hpp
index 1e755d3..48be9e9 100644
--- a/src/mlpack/methods/rann/ra_model.hpp
+++ b/src/mlpack/methods/rann/ra_model.hpp
@@ -43,7 +43,8 @@ class RAModel
X_TREE,
HILBERT_R_TREE,
R_PLUS_TREE,
- R_PLUS_PLUS_TREE
+ R_PLUS_PLUS_TREE,
+ UB_TREE
};
private:
@@ -82,6 +83,8 @@ class RAModel
RAType<tree::RPlusTree>* rPlusTreeRA;
//! Non-NULL if the R++ tree is used.
RAType<tree::RPlusPlusTree>* rPlusPlusTreeRA;
+ //! Non-NULL if the UB tree is used.
+ RAType<tree::UBTree>* ubTreeRA;
public:
/**
diff --git a/src/mlpack/methods/rann/ra_model_impl.hpp b/src/mlpack/methods/rann/ra_model_impl.hpp
index f096540..30452ef 100644
--- a/src/mlpack/methods/rann/ra_model_impl.hpp
+++ b/src/mlpack/methods/rann/ra_model_impl.hpp
@@ -25,7 +25,8 @@ RAModel<SortPolicy>::RAModel(const TreeTypes treeType, const bool randomBasis) :
xTreeRA(NULL),
hilbertRTreeRA(NULL),
rPlusTreeRA(NULL),
- rPlusPlusTreeRA(NULL)
+ rPlusPlusTreeRA(NULL),
+ ubTreeRA(NULL)
{
// Nothing to do.
}
@@ -49,6 +50,8 @@ RAModel<SortPolicy>::~RAModel()
delete rPlusTreeRA;
if (rPlusPlusTreeRA)
delete rPlusPlusTreeRA;
+ if (ubTreeRA)
+ delete ubTreeRA;
}
template<typename SortPolicy>
@@ -79,6 +82,8 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
delete rPlusTreeRA;
if (rPlusPlusTreeRA)
delete rPlusPlusTreeRA;
+ if (ubTreeRA)
+ delete ubTreeRA;
// Set all the pointers to NULL.
kdTreeRA = NULL;
@@ -89,6 +94,7 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
hilbertRTreeRA = NULL;
rPlusPlusTreeRA = NULL;
rPlusTreeRA = NULL;
+ ubTreeRA = NULL;
}
// We only need to serialize one of the kRANN objects.
@@ -118,6 +124,9 @@ void RAModel<SortPolicy>::Serialize(Archive& ar,
case R_PLUS_PLUS_TREE:
ar & data::CreateNVP(rPlusPlusTreeRA, "ra_model");
break;
+ case UB_TREE:
+ ar & data::CreateNVP(ubTreeRA, "ra_model");
+ break;
}
}
@@ -140,6 +149,8 @@ const arma::mat& RAModel<SortPolicy>::Dataset() const
return rPlusTreeRA->ReferenceSet();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->ReferenceSet();
+ else if (ubTreeRA)
+ return ubTreeRA->ReferenceSet();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -164,6 +175,8 @@ bool RAModel<SortPolicy>::Naive() const
return rPlusTreeRA->Naive();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Naive();
+ else if (ubTreeRA)
+ return ubTreeRA->Naive();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -188,6 +201,8 @@ bool& RAModel<SortPolicy>::Naive()
return rPlusTreeRA->Naive();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Naive();
+ else if (ubTreeRA)
+ return ubTreeRA->Naive();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -212,6 +227,8 @@ bool RAModel<SortPolicy>::SingleMode() const
return rPlusTreeRA->SingleMode();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SingleMode();
+ else if (ubTreeRA)
+ return ubTreeRA->SingleMode();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -236,6 +253,8 @@ bool& RAModel<SortPolicy>::SingleMode()
return rPlusTreeRA->SingleMode();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SingleMode();
+ else if (ubTreeRA)
+ return ubTreeRA->SingleMode();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -260,6 +279,8 @@ double RAModel<SortPolicy>::Tau() const
return rPlusTreeRA->Tau();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Tau();
+ else if (ubTreeRA)
+ return ubTreeRA->Tau();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -284,6 +305,8 @@ double& RAModel<SortPolicy>::Tau()
return rPlusTreeRA->Tau();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Tau();
+ else if (ubTreeRA)
+ return ubTreeRA->Tau();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -308,6 +331,8 @@ double RAModel<SortPolicy>::Alpha() const
return rPlusTreeRA->Alpha();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Alpha();
+ else if (ubTreeRA)
+ return ubTreeRA->Alpha();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -332,6 +357,8 @@ double& RAModel<SortPolicy>::Alpha()
return rPlusTreeRA->Alpha();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->Alpha();
+ else if (ubTreeRA)
+ return ubTreeRA->Alpha();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -356,6 +383,8 @@ bool RAModel<SortPolicy>::SampleAtLeaves() const
return rPlusTreeRA->SampleAtLeaves();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SampleAtLeaves();
+ else if (ubTreeRA)
+ return ubTreeRA->SampleAtLeaves();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -380,6 +409,8 @@ bool& RAModel<SortPolicy>::SampleAtLeaves()
return rPlusTreeRA->SampleAtLeaves();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SampleAtLeaves();
+ else if (ubTreeRA)
+ return ubTreeRA->SampleAtLeaves();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -404,6 +435,8 @@ bool RAModel<SortPolicy>::FirstLeafExact() const
return rPlusTreeRA->FirstLeafExact();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->FirstLeafExact();
+ else if (ubTreeRA)
+ return ubTreeRA->FirstLeafExact();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -428,6 +461,8 @@ bool& RAModel<SortPolicy>::FirstLeafExact()
return rPlusTreeRA->FirstLeafExact();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->FirstLeafExact();
+ else if (ubTreeRA)
+ return ubTreeRA->FirstLeafExact();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -452,6 +487,8 @@ size_t RAModel<SortPolicy>::SingleSampleLimit() const
return rPlusTreeRA->SingleSampleLimit();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SingleSampleLimit();
+ else if (ubTreeRA)
+ return ubTreeRA->SingleSampleLimit();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -476,6 +513,8 @@ size_t& RAModel<SortPolicy>::SingleSampleLimit()
return rPlusTreeRA->SingleSampleLimit();
else if (rPlusPlusTreeRA)
return rPlusPlusTreeRA->SingleSampleLimit();
+ else if (ubTreeRA)
+ return ubTreeRA->SingleSampleLimit();
throw std::runtime_error("no rank-approximate nearest neighbor search model "
"initialized");
@@ -547,6 +586,8 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
delete rPlusTreeRA;
if (rPlusPlusTreeRA)
delete rPlusPlusTreeRA;
+ if (ubTreeRA)
+ delete ubTreeRA;
if (randomBasis)
referenceSet = q * referenceSet;
@@ -607,6 +648,10 @@ void RAModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
rPlusPlusTreeRA = new RAType<tree::RPlusPlusTree>(std::move(referenceSet),
naive, singleMode);
break;
+ case UB_TREE:
+ ubTreeRA = new RAType<tree::UBTree>(std::move(referenceSet),
+ naive, singleMode);
+ break;
}
if (!naive)
@@ -696,6 +741,10 @@ void RAModel<SortPolicy>::Search(arma::mat&& querySet,
// No mapping necessary.
rPlusPlusTreeRA->Search(querySet, k, neighbors, distances);
break;
+ case UB_TREE:
+ // No mapping necessary.
+ ubTreeRA->Search(querySet, k, neighbors, distances);
+ break;
}
}
@@ -739,6 +788,9 @@ void RAModel<SortPolicy>::Search(const size_t k,
case R_PLUS_PLUS_TREE:
rPlusPlusTreeRA->Search(k, neighbors, distances);
break;
+ case UB_TREE:
+ ubTreeRA->Search(k, neighbors, distances);
+ break;
}
}
@@ -763,6 +815,8 @@ std::string RAModel<SortPolicy>::TreeName() const
return "R+ tree";
case R_PLUS_PLUS_TREE:
return "R++ tree";
+ case UB_TREE:
+ return "UB tree";
default:
return "unknown tree";
}
diff --git a/src/mlpack/tests/krann_search_test.cpp b/src/mlpack/tests/krann_search_test.cpp
index 34f87d7..3d3f918 100644
--- a/src/mlpack/tests/krann_search_test.cpp
+++ b/src/mlpack/tests/krann_search_test.cpp
@@ -625,7 +625,7 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
data::Load("rann_test_q_3_100.csv", queryData, true);
// Build all the possible models.
- KNNModel models[16];
+ KNNModel models[18];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, false);
@@ -642,13 +642,15 @@ BOOST_AUTO_TEST_CASE(RAModelTest)
models[13] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, true);
models[14] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
+ models[16] = KNNModel(KNNModel::TreeTypes::UB_TREE, false);
+ models[17] = KNNModel(KNNModel::TreeTypes::UB_TREE, true);
arma::Mat<size_t> qrRanks;
data::Load("rann_test_qr_ranks.csv", qrRanks, true, false); // No transpose.
for (size_t j = 0; j < 3; ++j)
{
- for (size_t i = 0; i < 16; ++i)
+ for (size_t i = 0; i < 18; ++i)
{
// We only have std::move() constructors so make a copy of our data.
arma::mat referenceCopy(referenceData);
More information about the mlpack-git
mailing list