[mlpack-git] master: Added tests for the random projection tree split method. Various fixes of random projection trees. (bc8994a)
gitdub at mlpack.org
gitdub at mlpack.org
Sun Aug 7 13:11:41 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/a7794bde8082c691553152393e1e230098f5e920...87776e52cf9ead63fa458118a0cfd2fe46b23466
>---------------------------------------------------------------
commit bc8994a37481fac281dfce43edb265d454afa6a9
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Sun Aug 7 20:11:41 2016 +0300
Added tests for the random projection tree split method.
Various fixes of random projection trees.
>---------------------------------------------------------------
bc8994a37481fac281dfce43edb265d454afa6a9
.../tree/binary_space_tree/rp_tree_max_split.hpp | 7 -
.../binary_space_tree/rp_tree_max_split_impl.hpp | 32 +----
.../tree/binary_space_tree/rp_tree_mean_split.hpp | 3 +
.../binary_space_tree/rp_tree_mean_split_impl.hpp | 45 ++----
src/mlpack/core/tree/binary_space_tree/traits.hpp | 8 +-
src/mlpack/core/tree/binary_space_tree/typedef.hpp | 57 +++++++-
src/mlpack/methods/neighbor_search/kfn_main.cpp | 14 +-
src/mlpack/methods/neighbor_search/knn_main.cpp | 14 +-
src/mlpack/methods/neighbor_search/ns_model.hpp | 8 +-
.../methods/neighbor_search/ns_model_impl.hpp | 16 +--
.../methods/range_search/range_search_main.cpp | 12 +-
src/mlpack/methods/range_search/rs_model.cpp | 54 +++----
src/mlpack/methods/range_search/rs_model.hpp | 12 +-
src/mlpack/methods/range_search/rs_model_impl.hpp | 48 +++----
src/mlpack/tests/aknn_test.cpp | 16 +--
src/mlpack/tests/knn_test.cpp | 16 +--
src/mlpack/tests/range_search_test.cpp | 16 +--
src/mlpack/tests/tree_test.cpp | 155 ++++++++++++++++++++-
18 files changed, 342 insertions(+), 191 deletions(-)
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
index f2cddb2..5c31e84 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split.hpp
@@ -68,13 +68,6 @@ class RPTreeMaxSplit
private:
/**
- * Get a random unit vector of size direction.n_elem.
- *
- * @param direction The variable into which the method saves the vector.
- */
- static void GetRandomDirection(arma::Col<ElemType>& direction);
-
- /**
* Get random deviation from the median of points multiplied by the direction
* obtained in GetRandomDirection().
*
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split_impl.hpp
index 223bdda..d5cd714 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_max_split_impl.hpp
@@ -9,6 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MAX_SPLIT_IMPL_HPP
#include "rp_tree_max_split.hpp"
+#include "rp_tree_mean_split.hpp"
namespace mlpack {
namespace tree {
@@ -23,7 +24,8 @@ bool RPTreeMaxSplit<BoundType, MatType>::SplitNode(const BoundType& /* bound */,
splitInfo.direction.zeros(data.n_rows);
// Get the normal to the hyperplane.
- GetRandomDirection(splitInfo.direction);
+ RPTreeMeanSplit<BoundType, MatType>::GetRandomDirection(
+ splitInfo.direction);
// Get the value according to which we will perform the split.
if (!GetSplitVal(data, begin, count, splitInfo.direction, splitInfo.splitVal))
@@ -33,34 +35,6 @@ bool RPTreeMaxSplit<BoundType, MatType>::SplitNode(const BoundType& /* bound */,
}
template<typename BoundType, typename MatType>
-void RPTreeMaxSplit<BoundType, MatType>::GetRandomDirection(
- arma::Col<ElemType>& direction)
-{
- arma::Col<ElemType> origin;
-
- origin.zeros(direction.n_rows);
-
- for (size_t k = 0; k < direction.n_rows; k++)
- direction[k] = math::Random(-1.0, 1.0);
-
- ElemType length = metric::EuclideanDistance::Evaluate(origin, direction);
-
- if (length > 0)
- direction /= length;
- else
- {
- // If the vector is equal to 0, choose an arbitrary dimension.
- size_t k = math::RandInt(direction.n_rows);
-
- direction[k] = 1.0;
-
- length = metric::EuclideanDistance::Evaluate(origin, direction);
-
- direction[k] /= length;
- }
-}
-
-template<typename BoundType, typename MatType>
typename MatType::elem_type RPTreeMaxSplit<BoundType, MatType>::
GetRandomDeviation(const MatType& data,
const size_t begin,
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
index 1f4d853..ad198a2 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split.hpp
@@ -9,6 +9,7 @@
#define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
#include <mlpack/core.hpp>
+#include "rp_tree_max_split.hpp"
namespace mlpack {
namespace tree /** Trees and tree-building procedures. */ {
@@ -133,6 +134,8 @@ class RPTreeMeanSplit
const arma::uvec& samples,
arma::Col<ElemType>& mean,
ElemType& splitVal);
+
+ friend RPTreeMaxSplit<BoundType, MatType>;
};
} // namespace tree
diff --git a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp
index c23db87..ec5ad1d 100644
--- a/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/rp_tree_mean_split_impl.hpp
@@ -102,28 +102,21 @@ template<typename BoundType, typename MatType>
void RPTreeMeanSplit<BoundType, MatType>::GetRandomDirection(
arma::Col<ElemType>& direction)
{
- arma::Col<ElemType> origin;
+ direction.randu(); // Fill with [0, 1].
+ direction -= 0.5; // Shift to [-0.5, 0.5].
- origin.zeros(direction.n_rows);
+ // Get the length of the vector.
+ const ElemType norm = arma::norm(direction);
- for (size_t k = 0; k < direction.n_rows; k++)
- direction[k] = math::Random(-1.0, 1.0);
-
- ElemType length = metric::EuclideanDistance::Evaluate(origin, direction);
-
- if (length > 0)
- direction /= length;
- else
+ if (norm == 0)
{
// If the vector is equal to 0, choose an arbitrary dimension.
size_t k = math::RandInt(direction.n_rows);
direction[k] = 1.0;
-
- length = metric::EuclideanDistance::Evaluate(origin, direction);
-
- direction[k] /= length;
}
+ else
+ direction /= norm; // Normalize the vector.
}
template<typename BoundType, typename MatType>
@@ -133,17 +126,15 @@ bool RPTreeMeanSplit<BoundType, MatType>::GetDotMedian(
const arma::Col<ElemType>& direction,
ElemType& splitVal)
{
- std::vector<ElemType> values(samples.n_elem);
+ arma::Col<ElemType> values(samples.n_elem);
for (size_t k = 0; k < samples.n_elem; k++)
values[k] = arma::dot(data.col(samples[k]), direction);
- std::sort(values.begin(), values.end());
-
- if (values[0] == values[values.size() - 1])
+ if (arma::min(values) == arma::max(values))
return false;
- splitVal = values[values.size() / 2];
+ splitVal = arma::median(values);
return true;
}
@@ -155,15 +146,11 @@ bool RPTreeMeanSplit<BoundType, MatType>::GetMeanMedian(
arma::Col<ElemType>& mean,
ElemType& splitVal)
{
- std::vector<ElemType> values(samples.n_elem);
+ arma::Col<ElemType> values(samples.n_elem);
- mean.zeros(data.n_rows);
+ mean = arma::mean(data.cols(samples));
- for (size_t k = 0; k < samples.n_elem; k++)
- mean += data.col(samples[k]);
-
- mean /= samples.n_elem;
- arma::Col<ElemType> tmp(data.n_elem);
+ arma::Col<ElemType> tmp(data.n_rows);
for (size_t k = 0; k < samples.n_elem; k++)
{
@@ -173,12 +160,10 @@ bool RPTreeMeanSplit<BoundType, MatType>::GetMeanMedian(
values[k] = arma::dot(tmp, tmp);
}
- std::sort(values.begin(), values.end());
-
- if (values[0] == values[values.size() - 1])
+ if (arma::min(values) == arma::max(values))
return false;
- splitVal = values[values.size() / 2];
+ splitVal = arma::median(values);
return true;
}
diff --git a/src/mlpack/core/tree/binary_space_tree/traits.hpp b/src/mlpack/core/tree/binary_space_tree/traits.hpp
index 88cbcae..ade4356 100644
--- a/src/mlpack/core/tree/binary_space_tree/traits.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/traits.hpp
@@ -65,9 +65,7 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
{
public:
/**
- * Each binary space tree node has two children which represent
- * non-overlapping subsets of the space which the node represents. Therefore,
- * children are not overlapping.
+ * Children of a random projection tree node may overlap.
*/
static const bool HasOverlappingChildren = true;
@@ -101,9 +99,7 @@ class TreeTraits<BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
{
public:
/**
- * Each binary space tree node has two children which represent
- * non-overlapping subsets of the space which the node represents. Therefore,
- * children are not overlapping.
+ * Children of a random projection tree node may overlap.
*/
static const bool HasOverlappingChildren = true;
diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
index 7015f6d..cad3af6 100644
--- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp
+++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp
@@ -135,15 +135,68 @@ using MeanSplitBallTree = BinarySpaceTree<MetricType,
bound::BallBound,
MeanSplit>;
+/**
+ * A max-split random projection tree. When recursively splitting nodes, the
+ * MaxSplitRPTree class selects a random hyperplane and splits a node by the
+ * hyperplane. The tree holds points in leaf nodes. In contrast to the k-d tree,
+ * children of a MaxSplitRPTree node may overlap.
+ *
+ * @code
+ * @inproceedings{dasgupta2008,
+ * author = {Dasgupta, Sanjoy and Freund, Yoav},
+ * title = {Random Projection Trees and Low Dimensional Manifolds},
+ * booktitle = {Proceedings of the Fortieth Annual ACM Symposium on Theory of
+ * Computing},
+ * series = {STOC '08},
+ * year = {2008},
+ * pages = {537--546},
+ * numpages = {10},
+ * publisher = {ACM},
+ * address = {New York, NY, USA},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
+
template<typename MetricType, typename StatisticType, typename MatType>
-using RPTreeMax = BinarySpaceTree<MetricType,
+using MaxSplitRPTree = BinarySpaceTree<MetricType,
StatisticType,
MatType,
bound::HRectBound,
RPTreeMaxSplit>;
+/**
+ * A mean-split random projection tree. When recursively splitting nodes, the
+ * RPTree class may perform one of two different kinds of split.
+ * Depending on the diameter and the average distance between points, the node
+ * may be split by a random hyperplane or according to the distance from the
+ * mean point. The tree holds points in leaf nodes. In contrast to the k-d tree,
+ * children of a MaxSplitRPTree node may overlap.
+ *
+ * @code
+ * @inproceedings{dasgupta2008,
+ * author = {Dasgupta, Sanjoy and Freund, Yoav},
+ * title = {Random Projection Trees and Low Dimensional Manifolds},
+ * booktitle = {Proceedings of the Fortieth Annual ACM Symposium on Theory of
+ * Computing},
+ * series = {STOC '08},
+ * year = {2008},
+ * pages = {537--546},
+ * numpages = {10},
+ * publisher = {ACM},
+ * address = {New York, NY, USA},
+ * }
+ * @endcode
+ *
+ * This template typedef satisfies the TreeType policy API.
+ *
+ * @see @ref trees, BinarySpaceTree, BallTree, MeanSplitKDTree
+ */
template<typename MetricType, typename StatisticType, typename MatType>
-using RPTreeMean = BinarySpaceTree<MetricType,
+using RPTree = BinarySpaceTree<MetricType,
StatisticType,
MatType,
bound::HRectBound,
diff --git a/src/mlpack/methods/neighbor_search/kfn_main.cpp b/src/mlpack/methods/neighbor_search/kfn_main.cpp
index 1dc93e9..628d26b 100644
--- a/src/mlpack/methods/neighbor_search/kfn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/kfn_main.cpp
@@ -61,8 +61,8 @@ PARAM_INT("k", "Number of furthest neighbors to find.", "k", 0);
// The user may specify the type of tree to use, and a few pararmeters for tree
// building.
-PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree-max', "
- "'rp-tree-mean', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree', "
+ "'max-split-rp-tree', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
"'r-plus', 'r-plus-plus'.", "t", "kd");
PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, "
"random projection trees, R trees, R* trees, X trees, Hilbert R trees, "
@@ -195,13 +195,13 @@ int main(int argc, char *argv[])
tree = KFNModel::R_PLUS_TREE;
else if (treeType == "r-plus-plus")
tree = KFNModel::R_PLUS_PLUS_TREE;
- else if (treeType == "rp-tree-max")
- tree = KFNModel::RP_TREE_MAX;
- else if (treeType == "rp-tree-mean")
- tree = KFNModel::RP_TREE_MEAN;
+ else if (treeType == "rp-tree")
+ tree = KFNModel::RP_TREE;
+ else if (treeType == "max-split-rp-tree")
+ tree = KFNModel::MAX_SPLIT_RP_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
- << "'kd', 'rp-tree-max', 'rp-tree-mean', 'cover', 'r', 'r-star', "
+ << "'kd', 'rp-tree', 'max-split-rp-tree', 'cover', 'r', 'r-star', "
<< "'x', 'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
kfn.TreeType() = tree;
diff --git a/src/mlpack/methods/neighbor_search/knn_main.cpp b/src/mlpack/methods/neighbor_search/knn_main.cpp
index a4c4aa7..cfecb49 100644
--- a/src/mlpack/methods/neighbor_search/knn_main.cpp
+++ b/src/mlpack/methods/neighbor_search/knn_main.cpp
@@ -62,8 +62,8 @@ PARAM_INT("k", "Number of nearest neighbors to find.", "k", 0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
-PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree-max', "
- "'rp-tree-mean', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree', "
+ "'max-split-rp-tree', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
"'r-plus', 'r-plus-plus'.", "t", "kd");
PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, "
"random projection trees, R trees, R* trees, X trees, Hilbert R trees, "
@@ -180,13 +180,13 @@ int main(int argc, char *argv[])
tree = KNNModel::R_PLUS_TREE;
else if (treeType == "r-plus-plus")
tree = KNNModel::R_PLUS_PLUS_TREE;
- else if (treeType == "rp-tree-max")
- tree = KNNModel::RP_TREE_MAX;
- else if (treeType == "rp-tree-mean")
- tree = KNNModel::RP_TREE_MEAN;
+ else if (treeType == "rp-tree")
+ tree = KNNModel::RP_TREE;
+ else if (treeType == "max-split-rp-tree")
+ tree = KNNModel::MAX_SPLIT_RP_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are "
- << "'kd', 'rp-tree-max', 'rp-tree-mean', 'cover', 'r', 'r-star', "
+ << "'kd', 'rp-tree', 'max-split-rp-tree', 'cover', 'r', 'r-star', "
<< "'x', 'ball', 'hilbert-r', 'r-plus' and 'r-plus-plus'." << endl;
knn.TreeType() = tree;
diff --git a/src/mlpack/methods/neighbor_search/ns_model.hpp b/src/mlpack/methods/neighbor_search/ns_model.hpp
index f43359f..f2e2fbd 100644
--- a/src/mlpack/methods/neighbor_search/ns_model.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model.hpp
@@ -257,8 +257,8 @@ class NSModel
HILBERT_R_TREE,
R_PLUS_TREE,
R_PLUS_PLUS_TREE,
- RP_TREE_MAX,
- RP_TREE_MEAN
+ RP_TREE,
+ MAX_SPLIT_RP_TREE
};
private:
@@ -287,8 +287,8 @@ class NSModel
NSType<SortPolicy, tree::HilbertRTree>*,
NSType<SortPolicy, tree::RPlusTree>*,
NSType<SortPolicy, tree::RPlusPlusTree>*,
- NSType<SortPolicy, tree::RPTreeMax>*,
- NSType<SortPolicy, tree::RPTreeMean>*> nSearch;
+ NSType<SortPolicy, tree::RPTree>*,
+ NSType<SortPolicy, tree::MaxSplitRPTree>*> nSearch;
public:
/**
diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
index 50e0cf0..d6ff6cc 100644
--- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
+++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp
@@ -394,12 +394,12 @@ void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
nSearch = new NSType<SortPolicy, tree::RPlusPlusTree>(naive, singleMode,
epsilon);
break;
- case RP_TREE_MAX:
- nSearch = new NSType<SortPolicy, tree::RPTreeMax>(naive, singleMode,
+ case RP_TREE:
+ nSearch = new NSType<SortPolicy, tree::RPTree>(naive, singleMode,
epsilon);
break;
- case RP_TREE_MEAN:
- nSearch = new NSType<SortPolicy, tree::RPTreeMean>(naive, singleMode,
+ case MAX_SPLIT_RP_TREE:
+ nSearch = new NSType<SortPolicy, tree::MaxSplitRPTree>(naive, singleMode,
epsilon);
break;
}
@@ -486,10 +486,10 @@ std::string NSModel<SortPolicy>::TreeName() const
return "R+ tree";
case R_PLUS_PLUS_TREE:
return "R++ tree";
- case RP_TREE_MAX:
- return "Random projection tree";
- case RP_TREE_MEAN:
- return "Random projection tree";
+ case RP_TREE:
+ return "Random projection tree (mean split)";
+ case MAX_SPLIT_RP_TREE:
+ return "Random projection tree (max split)";
default:
return "unknown tree";
}
diff --git a/src/mlpack/methods/range_search/range_search_main.cpp b/src/mlpack/methods/range_search/range_search_main.cpp
index cb151da..b837b7c 100644
--- a/src/mlpack/methods/range_search/range_search_main.cpp
+++ b/src/mlpack/methods/range_search/range_search_main.cpp
@@ -69,8 +69,8 @@ PARAM_DOUBLE("min", "Lower bound in range.", "L", 0.0);
// The user may specify the type of tree to use, and a few parameters for tree
// building.
-PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree-max', "
- "'rp-tree-mean', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
+PARAM_STRING("tree_type", "Type of tree to use: 'kd', 'rp-tree', "
+ "'max-split-rp-tree', 'cover', 'r', 'r-star', 'x', 'ball', 'hilbert-r', "
"'r-plus', 'r-plus-plus'.", "t", "kd");
PARAM_INT("leaf_size", "Leaf size for tree building (used for kd-trees, "
"random projection trees, R trees, R* trees, X trees, Hilbert R trees, "
@@ -183,10 +183,10 @@ int main(int argc, char *argv[])
tree = RSModel::R_PLUS_TREE;
else if (treeType == "r-plus-plus")
tree = RSModel::R_PLUS_PLUS_TREE;
- else if (treeType == "rp-tree-max")
- tree = RSModel::RP_TREE_MAX;
- else if (treeType == "rp-tree-mean")
- tree = RSModel::RP_TREE_MEAN;
+ else if (treeType == "rp-tree")
+ tree = RSModel::RP_TREE;
+ else if (treeType == "max-split-rp-tree")
+ tree = RSModel::MAX_SPLIT_RP_TREE;
else
Log::Fatal << "Unknown tree type '" << treeType << "; valid choices are "
<< "'kd', 'rp-tree-max', 'rp-tree-mean', 'cover', 'r', 'r-star', "
diff --git a/src/mlpack/methods/range_search/rs_model.cpp b/src/mlpack/methods/range_search/rs_model.cpp
index 906cf17..e3d5a8e 100644
--- a/src/mlpack/methods/range_search/rs_model.cpp
+++ b/src/mlpack/methods/range_search/rs_model.cpp
@@ -26,8 +26,8 @@ RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
hilbertRTreeRS(NULL),
rPlusTreeRS(NULL),
rPlusPlusTreeRS(NULL),
- rpTreeMaxRS(NULL),
- rpTreeMeanRS(NULL)
+ rpTreeRS(NULL),
+ maxSplitPRTreeRS(NULL)
{
// Nothing to do.
}
@@ -139,18 +139,18 @@ void RSModel::BuildModel(arma::mat&& referenceSet,
break;
case R_PLUS_PLUS_TREE:
- rPlusPlusTreeRS = new RSType<tree::RPlusPlusTree>(move(referenceSet), naive,
- singleMode);
+ rPlusPlusTreeRS = new RSType<tree::RPlusPlusTree>(move(referenceSet),
+ naive, singleMode);
break;
- case RP_TREE_MAX:
- rpTreeMaxRS = new RSType<tree::RPTreeMax>(move(referenceSet), naive,
+ case RP_TREE:
+ rpTreeRS = new RSType<tree::RPTree>(move(referenceSet), naive,
singleMode);
break;
- case RP_TREE_MEAN:
- rpTreeMeanRS = new RSType<tree::RPTreeMean>(move(referenceSet), naive,
- singleMode);
+ case MAX_SPLIT_RP_TREE:
+ maxSplitPRTreeRS = new RSType<tree::MaxSplitRPTree>(move(referenceSet),
+ naive, singleMode);
break;
}
@@ -274,12 +274,12 @@ void RSModel::Search(arma::mat&& querySet,
rPlusPlusTreeRS->Search(querySet, range, neighbors, distances);
break;
- case RP_TREE_MAX:
- rpTreeMaxRS->Search(querySet, range, neighbors, distances);
+ case RP_TREE:
+ rpTreeRS->Search(querySet, range, neighbors, distances);
break;
- case RP_TREE_MEAN:
- rpTreeMeanRS->Search(querySet, range, neighbors, distances);
+ case MAX_SPLIT_RP_TREE:
+ maxSplitPRTreeRS->Search(querySet, range, neighbors, distances);
break;
}
}
@@ -336,12 +336,12 @@ void RSModel::Search(const math::Range& range,
rPlusPlusTreeRS->Search(range, neighbors, distances);
break;
- case RP_TREE_MAX:
- rpTreeMaxRS->Search(range, neighbors, distances);
+ case RP_TREE:
+ rpTreeRS->Search(range, neighbors, distances);
break;
- case RP_TREE_MEAN:
- rpTreeMeanRS->Search(range, neighbors, distances);
+ case MAX_SPLIT_RP_TREE:
+ maxSplitPRTreeRS->Search(range, neighbors, distances);
break;
}
}
@@ -369,10 +369,10 @@ std::string RSModel::TreeName() const
return "R+ tree";
case R_PLUS_PLUS_TREE:
return "R++ tree";
- case RP_TREE_MAX:
- return "Random projection tree (max)";
- case RP_TREE_MEAN:
- return "Random projection tree (mean)";
+ case RP_TREE:
+ return "Random projection tree (mean split)";
+ case MAX_SPLIT_RP_TREE:
+ return "Random projection tree (max split)";
default:
return "unknown tree";
}
@@ -399,10 +399,10 @@ void RSModel::CleanMemory()
delete rPlusTreeRS;
if (rPlusPlusTreeRS)
delete rPlusPlusTreeRS;
- if (rpTreeMaxRS)
- delete rpTreeMaxRS;
- if (rpTreeMeanRS)
- delete rpTreeMeanRS;
+ if (rpTreeRS)
+ delete rpTreeRS;
+ if (maxSplitPRTreeRS)
+ delete maxSplitPRTreeRS;
kdTreeRS = NULL;
coverTreeRS = NULL;
@@ -413,6 +413,6 @@ void RSModel::CleanMemory()
hilbertRTreeRS = NULL;
rPlusTreeRS = NULL;
rPlusPlusTreeRS = NULL;
- rpTreeMaxRS = NULL;
- rpTreeMeanRS = NULL;
+ rpTreeRS = NULL;
+ maxSplitPRTreeRS = NULL;
}
diff --git a/src/mlpack/methods/range_search/rs_model.hpp b/src/mlpack/methods/range_search/rs_model.hpp
index 3ab9802..e5c5335 100644
--- a/src/mlpack/methods/range_search/rs_model.hpp
+++ b/src/mlpack/methods/range_search/rs_model.hpp
@@ -33,8 +33,8 @@ class RSModel
HILBERT_R_TREE,
R_PLUS_TREE,
R_PLUS_PLUS_TREE,
- RP_TREE_MAX,
- RP_TREE_MEAN
+ RP_TREE,
+ MAX_SPLIT_RP_TREE
};
private:
@@ -71,12 +71,12 @@ class RSModel
RSType<tree::RPlusTree>* rPlusTreeRS;
//! R++ tree based range search object (NULL if not in use).
RSType<tree::RPlusPlusTree>* rPlusPlusTreeRS;
- //! Random projection tree (max) based range search object
- //! (NULL if not in use).
- RSType<tree::RPTreeMax>* rpTreeMaxRS;
//! Random projection tree (mean) based range search object
//! (NULL if not in use).
- RSType<tree::RPTreeMean>* rpTreeMeanRS;
+ RSType<tree::RPTree>* rpTreeRS;
+ //! Random projection tree (max) based range search object
+ //! (NULL if not in use).
+ RSType<tree::MaxSplitRPTree>* maxSplitPRTreeRS;
public:
/**
diff --git a/src/mlpack/methods/range_search/rs_model_impl.hpp b/src/mlpack/methods/range_search/rs_model_impl.hpp
index c074279..5f67d17 100644
--- a/src/mlpack/methods/range_search/rs_model_impl.hpp
+++ b/src/mlpack/methods/range_search/rs_model_impl.hpp
@@ -66,12 +66,12 @@ void RSModel::Serialize(Archive& ar, const unsigned int /* version */)
ar & CreateNVP(rPlusPlusTreeRS, "range_search_model");
break;
- case RP_TREE_MAX:
- ar & CreateNVP(rpTreeMaxRS, "range_search_model");
+ case RP_TREE:
+ ar & CreateNVP(rpTreeRS, "range_search_model");
break;
- case RP_TREE_MEAN:
- ar & CreateNVP(rpTreeMeanRS, "range_search_model");
+ case MAX_SPLIT_RP_TREE:
+ ar & CreateNVP(maxSplitPRTreeRS, "range_search_model");
break;
}
}
@@ -96,10 +96,10 @@ inline const arma::mat& RSModel::Dataset() const
return rPlusTreeRS->ReferenceSet();
else if (rPlusPlusTreeRS)
return rPlusPlusTreeRS->ReferenceSet();
- else if (rpTreeMaxRS)
- return rpTreeMaxRS->ReferenceSet();
- else if (rpTreeMeanRS)
- return rpTreeMeanRS->ReferenceSet();
+ else if (rpTreeRS)
+ return rpTreeRS->ReferenceSet();
+ else if (maxSplitPRTreeRS)
+ return maxSplitPRTreeRS->ReferenceSet();
throw std::runtime_error("no range search model initialized");
}
@@ -124,10 +124,10 @@ inline bool RSModel::SingleMode() const
return rPlusTreeRS->SingleMode();
else if (rPlusPlusTreeRS)
return rPlusPlusTreeRS->SingleMode();
- else if (rpTreeMaxRS)
- return rpTreeMaxRS->SingleMode();
- else if (rpTreeMeanRS)
- return rpTreeMeanRS->SingleMode();
+ else if (rpTreeRS)
+ return rpTreeRS->SingleMode();
+ else if (maxSplitPRTreeRS)
+ return maxSplitPRTreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -152,10 +152,10 @@ inline bool& RSModel::SingleMode()
return rPlusTreeRS->SingleMode();
else if (rPlusPlusTreeRS)
return rPlusPlusTreeRS->SingleMode();
- else if (rpTreeMaxRS)
- return rpTreeMaxRS->SingleMode();
- else if (rpTreeMeanRS)
- return rpTreeMeanRS->SingleMode();
+ else if (rpTreeRS)
+ return rpTreeRS->SingleMode();
+ else if (maxSplitPRTreeRS)
+ return maxSplitPRTreeRS->SingleMode();
throw std::runtime_error("no range search model initialized");
}
@@ -180,10 +180,10 @@ inline bool RSModel::Naive() const
return rPlusTreeRS->Naive();
else if (rPlusPlusTreeRS)
return rPlusPlusTreeRS->Naive();
- else if (rpTreeMaxRS)
- return rpTreeMaxRS->Naive();
- else if (rpTreeMeanRS)
- return rpTreeMeanRS->Naive();
+ else if (rpTreeRS)
+ return rpTreeRS->Naive();
+ else if (maxSplitPRTreeRS)
+ return maxSplitPRTreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
@@ -208,10 +208,10 @@ inline bool& RSModel::Naive()
return rPlusTreeRS->Naive();
else if (rPlusPlusTreeRS)
return rPlusPlusTreeRS->Naive();
- else if (rpTreeMaxRS)
- return rpTreeMaxRS->Naive();
- else if (rpTreeMeanRS)
- return rpTreeMeanRS->Naive();
+ else if (rpTreeRS)
+ return rpTreeRS->Naive();
+ else if (maxSplitPRTreeRS)
+ return maxSplitPRTreeRS->Naive();
throw std::runtime_error("no range search model initialized");
}
diff --git a/src/mlpack/tests/aknn_test.cpp b/src/mlpack/tests/aknn_test.cpp
index 47bda1b..8da2476 100644
--- a/src/mlpack/tests/aknn_test.cpp
+++ b/src/mlpack/tests/aknn_test.cpp
@@ -306,10 +306,10 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false);
models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE, true);
+ models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE, false);
+ models[20] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 3; ++j)
{
@@ -379,10 +379,10 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false);
models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE, true);
+ models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE, false);
+ models[20] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
diff --git a/src/mlpack/tests/knn_test.cpp b/src/mlpack/tests/knn_test.cpp
index d9303a1..fe91ab3 100644
--- a/src/mlpack/tests/knn_test.cpp
+++ b/src/mlpack/tests/knn_test.cpp
@@ -996,10 +996,10 @@ BOOST_AUTO_TEST_CASE(KNNModelTest)
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false);
models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE, true);
+ models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE, false);
+ models[20] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1072,10 +1072,10 @@ BOOST_AUTO_TEST_CASE(KNNModelMonochromaticTest)
models[15] = KNNModel(KNNModel::TreeTypes::R_PLUS_TREE, false);
models[16] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = KNNModel(KNNModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = KNNModel(KNNModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = KNNModel(KNNModel::TreeTypes::RP_TREE, true);
+ models[19] = KNNModel(KNNModel::TreeTypes::RP_TREE, false);
+ models[20] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = KNNModel(KNNModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
diff --git a/src/mlpack/tests/range_search_test.cpp b/src/mlpack/tests/range_search_test.cpp
index ec03278..217288c 100644
--- a/src/mlpack/tests/range_search_test.cpp
+++ b/src/mlpack/tests/range_search_test.cpp
@@ -1268,10 +1268,10 @@ BOOST_AUTO_TEST_CASE(RSModelTest)
models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false);
models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = RSModel(RSModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = RSModel(RSModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = RSModel(RSModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = RSModel(RSModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = RSModel(RSModel::TreeTypes::RP_TREE, true);
+ models[19] = RSModel(RSModel::TreeTypes::RP_TREE, false);
+ models[20] = RSModel(RSModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = RSModel(RSModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
@@ -1348,10 +1348,10 @@ BOOST_AUTO_TEST_CASE(RSModelMonochromaticTest)
models[15] = RSModel(RSModel::TreeTypes::R_PLUS_TREE, false);
models[16] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, true);
models[17] = RSModel(RSModel::TreeTypes::R_PLUS_PLUS_TREE, false);
- models[18] = RSModel(RSModel::TreeTypes::RP_TREE_MAX, true);
- models[19] = RSModel(RSModel::TreeTypes::RP_TREE_MAX, false);
- models[20] = RSModel(RSModel::TreeTypes::RP_TREE_MEAN, true);
- models[21] = RSModel(RSModel::TreeTypes::RP_TREE_MEAN, false);
+ models[18] = RSModel(RSModel::TreeTypes::RP_TREE, true);
+ models[19] = RSModel(RSModel::TreeTypes::RP_TREE, false);
+ models[20] = RSModel(RSModel::TreeTypes::MAX_SPLIT_RP_TREE, true);
+ models[21] = RSModel(RSModel::TreeTypes::MAX_SPLIT_RP_TREE, false);
for (size_t j = 0; j < 2; ++j)
{
diff --git a/src/mlpack/tests/tree_test.cpp b/src/mlpack/tests/tree_test.cpp
index 48de0fc..9670048 100644
--- a/src/mlpack/tests/tree_test.cpp
+++ b/src/mlpack/tests/tree_test.cpp
@@ -1354,9 +1354,9 @@ BOOST_AUTO_TEST_CASE(KdTreeTest)
TreeType root(dataset);
}
-BOOST_AUTO_TEST_CASE(RPTreeMaxTest)
+BOOST_AUTO_TEST_CASE(MaxSplitRPTreeTest)
{
- typedef RPTreeMax<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ typedef MaxSplitRPTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
size_t maxRuns = 10; // Ten total tests.
size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
@@ -1396,9 +1396,109 @@ BOOST_AUTO_TEST_CASE(RPTreeMaxTest)
}
}
-BOOST_AUTO_TEST_CASE(RPTreeMeanTest)
+template<typename TreeType>
+bool CheckHyperplaneSplit(const TreeType& tree)
{
- typedef RPTreeMean<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ typedef typename TreeType::ElemType ElemType;
+
+ const typename TreeType::Mat& dataset = tree.Dataset();
+ arma::Mat<typename TreeType::ElemType> mat(dataset.n_rows + 1,
+ tree.Left()->NumDescendants() + tree.Right()->NumDescendants());
+
+ // We will try to find a hyperplane that splits the node.
+ // The hyperplane may be represented as
+ // a_1 * x_1 + ... + a_n * x_n + a_{n + 1} = 0.
+ // We have to solve the system of inequalities (mat^t) * x <= 0,
+ // where x[0], ... , x[dataset.n_rows-1] are the components of the normal
+ // to the hyperplane and x[dataset.n_rows] is the position of the hyperplane
+ // i.e. x = (a_1, ... , a_{n + 1}).
+ // Each column of the matrix consists of a point and 1.
+ // In such a way, the inner product of a column and x is equal to the value
+ // of the hyperplane expression.
+ // The hyperplane splits the node if the expression takes on opposite
+ // values on node's children.
+
+ for (size_t i = 0; i < tree.Left()->NumDescendants(); i++)
+ {
+ for (size_t k = 0; k < dataset.n_rows; k++)
+ mat(k, i) = - dataset(k, tree.Left()->Descendant(i));
+
+ mat(dataset.n_rows, i) = -1;
+ }
+
+ for (size_t i = 0; i < tree.Right()->NumDescendants(); i++)
+ {
+ for (size_t k = 0; k < dataset.n_rows; k++)
+ mat(k, i + tree.Left()->NumDescendants()) =
+ dataset(k, tree.Right()->Descendant(i));
+
+ mat(dataset.n_rows, i + tree.Left()->NumDescendants()) = 1;
+ }
+
+ arma::Col<ElemType> x(dataset.n_rows + 1);
+ x.zeros();
+ // Define an initial value.
+ x[0] = 1.0;
+ x[1] = -arma::mean(
+ dataset.cols(tree.Begin(), tree.Begin() + tree.Count() - 1).row(0));
+
+ const size_t numIters = 1000000;
+ const ElemType delta = 1e-4;
+
+ // We will solve the system using a simple gradient method.
+ bool success = false;
+ for (size_t it = 0; it < numIters; it++)
+ {
+ success = true;
+ for (size_t k = 0; k < tree.Count(); k++)
+ {
+ ElemType result = arma::dot(mat.col(k), x);
+ if (result > 0)
+ {
+ x -= mat.col(k) * delta;
+ success = false;
+ }
+ }
+
+ // The norm of the direction shouldn't be equal to zero.
+ if (arma::norm(x.rows(0, dataset.n_rows-1)) < 1e-8)
+ {
+ x[math::RandInt(0, dataset.n_rows)] = 1.0;
+ success = false;
+ }
+
+ if (success)
+ break;
+ }
+
+ return success;
+}
+
+template<typename TreeType>
+void CheckMaxRPTreeSplit(const TreeType& tree)
+{
+ if (tree.IsLeaf())
+ return;
+
+ BOOST_REQUIRE_EQUAL(CheckHyperplaneSplit(tree), true);
+
+ CheckMaxRPTreeSplit(*tree.Left());
+ CheckMaxRPTreeSplit(*tree.Right());
+}
+
+BOOST_AUTO_TEST_CASE(MaxSplitRPTreeSplitTest)
+{
+ typedef MaxSplitRPTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ arma::mat dataset;
+ dataset.randu(8, 1000);
+ TreeType root(dataset);
+
+ CheckMaxRPTreeSplit(root);
+}
+
+BOOST_AUTO_TEST_CASE(RPTreeTest)
+{
+ typedef RPTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
size_t maxRuns = 10; // Ten total tests.
size_t pointIncrements = 1000; // Range is from 2000 points to 11000.
@@ -1438,6 +1538,53 @@ BOOST_AUTO_TEST_CASE(RPTreeMeanTest)
}
}
+template<typename TreeType, typename MetricType>
+void CheckRPTreeSplit(const TreeType& tree)
+{
+ typedef typename TreeType::ElemType ElemType;
+ if (tree.IsLeaf())
+ return;
+
+ if (!CheckHyperplaneSplit(tree))
+ {
+ // Check if that was mean split.
+ arma::Col<ElemType> center;
+ tree.Left()->Bound().Center(center);
+ ElemType maxDist = 0;
+ for (size_t k =0; k < tree.Left()->NumDescendants(); k++)
+ {
+ ElemType dist = MetricType::Evaluate(center,
+ tree.Dataset().col(tree.Left()->Descendant(k)));
+
+ if (dist > maxDist)
+ maxDist = dist;
+ }
+
+ for (size_t k =0; k < tree.Right()->NumDescendants(); k++)
+ {
+ ElemType dist = MetricType::Evaluate(center,
+ tree.Dataset().col(tree.Right()->Descendant(k)));
+
+ BOOST_REQUIRE_LE(maxDist, dist *
+ (1.0 + 10.0 * std::numeric_limits<ElemType>::epsilon()));
+ }
+
+ }
+
+ CheckRPTreeSplit<TreeType, MetricType>(*tree.Left());
+ CheckRPTreeSplit<TreeType, MetricType>(*tree.Right());
+}
+
+BOOST_AUTO_TEST_CASE(RPTreeSplitTest)
+{
+ typedef RPTree<EuclideanDistance, EmptyStatistic, arma::mat> TreeType;
+ arma::mat dataset;
+ dataset.randu(8, 1000);
+ TreeType root(dataset);
+
+ CheckRPTreeSplit<TreeType, EuclideanDistance>(root);
+}
+
// Recursively checks that each node contains all points that it claims to have.
template<typename TreeType>
bool CheckPointBounds(TreeType& node)
More information about the mlpack-git
mailing list