[mlpack-git] master: Add support for spill trees in knn search. (2e67697)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 18 13:39:00 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit 2e676974ef119f128147c2f0567706e306c0e7e7
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Tue Jul 12 20:52:09 2016 -0300
Add support for spill trees in knn search.
>---------------------------------------------------------------
2e676974ef119f128147c2f0567706e306c0e7e7
src/mlpack/core/tree/spill_tree/spill_tree.hpp | 6 +-
src/mlpack/methods/neighbor_search/knn_main.cpp | 26 +++++++--
src/mlpack/methods/neighbor_search/ns_model.hpp | 35 +++++++++--
.../methods/neighbor_search/ns_model_impl.hpp | 67 ++++++++++++++++++++--
4 files changed, 115 insertions(+), 19 deletions(-)
diff --git a/src/mlpack/core/tree/spill_tree/spill_tree.hpp b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
index e563b49..d5c13fc 100644
--- a/src/mlpack/core/tree/spill_tree/spill_tree.hpp
+++ b/src/mlpack/core/tree/spill_tree/spill_tree.hpp
@@ -127,7 +127,7 @@ class SpillTree
* @param rho Balance threshold.
*/
SpillTree(const MatType& data,
- const double tau,
+ const double tau = 0,
const size_t maxLeafSize = 20,
const double rho = 0.7);
@@ -143,7 +143,7 @@ class SpillTree
* @param rho Balance threshold.
*/
SpillTree(MatType&& data,
- const double tau,
+ const double tau = 0,
const size_t maxLeafSize = 20,
const double rho = 0.7);
@@ -163,7 +163,7 @@ class SpillTree
SpillTree(SpillTree* parent,
std::vector<size_t>& points,
const size_t overlapIndex,
- const double tau,
+ const double tau = 0,
const size_t maxLeafSize = 20,
const double rho = 0.7);
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index 14e07db..c201db9 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -64,10 +64,13 @@ PARAM_INT_IN("k", "Number of nearest neighbors to find.", "k", 0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
PARAM_STRING_IN("tree_type", "Type of tree to use: 'kd', 'cover', 'r', "
- "'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus'.", "t", "kd");
+ "'r-star', 'x', 'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', 'spill'.",
+ "t", "kd");
PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
- "trees, R* trees, X trees, Hilbert R trees, R+ trees and R++ trees).", "l",
- 20);
+ "trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees, and Spill "
+ "trees).", "l", 20);
+PARAM_DOUBLE_IN("tau", "Overlapping size (for spill trees).", "u", 0);
+
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
PARAM_INT_IN("seed", "Random seed (if 0, std::time(NULL) is used).", "s", 0);
@@ -111,6 +114,9 @@ int main(int argc, char *argv[])
if (CLI::HasParam("leaf_size"))
Log::Warn << "--leaf_size (-l) will be ignored because --input_model_file"
<< " is specified." << endl;
+ if (CLI::HasParam("tau"))
+ Log::Warn << "--tau (-u) will be ignored because --input_model_file"
+ << " is specified." << endl;
if (CLI::HasParam("random_basis"))
Log::Warn << "--random_basis (-R) will be ignored because "
<< "--input_model_file is specified." << endl;
@@ -144,6 +150,13 @@ int main(int argc, char *argv[])
Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater "
"than 0." << endl;
+ // Sanity check on tau.
+ const double tau = CLI::GetParam<double>("tau");
+ if (tau < 0)
+ Log::Fatal << "Invalid tau: " << tau << ". Must be non-negative. " << endl;
+ if (CLI::HasParam("tau") && "spill" != CLI::GetParam<string>("tree_type"))
+ Log::Fatal << "Tau parameter is only valid for spill trees." << endl;
+
// Sanity check on epsilon.
const double epsilon = CLI::GetParam<double>("epsilon");
if (epsilon < 0)
@@ -180,13 +193,17 @@ int main(int argc, char *argv[])
tree = KNNModel::R_PLUS_TREE;
else if (treeType == "r-plus-plus")
tree = KNNModel::R_PLUS_PLUS_TREE;
+ else if (treeType == "spill")
+ tree = KNNModel::SPILL_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
<< "'kd', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
- << "'r-plus' and 'r-plus-plus'." << endl;
+ << "'r-plus', 'r-plus-plus' and 'spill'." << endl;
knn.TreeType() = tree;
knn.RandomBasis() = randomBasis;
+ knn.LeafSize() = size_t(lsInt);
+ knn.Tau() = tau;
arma::mat referenceSet;
data::Load(referenceFile, referenceSet, true);
@@ -213,6 +230,7 @@ int main(int argc, char *argv[])
knn.Naive() = CLI::HasParam("naive");
knn.LeafSize() = size_t(lsInt);
knn.Epsilon() = epsilon;
+ knn.Tau() = tau;
}
// Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 38e4748..6a31d5b 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -13,6 +13,7 @@
#include <mlpack/core/tree/binary_space_tree.hpp>
#include <mlpack/core/tree/cover_tree.hpp>
#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/spill_tree.hpp>
#include <boost/variant.hpp>
#include "neighbor_search.hpp"
@@ -101,6 +102,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
arma::mat& distances;
//! The number of points in a leaf (for BinarySpaceTrees).
const size_t leafSize;
+ //! Overlapping size (for spill trees).
+ const double tau;
//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename NSType>
@@ -125,12 +128,16 @@ class BiSearchVisitor : public boost::static_visitor<void>
//! Bichromatic neighbor search on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;
+ //! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+ void operator()(NSTypeT<tree::SPTree>* ns) const;
+
//! Construct the BiSearchVisitor.
BiSearchVisitor(const arma::mat& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
- const size_t leafSize);
+ const size_t leafSize,
+ const double tau);
};
/**
@@ -147,6 +154,8 @@ class TrainVisitor : public boost::static_visitor<void>
arma::mat&& referenceSet;
//! The leaf size, used only by BinarySpaceTree.
size_t leafSize;
+ //! Overlapping size (for spill trees).
+ const double tau;
//! Train on the given NSType considering the leafSize.
template<typename NSType>
@@ -171,9 +180,14 @@ class TrainVisitor : public boost::static_visitor<void>
//! Train on the given NSType specialized for BallTrees.
void operator()(NSTypeT<tree::BallTree>* ns) const;
- //! Construct the TrainVisitor object with the given reference set and leaf
- //! size for BinarySpaceTrees.
- TrainVisitor(arma::mat&& referenceSet, const size_t leafSize);
+ //! Train on the given NSType specialized for SPTrees.
+ void operator()(NSTypeT<tree::SPTree>* ns) const;
+
+ //! Construct the TrainVisitor object with the given reference set, leafSize
+ //! for BinarySpaceTrees, and tau for spill trees.
+ TrainVisitor(arma::mat&& referenceSet,
+ const size_t leafSize,
+ const double tau);
};
/**
@@ -256,7 +270,8 @@ class NSModel
X_TREE,
HILBERT_R_TREE,
R_PLUS_TREE,
- R_PLUS_PLUS_TREE
+ R_PLUS_PLUS_TREE,
+ SPILL_TREE
};
private:
@@ -266,6 +281,9 @@ class NSModel
//! For tree types that accept the maxLeafSize parameter.
size_t leafSize;
+ //! Overlapping size (for spill trees).
+ double tau;
+
//! If true, random projections are used.
bool randomBasis;
//! This is the random projection matrix; only used if randomBasis is true.
@@ -284,7 +302,8 @@ class NSModel
NSType<SortPolicy, tree::XTree>*,
NSType<SortPolicy, tree::HilbertRTree>*,
NSType<SortPolicy, tree::RPlusTree>*,
- NSType<SortPolicy, tree::RPlusPlusTree>*> nSearch;
+ NSType<SortPolicy, tree::RPlusPlusTree>*,
+ NSType<SortPolicy, tree::SPTree>*> nSearch;
public:
/**
@@ -319,6 +338,10 @@ class NSModel
size_t LeafSize() const { return leafSize; }
size_t& LeafSize() { return leafSize; }
+ //! Expose tau.
+ double Tau() const { return tau; }
+ double& Tau() { return tau; }
+
//! Expose treeType.
TreeTypes TreeType() const { return treeType; }
TreeTypes& TreeType() { return treeType; }
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index acbed6c..71430e6 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -33,12 +33,14 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
- const size_t leafSize) :
+ const size_t leafSize,
+ const double tau) :
querySet(querySet),
k(k),
neighbors(neighbors),
distances(distances),
- leafSize(leafSize)
+ leafSize(leafSize),
+ tau(tau)
{}
//! Default Bichromatic neighbor search on the given NSType instance.
@@ -71,6 +73,25 @@ void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
+//! Bichromatic neighbor search on the given NSType specialized for SPTrees.
+template<typename SortPolicy>
+void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::SPTree>* ns) const
+{
+ if (ns)
+ {
+ if (!ns->Naive() && !ns->SingleMode())
+ {
+ typename NSTypeT<tree::SPTree>::Tree queryTree(std::move(querySet), tau,
+ leafSize);
+ ns->Search(&queryTree, k, neighbors, distances);
+ }
+ else
+ ns->Search(querySet, k, neighbors, distances);
+ }
+ else
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
@@ -102,9 +123,11 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
//! Save parameters for Train.
template<typename SortPolicy>
TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
- const size_t leafSize) :
+ const size_t leafSize,
+ const double tau) :
referenceSet(std::move(referenceSet)),
- leafSize(leafSize)
+ leafSize(leafSize),
+ tau(tau)
{}
//! Default Train on the given NSType instance.
@@ -137,6 +160,27 @@ void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::BallTree>* ns) const
throw std::runtime_error("no neighbor search model initialized");
}
+//! Train on the given NSType specialized for SPTrees.
+template<typename SortPolicy>
+void TrainVisitor<SortPolicy>::operator ()(NSTypeT<tree::SPTree>* ns) const
+{
+ if (ns)
+ {
+ if (ns->Naive())
+ ns->Train(std::move(referenceSet));
+ else
+ {
+ typename NSTypeT<tree::SPTree>::Tree* tree = new typename
+ NSTypeT<tree::SPTree>::Tree(std::move(referenceSet), tau, leafSize);
+ ns->Train(tree);
+ // Give the model ownership of the tree.
+ ns->treeOwner = true;
+ }
+ }
+ else
+ throw std::runtime_error("no neighbor search model initialized");
+}
+
//! Train on the given NSType considering the leafSize.
template<typename SortPolicy>
template<typename NSType>
@@ -209,6 +253,8 @@ void DeleteVisitor::operator()(NSType* ns) const
template<typename SortPolicy>
NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
treeType(treeType),
+ leafSize(20),
+ tau(0),
randomBasis(randomBasis)
{
// Nothing to do.
@@ -249,6 +295,8 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
const unsigned int /* version */)
{
ar & data::CreateNVP(treeType, "treeType");
+ ar & data::CreateNVP(leafSize, "leafSize");
+ ar & data::CreateNVP(tau, "tau");
ar & data::CreateNVP(randomBasis, "randomBasis");
ar & data::CreateNVP(q, "q");
@@ -313,6 +361,7 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
const bool singleMode,
const double epsilon)
{
+ this->leafSize = leafSize;
// Initialize random basis if necessary.
if (randomBasis)
{
@@ -394,9 +443,13 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
nSearch = new NSType<SortPolicy, tree::RPlusPlusTree>(naive, singleMode,
epsilon);
break;
+ case SPILL_TREE:
+ nSearch = new NSType<SortPolicy, tree::SPTree>(naive, singleMode,
+ epsilon);
+ break;
}
- TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize);
+ TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau);
boost::apply_visitor(tn, nSearch);
if (!naive)
@@ -429,7 +482,7 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
<< std::endl;
BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
- leafSize);
+ leafSize, tau);
boost::apply_visitor(search, nSearch);
}
@@ -478,6 +531,8 @@ std::string NSModel<SortPolicy>::TreeName() const
return "R+ tree";
case R_PLUS_PLUS_TREE:
return "R++ tree";
+ case SPILL_TREE:
+ return "Spill tree";
default:
return "unknown tree";
}
More information about the mlpack-git
mailing list