[mlpack-git] master: Add ball tree support to NSModel. (abce060)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Oct 20 09:47:56 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/67e0a132c7f62820c734eb508fe1bc83128a3e13...00eccfdb0d315de3d94bfa1da84cc1dc65c8af39
>---------------------------------------------------------------
commit abce06073c625f8fe48d0caad26e4ca658c2a39d
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Oct 19 20:57:18 2015 +0000
Add ball tree support to NSModel.
>---------------------------------------------------------------
abce06073c625f8fe48d0caad26e4ca658c2a39d
src/mlpack/methods/neighbor_search/ns_model.hpp | 4 +-
.../methods/neighbor_search/ns_model_impl.hpp | 80 +++++++++++++++++++++-
src/mlpack/tests/allknn_test.cpp | 12 ++--
3 files changed, 90 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 6412922..6242377 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -46,7 +46,8 @@ class NSModel
KD_TREE,
COVER_TREE,
R_TREE,
- R_STAR_TREE
+ R_STAR_TREE,
+ BALL_TREE
};
private:
@@ -70,6 +71,7 @@ class NSModel
NSType<tree::StandardCoverTree>* coverTreeNS;
NSType<tree::RTree>* rTreeNS;
NSType<tree::RStarTree>* rStarTreeNS;
+ NSType<tree::BallTree>* ballTreeNS;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 1f9ba6f..5eb5511 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -27,7 +27,8 @@ NSModel<SortPolicy>::NSModel(int treeType, bool randomBasis) :
kdTreeNS(NULL),
coverTreeNS(NULL),
rTreeNS(NULL),
- rStarTreeNS(NULL)
+ rStarTreeNS(NULL),
+ ballTreeNS(NULL)
{
// Nothing to do.
}
@@ -44,6 +45,8 @@ NSModel<SortPolicy>::~NSModel()
delete rTreeNS;
if (rStarTreeNS)
delete rStarTreeNS;
+ if (ballTreeNS)
+ delete ballTreeNS;
}
//! Serialize the kNN model.
@@ -67,6 +70,8 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
delete rTreeNS;
if (rStarTreeNS)
delete rStarTreeNS;
+ if (ballTreeNS)
+ delete ballTreeNS;
// Set all the pointers to NULL.
kdTreeNS = NULL;
@@ -91,6 +96,9 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
case R_STAR_TREE:
ar & data::CreateNVP(rStarTreeNS, name);
break;
+ case BALL_TREE:
+ ar & data::CreateNVP(ballTreeNS, name);
+ break;
}
}
@@ -105,6 +113,8 @@ const arma::mat& NSModel<SortPolicy>::Dataset() const
return rTreeNS->ReferenceSet();
else if (rStarTreeNS)
return rStarTreeNS->ReferenceSet();
+ else if (ballTreeNS)
+ return ballTreeNS->ReferenceSet();
throw std::runtime_error("no neighbor search model initialized");
}
@@ -121,6 +131,8 @@ bool NSModel<SortPolicy>::SingleMode() const
return rTreeNS->SingleMode();
else if (rStarTreeNS)
return rStarTreeNS->SingleMode();
+ else if (ballTreeNS)
+ return ballTreeNS->SingleMode();
throw std::runtime_error("no neighbor search model initialized");
}
@@ -136,6 +148,8 @@ bool& NSModel<SortPolicy>::SingleMode()
return rTreeNS->SingleMode();
else if (rStarTreeNS)
return rStarTreeNS->SingleMode();
+ else if (ballTreeNS)
+ return ballTreeNS->SingleMode();
throw std::runtime_error("no neighbor search model initialized");
}
@@ -151,6 +165,8 @@ bool NSModel<SortPolicy>::Naive() const
return rTreeNS->Naive();
else if (rStarTreeNS)
return rStarTreeNS->Naive();
+ else if (ballTreeNS)
+ return ballTreeNS->Naive();
throw std::runtime_error("no neighbor search model initialized");
}
@@ -166,6 +182,8 @@ bool& NSModel<SortPolicy>::Naive()
return rTreeNS->Naive();
else if (rStarTreeNS)
return rStarTreeNS->Naive();
+ else if (ballTreeNS)
+ return ballTreeNS->Naive();
throw std::runtime_error("no neighbor search model initialized");
}
@@ -218,6 +236,8 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
delete rTreeNS;
if (rStarTreeNS)
delete rStarTreeNS;
+ if (ballTreeNS)
+ delete ballTreeNS;
// Do we need to modify the reference set?
if (randomBasis)
@@ -267,6 +287,27 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
rStarTreeNS = new NSType<tree::RStarTree>(std::move(referenceSet), naive,
singleMode);
break;
+ case BALL_TREE:
+ // If necessary, build the ball tree.
+ if (naive)
+ {
+ ballTreeNS = new NSType<tree::BallTree>(std::move(referenceSet), naive,
+ singleMode);
+ }
+ else
+ {
+ std::vector<size_t> oldFromNewReferences;
+ typename NSType<tree::BallTree>::Tree* ballTree =
+ new typename NSType<tree::KDTree>::Tree(std::move(referenceSet),
+ oldFromNewReferences, leafSize);
+ ballTreeNS = new NSType<tree::BallTree>(ballTree, singleMode);
+
+ // Give the model ownership of the tree and the mappings.
+ ballTreeNS->treeOwner = true;
+ ballTreeNS->oldFromNewReferences = std::move(oldFromNewReferences);
+ }
+
+ break;
}
if (!naive)
@@ -340,6 +381,38 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
// No mapping necessary.
rStarTreeNS->Search(querySet, k, neighbors, distances);
break;
+ case BALL_TREE:
+ if (!ballTreeNS->Naive() && !ballTreeNS->SingleMode())
+ {
+ // Build a second tree and search.
+ Timer::Start("tree_building");
+ Log::Info << "Building query tree..." << std::endl;
+ std::vector<size_t> oldFromNewQueries;
+ typename NSType<tree::BallTree>::Tree queryTree(std::move(querySet),
+ oldFromNewQueries, leafSize);
+ Log::Info << "Tree built." << std::endl;
+ Timer::Stop("tree_building");
+
+ arma::Mat<size_t> neighborsOut;
+ arma::mat distancesOut;
+ ballTreeNS->Search(&queryTree, k, neighborsOut, distancesOut);
+
+ // Unmap the query points.
+ distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
+ neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
+ for (size_t i = 0; i < neighborsOut.n_cols; ++i)
+ {
+ neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
+ distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
+ }
+ }
+ else
+ {
+ // Search without building a second tree.
+ ballTreeNS->Search(querySet, k, neighbors, distances);
+ }
+
+ break;
}
}
@@ -371,6 +444,9 @@ void NSModel<SortPolicy>::Search(const size_t k,
case R_STAR_TREE:
rStarTreeNS->Search(k, neighbors, distances);
break;
+ case BALL_TREE:
+ ballTreeNS->Search(k, neighbors, distances);
+ break;
}
}
@@ -388,6 +464,8 @@ void NSModel<SortPolicy>::TreeName() const
return "R tree";
case R_STAR_TREE:
return "R* tree";
+ case BALL_TREE:
+ return "ball tree";
default:
return "unknown tree";
}
diff --git a/src/mlpack/tests/allknn_test.cpp b/src/mlpack/tests/allknn_test.cpp
index 150a765..82b8564 100644
--- a/src/mlpack/tests/allknn_test.cpp
+++ b/src/mlpack/tests/allknn_test.cpp
@@ -975,7 +975,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- KNNModel models[8];
+ KNNModel models[10];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -984,6 +984,8 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+ models[8] = KNNModel(KNNModel::TreeTypes::BALL_TREE, true);
+ models[9] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -993,7 +995,7 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
arma::mat baselineDistances;
knn.Search(queryData, 3, baselineNeighbors, baselineDistances);
- for (size_t i = 0; i < 8; ++i)
+ for (size_t i = 0; i < 10; ++i)
{
// We only have std::move() constructors so make a copy of our data.
arma::mat referenceCopy(referenceData);
@@ -1037,7 +1039,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
arma::mat referenceData = arma::randu<arma::mat>(10, 200);
// Build all the possible models.
- KNNModel models[8];
+ KNNModel models[10];
models[0] = KNNModel(KNNModel::TreeTypes::KD_TREE, true);
models[1] = KNNModel(KNNModel::TreeTypes::KD_TREE, false);
models[2] = KNNModel(KNNModel::TreeTypes::COVER_TREE, true);
@@ -1046,6 +1048,8 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
models[5] = KNNModel(KNNModel::TreeTypes::R_TREE, false);
models[6] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, true);
models[7] = KNNModel(KNNModel::TreeTypes::R_STAR_TREE, false);
+ models[8] = KNNModel(KNNModel::TreeTypes::BALL_TREE, true);
+ models[0] = KNNModel(KNNModel::TreeTypes::BALL_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1055,7 +1059,7 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
arma::mat baselineDistances;
knn.Search(3, baselineNeighbors, baselineDistances);
- for (size_t i = 0; i < 8; ++i)
+ for (size_t i = 0; i < 10; ++i)
{
// We only have a std::move() constructor... so copy the data.
arma::mat referenceCopy(referenceData);
More information about the mlpack-git
mailing list