[mlpack-git] master: Add rho and leafSize members on SpillSearch, and a command line paramenter '--rho' for mlpack_knn. (fa81be4)
gitdub at mlpack.org
gitdub at mlpack.org
Thu Aug 11 11:16:49 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/0f4b25acd6aaa14294c044874ba6cc0751712baa...0a19d07bd39e6223991976474bc79671ba8aa0f0
>---------------------------------------------------------------
commit fa81be4b559909c3c066645dc707efc5535e9ca6
Author: MarcosPividori <marcos.pividori at gmail.com>
Date: Thu Aug 11 12:16:49 2016 -0300
Add rho and leafSize members on SpillSearch, and a command line paramenter '--rho' for mlpack_knn.
>---------------------------------------------------------------
fa81be4b559909c3c066645dc707efc5535e9ca6
src/mlpack/methods/neighbor_search/knn_main.cpp | 14 ++++++++
src/mlpack/methods/neighbor_search/ns_model.hpp | 18 ++++++++--
.../methods/neighbor_search/ns_model_impl.hpp | 24 ++++++++-----
.../methods/neighbor_search/spill_search.hpp | 28 +++++++++++++++
.../methods/neighbor_search/spill_search_impl.hpp | 42 +++++++++++++++++-----
5 files changed, 106 insertions(+), 20 deletions(-)
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index c201db9..97db49d 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -70,6 +70,7 @@ PARAM_INT_IN("leaf_size", "Leaf size for tree building (used for kd-trees, R "
"trees, R* trees, X trees, Hilbert R trees, R+ trees, R++ trees, and Spill "
"trees).", "l", 20);
PARAM_DOUBLE_IN("tau", "Overlapping size (for spill trees).", "u", 0);
+PARAM_DOUBLE_IN("rho", "Balance threshold (for spill trees).", "b", 0.7);
PARAM_FLAG("random_basis", "Before tree-building, project the data onto a "
"random orthogonal basis.", "R");
@@ -117,6 +118,9 @@ int main(int argc, char *argv[])
if (CLI::HasParam("tau"))
Log::Warn << "--tau (-u) will be ignored because --input_model_file"
<< " is specified." << endl;
+ if (CLI::HasParam("rho"))
+ Log::Warn << "--rho (-b) will be ignored because --input_model_file"
+ << " is specified." << endl;
if (CLI::HasParam("random_basis"))
Log::Warn << "--random_basis (-R) will be ignored because "
<< "--input_model_file is specified." << endl;
@@ -157,6 +161,14 @@ int main(int argc, char *argv[])
if (CLI::HasParam("tau") && "spill" != CLI::GetParam<string>("tree_type"))
Log::Fatal << "Tau parameter is only valid for spill trees." << endl;
+ // Sanity check on rho.
+ const double rho = CLI::GetParam<double>("rho");
+ if (rho < 0 || rho > 1)
+ Log::Fatal << "Invalid rho: " << rho << ". Must be in the range [0,1]. "
+ << endl;
+ if (CLI::HasParam("rho") && "spill" != CLI::GetParam<string>("tree_type"))
+ Log::Fatal << "Rho parameter is only valid for spill trees." << endl;
+
// Sanity check on epsilon.
const double epsilon = CLI::GetParam<double>("epsilon");
if (epsilon < 0)
@@ -204,6 +216,7 @@ int main(int argc, char *argv[])
knn.RandomBasis() = randomBasis;
knn.LeafSize() = size_t(lsInt);
knn.Tau() = tau;
+ knn.Rho() = rho;
arma::mat referenceSet;
data::Load(referenceFile, referenceSet, true);
@@ -231,6 +244,7 @@ int main(int argc, char *argv[])
knn.LeafSize() = size_t(lsInt);
knn.Epsilon() = epsilon;
knn.Tau() = tau;
+ knn.Rho() = rho;
}
// Perform search, if desired.
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index 03855a1..4f34e07 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -110,6 +110,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
const size_t leafSize;
//! Overlapping size (for spill trees).
const double tau;
+ //! Balance threshold (for spill trees).
+ const double rho;
//! Bichromatic neighbor search on the given NSType considering the leafSize.
template<typename NSType>
@@ -143,7 +145,8 @@ class BiSearchVisitor : public boost::static_visitor<void>
arma::Mat<size_t>& neighbors,
arma::mat& distances,
const size_t leafSize,
- const double tau);
+ const double tau,
+ const double rho);
};
/**
@@ -162,6 +165,8 @@ class TrainVisitor : public boost::static_visitor<void>
size_t leafSize;
//! Overlapping size (for spill trees).
const double tau;
+ //! Balance threshold (for spill trees).
+ const double rho;
//! Train on the given NSType considering the leafSize.
template<typename NSType>
@@ -190,10 +195,11 @@ class TrainVisitor : public boost::static_visitor<void>
void operator()(NSSpillType* ns) const;
//! Construct the TrainVisitor object with the given reference set, leafSize
- //! for BinarySpaceTrees, and tau for spill trees.
+ //! for BinarySpaceTrees, and tau and rho for spill trees.
TrainVisitor(arma::mat&& referenceSet,
const size_t leafSize,
- const double tau);
+ const double tau,
+ const double rho);
};
/**
@@ -289,6 +295,8 @@ class NSModel
//! Overlapping size (for spill trees).
double tau;
+ //! Balance threshold (for spill trees).
+ double rho;
//! If true, random projections are used.
bool randomBasis;
@@ -348,6 +356,10 @@ class NSModel
double Tau() const { return tau; }
double& Tau() { return tau; }
+ //! Expose rho.
+ double Rho() const { return rho; }
+ double& Rho() { return rho; }
+
//! Expose treeType.
TreeTypes TreeType() const { return treeType; }
TreeTypes& TreeType() { return treeType; }
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index d46b7e5..29cc167 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -34,13 +34,15 @@ BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
arma::Mat<size_t>& neighbors,
arma::mat& distances,
const size_t leafSize,
- const double tau) :
+ const double tau,
+ const double rho) :
querySet(querySet),
k(k),
neighbors(neighbors),
distances(distances),
leafSize(leafSize),
- tau(tau)
+ tau(tau),
+ rho(rho)
{}
//! Default Bichromatic neighbor search on the given NSType instance.
@@ -84,7 +86,7 @@ void BiSearchVisitor<SortPolicy>::operator()(NSSpillType* ns) const
// For Dual Tree Search on SpillTrees, the queryTree must be built with
// non overlapping (tau = 0).
typename NSSpillType::Tree queryTree(std::move(querySet), 0 /* tau*/,
- leafSize);
+ leafSize, rho);
ns->Search(&queryTree, k, neighbors, distances);
}
else
@@ -126,10 +128,12 @@ void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
template<typename SortPolicy>
TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
const size_t leafSize,
- const double tau) :
+ const double tau,
+ const double rho) :
referenceSet(std::move(referenceSet)),
leafSize(leafSize),
- tau(tau)
+ tau(tau),
+ rho(rho)
{}
//! Default Train on the given NSType instance.
@@ -173,7 +177,7 @@ void TrainVisitor<SortPolicy>::operator ()(NSSpillType* ns) const
else
{
typename NSSpillType::Tree* tree = new typename NSSpillType::Tree(
- std::move(referenceSet), tau, leafSize);
+ std::move(referenceSet), tau, leafSize, rho);
ns->Train(tree);
// Give the model ownership of the tree.
ns->neighborSearch.treeOwner = true;
@@ -257,6 +261,7 @@ NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
treeType(treeType),
leafSize(20),
tau(0),
+ rho(0.7),
randomBasis(randomBasis)
{
// Nothing to do.
@@ -307,6 +312,7 @@ void NSModel<SortPolicy>::Serialize(Archive& ar,
ar & data::CreateNVP(treeType, "treeType");
ar & data::CreateNVP(leafSize, "leafSize");
ar & data::CreateNVP(tau, "tau");
+ ar & data::CreateNVP(rho, "rho");
ar & data::CreateNVP(randomBasis, "randomBasis");
ar & data::CreateNVP(q, "q");
@@ -454,11 +460,11 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
epsilon);
break;
case SPILL_TREE:
- nSearch = new NSSpillType(naive, singleMode, tau, epsilon);
+ nSearch = new NSSpillType(naive, singleMode, tau, leafSize, rho, epsilon);
break;
}
- TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau);
+ TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
boost::apply_visitor(tn, nSearch);
if (!naive)
@@ -491,7 +497,7 @@ void NSModel<SortPolicy>::Search(arma::mat&& querySet,
<< std::endl;
BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
- leafSize, tau);
+ leafSize, tau, rho);
boost::apply_visitor(search, nSearch);
}
diff --git a/src/mlpack/methods/neighbor_search/spill_search.hpp b/src/mlpack/methods/neighbor_search/spill_search.hpp
index e1d4008..8437012 100644
--- a/src/mlpack/methods/neighbor_search/spill_search.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search.hpp
@@ -68,6 +68,8 @@ class SpillSearch
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
* @param tau Overlapping size (non-negative).
+ * @param leafSize Max size of each leaf in the tree.
+ * @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
@@ -75,6 +77,8 @@ class SpillSearch
const bool naive = false,
const bool singleMode = false,
const double tau = 0,
+ const double leafSize = 20,
+ const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());
@@ -91,6 +95,8 @@ class SpillSearch
* @param singleMode If true, single-tree search will be used (as opposed to
* dual-tree search).
* @param tau Overlapping size (non-negative).
+ * @param leafSize Max size of each leaf in the tree.
+ * @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
@@ -98,6 +104,8 @@ class SpillSearch
const bool naive = false,
const bool singleMode = false,
const double tau = 0,
+ const double leafSize = 20,
+ const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());
@@ -113,12 +121,16 @@ class SpillSearch
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
* @param tau Overlapping size (non-negative).
+ * @param leafSize Max size of each leaf in the tree.
+ * @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated distance metric.
*/
SpillSearch(Tree* referenceTree,
const bool singleMode = false,
const double tau = 0,
+ const double leafSize = 20,
+ const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());
@@ -131,12 +143,16 @@ class SpillSearch
* @param singleMode Whether single-tree computation should be used (as
* opposed to dual-tree computation).
* @param tau Overlapping size (non-negative).
+ * @param leafSize Max size of each leaf in the tree.
+ * @param rho Balance threshold (non-negative).
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated metric.
*/
SpillSearch(const bool naive = false,
const bool singleMode = false,
const double tau = 0,
+ const double leafSize = 20,
+ const double rho = 0.7,
const double epsilon = 0,
const MetricType metric = MetricType());
@@ -262,6 +278,12 @@ class SpillSearch
//! Access the overlapping size.
double Tau() const { return tau; }
+ //! Access the balance threshold.
+ double Rho() const { return rho; }
+
+ //! Access the leaf size.
+ double LeafSize() const { return leafSize; }
+
//! Access the reference dataset.
const MatType& ReferenceSet() const { return neighborSearch.ReferenceSet(); }
@@ -277,6 +299,12 @@ class SpillSearch
//! Overlapping size.
double tau;
+ //! Balance threshold.
+ double rho;
+
+ //! Max leaf size.
+ double leafSize;
+
//! The NSModel class should have access to internal members.
template<typename SortPolicy>
friend class TrainVisitor;
diff --git a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
index 0a06037..5638959 100644
--- a/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/spill_search_impl.hpp
@@ -25,13 +25,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
const bool naive,
const bool singleMode,
const double tau,
+ const double leafSize,
+ const double rho,
const double epsilon,
const MetricType metric) :
neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau)
+ tau(tau),
+ rho(rho),
+ leafSize(leafSize)
{
if (tau < 0)
throw std::invalid_argument("tau must be non-negative");
+ if (rho < 0 || rho > 1)
+ throw std::invalid_argument("rho must be in the range [0,1]");
Train(referenceSetIn);
}
@@ -45,13 +51,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
const bool naive,
const bool singleMode,
const double tau,
+ const double leafSize,
+ const double rho,
const double epsilon,
const MetricType metric) :
neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau)
+ tau(tau),
+ rho(rho),
+ leafSize(leafSize)
{
if (tau < 0)
throw std::invalid_argument("tau must be non-negative");
+ if (rho < 0 || rho > 1)
+ throw std::invalid_argument("rho must be in the range [0,1]");
Train(std::move(referenceSetIn));
}
@@ -64,13 +76,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
Tree* referenceTree,
const bool singleMode,
const double tau,
+ const double leafSize,
+ const double rho,
const double epsilon,
const MetricType metric) :
neighborSearch(singleMode, epsilon, metric),
- tau(tau)
+ tau(tau),
+ rho(rho),
+ leafSize(leafSize)
{
if (tau < 0)
throw std::invalid_argument("tau must be non-negative");
+ if (rho < 0 || rho > 1)
+ throw std::invalid_argument("rho must be in the range [0,1]");
Train(referenceTree);
}
@@ -83,13 +101,19 @@ SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::SpillSearch(
const bool naive,
const bool singleMode,
const double tau,
+ const double leafSize,
+ const double rho,
const double epsilon,
const MetricType metric) :
neighborSearch(naive, singleMode, epsilon, metric),
- tau(tau)
+ tau(tau),
+ rho(rho),
+ leafSize(leafSize)
{
if (tau < 0)
throw std::invalid_argument("tau must be non-negative");
+ if (rho < 0 || rho > 1)
+ throw std::invalid_argument("rho must be in the range [0,1]");
}
// Clean memory.
@@ -115,7 +139,7 @@ Train(const MatType& referenceSet)
else
{
// Build reference tree with proper value for tau.
- Tree* tree = new Tree(referenceSet, tau);
+ Tree* tree = new Tree(referenceSet, tau, leafSize, rho);
neighborSearch.Train(tree);
// Give the model ownership of the tree.
neighborSearch.treeOwner = true;
@@ -134,7 +158,7 @@ Train(MatType&& referenceSetIn)
else
{
// Build reference tree with proper value for tau.
- Tree* tree = new Tree(std::move(referenceSetIn), tau);
+ Tree* tree = new Tree(std::move(referenceSetIn), tau, leafSize, rho);
neighborSearch.Train(tree);
// Give the model ownership of the tree.
neighborSearch.treeOwner = true;
@@ -167,7 +191,7 @@ Search(const MatType& querySet,
{
// For Dual Tree Search on SpillTrees, the queryTree must be built with non
// overlapping (tau = 0).
- Tree queryTree(querySet, 0 /* tau */);
+ Tree queryTree(querySet, 0 /* tau */, leafSize, rho);
neighborSearch.Search(&queryTree, k, neighbors, distances);
}
}
@@ -201,7 +225,7 @@ Search(const size_t k,
// For Dual Tree Search on SpillTrees, the queryTree must be built with non
// overlapping (tau = 0). If the referenceTree was built with a non-zero
// value for tau, we need to build a new queryTree.
- Tree queryTree(ReferenceSet(), 0 /* tau */);
+ Tree queryTree(ReferenceSet(), 0 /* tau */, leafSize, rho);
neighborSearch.Search(&queryTree, k, neighbors, distances, true);
}
}
@@ -217,6 +241,8 @@ void SpillSearch<MetricType, MatType, HyperplaneType, SplitType>::
{
ar & data::CreateNVP(neighborSearch, "neighborSearch");
ar & data::CreateNVP(tau, "tau");
+ ar & data::CreateNVP(rho, "rho");
+ ar & data::CreateNVP(leafSize, "leafSize");
}
} // namespace neighbor
More information about the mlpack-git
mailing list