[mlpack-git] master: Replace static SplitType in the RectangleTree by an instantiated object. (03a34ed)
gitdub at mlpack.org
gitdub at mlpack.org
Fri Apr 22 08:10:45 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/4c8a8d1ccbc33916794fe0f6142fa5378ff8503d...93eb34cc4daf08cc331d1a22fccf44950b276daa
>---------------------------------------------------------------
commit 03a34ed2e59c640f29d40483f02d3175f32abe78
Author: Mikhail Lozhnikov <lozhnikovma at gmail.com>
Date: Fri Apr 22 15:10:45 2016 +0300
Replace static SplitType in the RectangleTree by an instantiated object.
>---------------------------------------------------------------
03a34ed2e59c640f29d40483f02d3175f32abe78
.../tree/rectangle_tree/dual_tree_traverser.hpp | 2 +-
.../rectangle_tree/dual_tree_traverser_impl.hpp | 4 +-
.../core/tree/rectangle_tree/r_star_tree_split.hpp | 17 ++++--
.../tree/rectangle_tree/r_star_tree_split_impl.hpp | 31 +++++++---
.../core/tree/rectangle_tree/r_tree_split.hpp | 27 ++++-----
.../core/tree/rectangle_tree/r_tree_split_impl.hpp | 60 +++++++++++--------
.../core/tree/rectangle_tree/rectangle_tree.hpp | 6 +-
.../tree/rectangle_tree/rectangle_tree_impl.hpp | 68 ++++++++++++----------
.../tree/rectangle_tree/single_tree_traverser.hpp | 2 +-
.../rectangle_tree/single_tree_traverser_impl.hpp | 4 +-
src/mlpack/core/tree/rectangle_tree/traits.hpp | 2 +-
.../core/tree/rectangle_tree/x_tree_split.hpp | 17 ++++--
.../core/tree/rectangle_tree/x_tree_split_impl.hpp | 31 +++++++---
13 files changed, 169 insertions(+), 102 deletions(-)
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
index 397a1c2..75a0501 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser.hpp
@@ -19,7 +19,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
class RectangleTree<MetricType, StatisticType, MatType, SplitType,
diff --git a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
index 42379d2..7f02ad9 100644
--- a/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/dual_tree_traverser_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -35,7 +35,7 @@ DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
index 54d1bba..3970f3a 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split.hpp
@@ -18,25 +18,33 @@ namespace tree /** Trees and tree-building procedures. */ {
* nodes overflow, we split them, moving up the tree and splitting nodes
* as necessary.
*/
+template <typename TreeType>
class RStarTreeSplit
{
public:
+ //! Default constructor
+ RStarTreeSplit();
+
+ //! Construct this with specified node.
+ RStarTreeSplit(TreeType *node);
+
/**
* Split a leaf node using the algorithm described in "The R*-tree: An
* Efficient and Robust Access method for Points and Rectangles." If
* necessary, this split will propagate upwards through the tree.
*/
- template<typename TreeType>
- static void SplitLeafNode(TreeType* tree, std::vector<bool>& relevels);
+ void SplitLeafNode(std::vector<bool>& relevels);
/**
* Split a non-leaf node using the "default" algorithm. If this is a root
* node, the tree increases in depth.
*/
- template<typename TreeType>
- static bool SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels);
+ bool SplitNonLeafNode(std::vector<bool>& relevels);
private:
+ //! The node which has to be split.
+ TreeType *tree;
+
/**
* Class to allow for faster sorting.
*/
@@ -60,7 +68,6 @@ class RStarTreeSplit
/**
* Insert a node into another node.
*/
- template<typename TreeType>
static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
};
diff --git a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
index 9c0e95c..a2af8aa 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_star_tree_split_impl.hpp
@@ -15,6 +15,20 @@
namespace mlpack {
namespace tree {
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit() :
+ tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+RStarTreeSplit<TreeType>::RStarTreeSplit(TreeType *node) :
+ tree(node)
+{
+
+}
+
/**
* We call GetPointSeeds to get the two points which will be the initial points
* in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -22,7 +36,7 @@ namespace tree {
* new nodes into the tree, spliting the parent if necessary.
*/
template<typename TreeType>
-void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RStarTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
{
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
@@ -41,7 +55,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
tree->Children()[(tree->NumChildren())++] = copy;
assert(tree->NumChildren() == 1);
- SplitLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return;
}
@@ -58,7 +72,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
size_t p = tree->MaxLeafSize() * 0.3; // The paper says this works the best.
if (p == 0)
{
- SplitLeafNode(tree, relevels);
+ tree->SplitNode(relevels);
return;
}
@@ -251,7 +265,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
// just in case, we use an assert.
assert(par->NumChildren() <= par->MaxNumChildren() + 1);
if (par->NumChildren() == par->MaxNumChildren() + 1)
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -269,8 +283,7 @@ void RStarTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
* higher up the tree because they were already updated if necessary.
*/
template<typename TreeType>
-bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
- std::vector<bool>& relevels)
+bool RStarTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
{
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
@@ -288,7 +301,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
tree->NullifyData();
tree->Children()[(tree->NumChildren())++] = copy;
- SplitNonLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return true;
}
@@ -644,7 +657,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
assert(par->NumChildren() <= par->MaxNumChildren() + 1);
if (par->NumChildren() == par->MaxNumChildren() + 1)
{
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
}
// We have to update the children of each of these new nodes so that they
@@ -673,7 +686,7 @@ bool RStarTreeSplit::SplitNonLeafNode(TreeType* tree,
* numberOfChildren.
*/
template<typename TreeType>
-void RStarTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void RStarTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
destTree->Children()[destTree->NumChildren()++] = srcNode;
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
index 1bf5745..8a412ad 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split.hpp
@@ -18,42 +18,45 @@ namespace tree /** Trees and tree-building procedures. */ {
* nodes overflow, we split them, moving up the tree and splitting nodes
* as necessary.
*/
+template<typename TreeType>
class RTreeSplit
{
public:
+ //! Default constructor
+ RTreeSplit();
+
+ //! Construct this with specified node.
+ RTreeSplit(TreeType *node);
+
/**
* Split a leaf node using the "default" algorithm. If necessary, this split
* will propagate upwards through the tree.
*/
- template<typename TreeType>
- static void SplitLeafNode(TreeType* tree,
- std::vector<bool>& relevels);
+ void SplitLeafNode(std::vector<bool>& relevels);
/**
* Split a non-leaf node using the "default" algorithm. If this is a root
* node, the tree increases in depth.
*/
- template<typename TreeType>
- static bool SplitNonLeafNode(TreeType* tree,
- std::vector<bool>& relevels);
+ bool SplitNonLeafNode(std::vector<bool>& relevels);
private:
+ //! The node which has to be split.
+ TreeType *tree;
+
/**
* Get the seeds for splitting a leaf node.
*/
- template<typename TreeType>
- static void GetPointSeeds(const TreeType& tree, int& i, int& j);
+ void GetPointSeeds(int& i, int& j);
/**
* Get the seeds for splitting a non-leaf node.
*/
- template<typename TreeType>
- static void GetBoundSeeds(const TreeType& tree, int& i, int& j);
+ void GetBoundSeeds(int& i, int& j);
/**
* Assign points to the two new nodes.
*/
- template<typename TreeType>
static void AssignPointDestNode(TreeType* oldTree,
TreeType* treeOne,
TreeType* treeTwo,
@@ -63,7 +66,6 @@ class RTreeSplit
/**
* Assign nodes to the two new nodes.
*/
- template<typename TreeType>
static void AssignNodeDestNode(TreeType* oldTree,
TreeType* treeOne,
TreeType* treeTwo,
@@ -73,7 +75,6 @@ class RTreeSplit
/**
* Insert a node into another node.
*/
- template<typename TreeType>
static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
};
diff --git a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
index 15d70bd..5bc0461 100644
--- a/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/r_tree_split_impl.hpp
@@ -14,6 +14,20 @@
namespace mlpack {
namespace tree {
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit() :
+ tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+RTreeSplit<TreeType>::RTreeSplit(TreeType *node) :
+ tree(node)
+{
+
+}
+
/**
* We call GetPointSeeds to get the two points which will be the initial points
* in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -21,7 +35,7 @@ namespace tree {
* new nodes into the tree, spliting the parent if necessary.
*/
template<typename TreeType>
-void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void RTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
@@ -35,7 +49,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
tree->NullifyData();
// Because this was a leaf node, numChildren must be 0.
tree->Children()[(tree->NumChildren())++] = copy;
- SplitLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return;
}
@@ -46,7 +60,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
// rectangles, only points. We assume that the tree uses Euclidean Distance.
int i = 0;
int j = 0;
- GetPointSeeds(*tree, i, j);
+ GetPointSeeds(i, j);
TreeType* treeOne = new TreeType(tree->Parent());
TreeType* treeTwo = new TreeType(tree->Parent());
@@ -66,7 +80,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
// just in case, we use an assert.
assert(par->NumChildren() <= par->MaxNumChildren() + 1);
if (par->NumChildren() == par->MaxNumChildren() + 1)
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
assert(treeOne->Parent()->NumChildren() <= treeOne->MaxNumChildren());
assert(treeOne->Parent()->NumChildren() >= treeOne->MinNumChildren());
@@ -85,7 +99,7 @@ void RTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
* higher up the tree because they were already updated if necessary.
*/
template<typename TreeType>
-bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
+bool RTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
{
// If we are splitting the root node, we need will do things differently so
// that the constructor and other methods don't confuse the end user by giving
@@ -98,13 +112,13 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
tree->NumChildren() = 0;
tree->NullifyData();
tree->Children()[(tree->NumChildren())++] = copy;
- SplitNonLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return true;
}
int i = 0;
int j = 0;
- GetBoundSeeds(*tree, i, j);
+ GetBoundSeeds(i, j);
assert(i != j);
@@ -131,7 +145,7 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
assert(par->NumChildren() <= par->MaxNumChildren() + 1);
if (par->NumChildren() == par->MaxNumChildren() + 1)
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
// We have to update the children of each of these new nodes so that they
// record the correct parent.
@@ -157,18 +171,18 @@ bool RTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
* The indices of these points will be stored in iRet and jRet.
*/
template<typename TreeType>
-void RTreeSplit::GetPointSeeds(const TreeType& tree, int& iRet, int& jRet)
+void RTreeSplit<TreeType>::GetPointSeeds(int& iRet, int& jRet)
{
// Here we want to find the pair of points that it is worst to place in the
// same node. Because we are just using points, we will simply choose the two
// that would create the most voluminous hyperrectangle.
typename TreeType::ElemType worstPairScore = -1.0;
- for (size_t i = 0; i < tree.Count(); i++)
+ for (size_t i = 0; i < tree->Count(); i++)
{
- for (size_t j = i + 1; j < tree.Count(); j++)
+ for (size_t j = i + 1; j < tree->Count(); j++)
{
const typename TreeType::ElemType score = arma::prod(arma::abs(
- tree.LocalDataset().col(i) - tree.LocalDataset().col(j)));
+ tree->LocalDataset().col(i) - tree->LocalDataset().col(j)));
if (score > worstPairScore)
{
@@ -185,23 +199,23 @@ void RTreeSplit::GetPointSeeds(const TreeType& tree, int& iRet, int& jRet)
* indices of the bounds will be stored in iRet and jRet.
*/
template<typename TreeType>
-void RTreeSplit::GetBoundSeeds(const TreeType& tree, int& iRet, int& jRet)
+void RTreeSplit<TreeType>::GetBoundSeeds(int& iRet, int& jRet)
{
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
ElemType worstPairScore = -1.0;
- for (size_t i = 0; i < tree.NumChildren(); i++)
+ for (size_t i = 0; i < tree->NumChildren(); i++)
{
- for (size_t j = i + 1; j < tree.NumChildren(); j++)
+ for (size_t j = i + 1; j < tree->NumChildren(); j++)
{
ElemType score = 1.0;
- for (size_t k = 0; k < tree.Bound().Dim(); k++)
+ for (size_t k = 0; k < tree->Bound().Dim(); k++)
{
- const ElemType hiMax = std::max(tree.Children()[i]->Bound()[k].Hi(),
- tree.Children()[j]->Bound()[k].Hi());
- const ElemType loMin = std::min(tree.Children()[i]->Bound()[k].Lo(),
- tree.Children()[j]->Bound()[k].Lo());
+ const ElemType hiMax = std::max(tree->Children()[i]->Bound()[k].Hi(),
+ tree->Children()[j]->Bound()[k].Hi());
+ const ElemType loMin = std::min(tree->Children()[i]->Bound()[k].Lo(),
+ tree->Children()[j]->Bound()[k].Lo());
score *= (hiMax - loMin);
}
@@ -216,7 +230,7 @@ void RTreeSplit::GetBoundSeeds(const TreeType& tree, int& iRet, int& jRet)
}
template<typename TreeType>
-void RTreeSplit::AssignPointDestNode(TreeType* oldTree,
+void RTreeSplit<TreeType>::AssignPointDestNode(TreeType* oldTree,
TreeType* treeOne,
TreeType* treeTwo,
const int intI,
@@ -357,7 +371,7 @@ void RTreeSplit::AssignPointDestNode(TreeType* oldTree,
}
template<typename TreeType>
-void RTreeSplit::AssignNodeDestNode(TreeType* oldTree,
+void RTreeSplit<TreeType>::AssignNodeDestNode(TreeType* oldTree,
TreeType* treeOne,
TreeType* treeTwo,
const int intI,
@@ -522,7 +536,7 @@ void RTreeSplit::AssignNodeDestNode(TreeType* oldTree,
* numberOfChildren.
*/
template<typename TreeType>
-void RTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void RTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
destTree->Children()[destTree->NumChildren()++] = srcNode;
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
index e4d69d3..090a832 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree.hpp
@@ -35,13 +35,15 @@ namespace tree /** Trees and tree-building procedures. */ {
* @tparam DescentType The heuristic to use when descending the tree to insert
* points.
*/
+
template<typename MetricType = metric::EuclideanDistance,
typename StatisticType = EmptyStatistic,
typename MatType = arma::mat,
- typename SplitType = RTreeSplit,
+ template<typename> class SplitType = RTreeSplit,
typename DescentType = RTreeDescentHeuristic>
class RectangleTree
{
+ friend class SplitType<RectangleTree>;
// The metric *must* be the euclidean distance.
static_assert(boost::is_same<MetricType, metric::EuclideanDistance>::value,
"RectangleTree: MetricType must be metric::EuclideanDistance.");
@@ -118,6 +120,8 @@ class RectangleTree
std::vector<size_t> points;
//! The local dataset
MatType* localDataset;
+ //! The class that performs the split of the node.
+ SplitType<RectangleTree> split;
public:
//! A single traverser for rectangle type trees. See
diff --git a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
index 2c747ca..cca78de 100644
--- a/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/rectangle_tree_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(const MatType& data,
@@ -50,6 +50,8 @@ RectangleTree(const MatType& data,
{
stat = StatisticType(*this);
+ split = SplitType<RectangleTree>(this);
+
// For now, just insert the points in order.
RectangleTree* root = this;
@@ -60,7 +62,7 @@ RectangleTree(const MatType& data,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(MatType&& data,
@@ -90,6 +92,8 @@ RectangleTree(MatType&& data,
{
stat = StatisticType(*this);
+ split = SplitType<RectangleTree>(this);
+
// For now, just insert the points in order.
RectangleTree* root = this;
@@ -100,7 +104,7 @@ RectangleTree(MatType&& data,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(
@@ -126,6 +130,7 @@ RectangleTree(
maxLeafSize + 1)))
{
stat = StatisticType(*this);
+ split = SplitType<RectangleTree>(this);
}
/**
@@ -135,7 +140,7 @@ RectangleTree(
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(
@@ -159,6 +164,7 @@ RectangleTree(
points(other.Points()),
localDataset(NULL)
{
+ split = SplitType<RectangleTree>(this);
if (deepCopy)
{
if (numChildren > 0)
@@ -187,7 +193,7 @@ RectangleTree(
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename Archive>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -208,7 +214,7 @@ RectangleTree(
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
~RectangleTree()
@@ -229,7 +235,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
SoftDelete()
@@ -249,7 +255,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
NullifyData()
@@ -264,7 +270,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
InsertPoint(const size_t point)
@@ -301,7 +307,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
InsertPoint(const size_t point, std::vector<bool>& relevels)
@@ -336,7 +342,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
InsertNode(RectangleTree* node,
@@ -365,7 +371,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
DeletePoint(const size_t point)
@@ -410,7 +416,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
DeletePoint(const size_t point, std::vector<bool>& relevels)
@@ -445,7 +451,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RemoveNode(const RectangleTree* node, std::vector<bool>& relevels)
@@ -474,7 +480,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::TreeSize() const
@@ -489,7 +495,7 @@ size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::TreeDepth() const
@@ -509,7 +515,7 @@ size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline bool RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::IsLeaf() const
@@ -524,7 +530,7 @@ inline bool RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline
typename RectangleTree<MetricType, StatisticType, MatType, SplitType,
@@ -549,7 +555,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline
typename RectangleTree<MetricType, StatisticType, MatType, SplitType,
@@ -568,7 +574,7 @@ RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::NumPoints() const
@@ -585,7 +591,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::NumDescendants() const
@@ -609,7 +615,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::Descendant(const size_t index) const
@@ -641,7 +647,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>::Point(const size_t index) const
@@ -656,7 +662,7 @@ inline size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
SplitNode(std::vector<bool>& relevels)
@@ -669,7 +675,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
// If we are full, then we need to split (or at least try). The SplitType
// takes care of this and of moving up the tree if necessary.
- SplitType::SplitLeafNode(this, relevels);
+ split.SplitLeafNode(relevels);
}
else
{
@@ -679,7 +685,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
// If we are full, then we need to split (or at least try). The SplitType
// takes care of this and of moving up the tree if necessary.
- SplitType::SplitNonLeafNode(this, relevels);
+ split.SplitNonLeafNode(relevels);
}
}
@@ -687,7 +693,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree() :
@@ -715,7 +721,7 @@ RectangleTree() :
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
CondenseTree(const arma::vec& point,
@@ -844,7 +850,7 @@ void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
ShrinkBoundForPoint(const arma::vec& point)
@@ -940,7 +946,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
ShrinkBoundForBound(const bound::HRectBound<MetricType>& /* b */)
@@ -974,7 +980,7 @@ bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename Archive>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
index 720a06e..7c4c938 100644
--- a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser.hpp
@@ -19,7 +19,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
class RectangleTree<MetricType, StatisticType, MatType, SplitType,
diff --git a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
index a5d9792..b67ae72 100644
--- a/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/single_tree_traverser_impl.hpp
@@ -20,7 +20,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
@@ -32,7 +32,7 @@ SingleTreeTraverser<RuleType>::SingleTreeTraverser(RuleType& rule) :
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
template<typename RuleType>
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
diff --git a/src/mlpack/core/tree/rectangle_tree/traits.hpp b/src/mlpack/core/tree/rectangle_tree/traits.hpp
index 811d9ac..49b357f 100644
--- a/src/mlpack/core/tree/rectangle_tree/traits.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/traits.hpp
@@ -21,7 +21,7 @@ namespace tree {
template<typename MetricType,
typename StatisticType,
typename MatType,
- typename SplitType,
+ template<typename> class SplitType,
typename DescentType>
class TreeTraits<RectangleTree<MetricType, StatisticType, MatType, SplitType,
DescentType>>
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
index ef9224c..3264353 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split.hpp
@@ -28,25 +28,33 @@ const double MAX_OVERLAP = 0.2;
* nodes overflow, we split them, moving up the tree and splitting nodes
* as necessary.
*/
+template<typename TreeType>
class XTreeSplit
{
public:
+ //! Default constructor
+ XTreeSplit();
+
+ //! Construct this with specified node.
+ XTreeSplit(TreeType *node);
+
/**
* Split a leaf node using the algorithm described in "The R*-tree: An
* Efficient and Robust Access method for Points and Rectangles." If
* necessary, this split will propagate upwards through the tree.
*/
- template<typename TreeType>
- static void SplitLeafNode(TreeType* tree, std::vector<bool>& relevels);
+ void SplitLeafNode(std::vector<bool>& relevels);
/**
* Split a non-leaf node using the "default" algorithm. If this is a root
* node, the tree increases in depth.
*/
- template<typename TreeType>
- static bool SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels);
+ bool SplitNonLeafNode(std::vector<bool>& relevels);
private:
+ //! The node which has to be split.
+ TreeType *tree;
+
/**
* Class to allow for faster sorting.
*/
@@ -71,7 +79,6 @@ class XTreeSplit
/**
* Insert a node into another node.
*/
- template<typename TreeType>
static void InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode);
};
diff --git a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
index 21d3da3..685b26a 100644
--- a/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
+++ b/src/mlpack/core/tree/rectangle_tree/x_tree_split_impl.hpp
@@ -14,6 +14,21 @@
namespace mlpack {
namespace tree {
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit() :
+ tree(NULL)
+{
+
+}
+
+template<typename TreeType>
+XTreeSplit<TreeType>::XTreeSplit(TreeType *node) :
+ tree(node)
+{
+
+}
+
+
/**
* We call GetPointSeeds to get the two points which will be the initial points
* in the new nodes We then call AssignPointDestNode to assign the remaining
@@ -21,7 +36,7 @@ namespace tree {
* new nodes into the tree, spliting the parent if necessary.
*/
template<typename TreeType>
-void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
+void XTreeSplit<TreeType>::SplitLeafNode(std::vector<bool>& relevels)
{
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
@@ -39,7 +54,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
// Because this was a leaf node, numChildren must be 0.
tree->Children()[(tree->NumChildren())++] = copy;
assert(tree->NumChildren() == 1);
- XTreeSplit::SplitLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return;
}
@@ -57,7 +72,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
size_t p = tree->MaxLeafSize() * 0.3;
if (p == 0)
{
- SplitLeafNode(tree, relevels);
+ tree->SplitNode(relevels);
return;
}
@@ -270,7 +285,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
// in case, we use an assert.
assert(par->NumChildren() <= par->MaxNumChildren() + 1);
if (par->NumChildren() == par->MaxNumChildren() + 1)
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
assert(treeOne->Parent()->NumChildren() <=
treeOne->Parent()->MaxNumChildren());
@@ -292,7 +307,7 @@ void XTreeSplit::SplitLeafNode(TreeType* tree, std::vector<bool>& relevels)
* higher up the tree because they were already updated if necessary.
*/
template<typename TreeType>
-bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
+bool XTreeSplit<TreeType>::SplitNonLeafNode(std::vector<bool>& relevels)
{
// Convenience typedef.
typedef typename TreeType::ElemType ElemType;
@@ -309,7 +324,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
tree->NumChildren() = 0;
tree->NullifyData();
tree->Children()[(tree->NumChildren())++] = copy;
- XTreeSplit::SplitNonLeafNode(copy, relevels);
+ copy->SplitNode(relevels);
return true;
}
@@ -803,7 +818,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
if (par->NumChildren() == par->MaxNumChildren() + 1)
{
- SplitNonLeafNode(par, relevels);
+ par->SplitNode(relevels);
}
// We have to update the children of each of these new nodes so that they
@@ -832,7 +847,7 @@ bool XTreeSplit::SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
* numberOfChildren.
*/
template<typename TreeType>
-void XTreeSplit::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
+void XTreeSplit<TreeType>::InsertNodeIntoTree(TreeType* destTree, TreeType* srcNode)
{
destTree->Bound() |= srcNode->Bound();
destTree->Children()[destTree->NumChildren()] = srcNode;
More information about the mlpack-git
mailing list