[mlpack-git] master: Refactor DTB to use new TreeType API. (b37e976)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Wed Jul 29 16:41:35 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/f8ceffae0613b350f4d6bdd46c6c8633a40b4897...6ee21879488fe98612a4619b17f8b51e8da5215b
>---------------------------------------------------------------
commit b37e976adbb17d94a6f7fb5b76696faaa23b0697
Author: ryan <ryan at ratml.org>
Date: Wed Jul 22 23:21:44 2015 -0400
Refactor DTB to use new TreeType API.
>---------------------------------------------------------------
b37e976adbb17d94a6f7fb5b76696faaa23b0697
src/mlpack/methods/emst/dtb.hpp | 28 +++++---
src/mlpack/methods/emst/dtb_impl.hpp | 127 +++++++++++++++++++++++-----------
src/mlpack/methods/emst/emst_main.cpp | 5 +-
3 files changed, 105 insertions(+), 55 deletions(-)
diff --git a/src/mlpack/methods/emst/dtb.hpp b/src/mlpack/methods/emst/dtb.hpp
index 667da69..f706b13 100644
--- a/src/mlpack/methods/emst/dtb.hpp
+++ b/src/mlpack/methods/emst/dtb.hpp
@@ -70,19 +70,25 @@ namespace emst /** Euclidean Minimum Spanning Trees. */ {
* @tparam TreeType Type of tree to use. Should use DTBStat as a statistic.
*/
template<
- typename MetricType = metric::EuclideanDistance,
- typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
+ typename MetricType = metric::EuclideanDistance,
+ typename MatType = arma::mat,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType = tree::KDTree
>
class DualTreeBoruvka
{
+ public:
+ //! Convenience typedef.
+ typedef TreeType<MetricType, DTBStat, MatType> Tree;
+
private:
//! Copy of the data (if necessary).
- typename TreeType::Mat dataCopy;
+ MatType dataCopy;
//! Reference to the data (this is what should be used for accessing data).
- const typename TreeType::Mat& data;
+ const MatType& data;
//! Pointer to the root of the tree.
- TreeType* tree;
+ Tree* tree;
//! Indicates whether or not we "own" the tree.
bool ownTree;
@@ -128,7 +134,7 @@ class DualTreeBoruvka
* @param naive Whether the computation should be done in O(n^2) naive mode.
* @param leafSize The leaf size to be used during tree construction.
*/
- DualTreeBoruvka(const typename TreeType::Mat& dataset,
+ DualTreeBoruvka(const MatType& dataset,
const bool naive = false,
const MetricType metric = MetricType());
@@ -149,8 +155,8 @@ class DualTreeBoruvka
* @param tree Pre-built tree.
* @param dataset Dataset corresponding to the pre-built tree.
*/
- DualTreeBoruvka(TreeType* tree,
- const typename TreeType::Mat& dataset,
+ DualTreeBoruvka(Tree* tree,
+ const MatType& dataset,
const MetricType metric = MetricType());
/**
@@ -194,7 +200,7 @@ class DualTreeBoruvka
* This function resets the values in the nodes of the tree nearest neighbor
* distance, and checks for fully connected nodes.
*/
- void CleanupHelper(TreeType* tree);
+ void CleanupHelper(Tree* tree);
/**
* The values stored in the tree must be reset on each iteration.
@@ -203,8 +209,8 @@ class DualTreeBoruvka
}; // class DualTreeBoruvka
-}; // namespace emst
-}; // namespace mlpack
+} // namespace emst
+} // namespace mlpack
#include "dtb_impl.hpp"
diff --git a/src/mlpack/methods/emst/dtb_impl.hpp b/src/mlpack/methods/emst/dtb_impl.hpp
index 1a14ad7..68580b3 100644
--- a/src/mlpack/methods/emst/dtb_impl.hpp
+++ b/src/mlpack/methods/emst/dtb_impl.hpp
@@ -13,9 +13,9 @@ namespace mlpack {
namespace emst {
//! Call the tree constructor that does mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
TreeType* BuildTree(
- typename TreeType::Mat& dataset,
+ MatType& dataset,
std::vector<size_t>& oldFromNew,
typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == true, TreeType*
@@ -25,9 +25,9 @@ TreeType* BuildTree(
}
//! Call the tree constructor that does not do mapping.
-template<typename TreeType>
+template<typename MatType, typename TreeType>
TreeType* BuildTree(
- const typename TreeType::Mat& dataset,
+ const MatType& dataset,
const std::vector<size_t>& /* oldFromNew */,
const typename boost::enable_if_c<
tree::TreeTraits<TreeType>::RearrangesDataset == false, TreeType*
@@ -40,12 +40,17 @@ TreeType* BuildTree(
* Takes in a reference to the data set. Copies the data, builds the tree,
* and initializes all of the member variables.
*/
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
- const typename TreeType::Mat& dataset,
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
+ const MatType& dataset,
const bool naive,
const MetricType metric) :
- data((tree::TreeTraits<TreeType>::RearrangesDataset && !naive) ? dataCopy : dataset),
+ data((tree::TreeTraits<Tree>::RearrangesDataset && !naive) ? dataCopy :
+ dataset),
ownTree(!naive),
naive(naive),
connections(dataset.n_cols),
@@ -57,11 +62,10 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
if (!naive)
{
// Copy the dataset, if it will be modified during tree construction.
- if (tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (tree::TreeTraits<Tree>::RearrangesDataset)
dataCopy = dataset;
- tree = BuildTree<TreeType>(const_cast<typename TreeType::Mat&>(data),
- oldFromNew);
+ tree = BuildTree<MatType, Tree>(const_cast<MatType&>(data), oldFromNew);
}
Timer::Stop("emst/tree_building");
@@ -72,12 +76,16 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
neighborsOutComponent.set_size(data.n_cols);
neighborsDistances.set_size(data.n_cols);
neighborsDistances.fill(DBL_MAX);
-} // Constructor
+}
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
- TreeType* tree,
- const typename TreeType::Mat& dataset,
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::DualTreeBoruvka(
+ Tree* tree,
+ const MatType& dataset,
const MetricType metric) :
data(dataset),
tree(tree),
@@ -95,8 +103,12 @@ DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
neighborsDistances.fill(DBL_MAX);
}
-template<typename MetricType, typename TreeType>
-DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+DualTreeBoruvka<MetricType, MatType, TreeType>::~DualTreeBoruvka()
{
if (ownTree)
delete tree;
@@ -106,14 +118,19 @@ DualTreeBoruvka<MetricType, TreeType>::~DualTreeBoruvka()
* Iteratively find the nearest neighbor of each component until the MST is
* complete.
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::ComputeMST(
+ arma::mat& results)
{
Timer::Start("emst/mst_computation");
totalDist = 0; // Reset distance.
- typedef DTBRules<MetricType, TreeType> RuleType;
+ typedef DTBRules<MetricType, Tree> RuleType;
RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
neighborsOutComponent, metric);
while (edges.size() < (data.n_cols - 1))
@@ -127,7 +144,7 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
}
else
{
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ typename Tree::template DualTreeTraverser<RuleType> traverser(rules);
traverser.Traverse(*tree, *tree);
}
@@ -154,10 +171,15 @@ void DualTreeBoruvka<MetricType, TreeType>::ComputeMST(arma::mat& results)
/**
* Adds a single edge to the edge list
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
- const size_t e2,
- const double distance)
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::AddEdge(
+ const size_t e1,
+ const size_t e2,
+ const double distance)
{
Log::Assert((distance >= 0.0),
"DualTreeBoruvka::AddEdge(): distance cannot be negative.");
@@ -166,13 +188,17 @@ void DualTreeBoruvka<MetricType, TreeType>::AddEdge(const size_t e1,
edges.push_back(EdgePair(e1, e2, distance));
else
edges.push_back(EdgePair(e2, e1, distance));
-} // AddEdge
+}
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::AddAllEdges()
{
for (size_t i = 0; i < data.n_cols; i++)
{
@@ -188,13 +214,18 @@ void DualTreeBoruvka<MetricType, TreeType>::AddAllEdges()
connections.Union(inEdge, outEdge);
}
}
-} // AddAllEdges
+}
/**
* Unpermute the edge list (if necessary) and output it to results.
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::EmitResults(
+ arma::mat& results)
{
// Sort the edges.
std::sort(edges.begin(), edges.end(), SortFun);
@@ -203,7 +234,7 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
results.set_size(3, edges.size());
// Need to unpermute the point labels.
- if (!naive && ownTree && tree::TreeTraits<TreeType>::RearrangesDataset)
+ if (!naive && ownTree && tree::TreeTraits<Tree>::RearrangesDataset)
{
for (size_t i = 0; i < (data.n_cols - 1); i++)
{
@@ -237,14 +268,18 @@ void DualTreeBoruvka<MetricType, TreeType>::EmitResults(arma::mat& results)
results(2, i) = edges[i].Distance();
}
}
-} // EmitResults
+}
/**
* This function resets the values in the nodes of the tree nearest neighbor
* distance and checks for fully connected nodes.
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::CleanupHelper(Tree* tree)
{
// Reset the statistic information.
tree->Stat().MaxNeighborDistance() = DBL_MAX;
@@ -278,8 +313,12 @@ void DualTreeBoruvka<MetricType, TreeType>::CleanupHelper(TreeType* tree)
/**
* The values stored in the tree must be reset on each iteration.
*/
-template<typename MetricType, typename TreeType>
-void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+void DualTreeBoruvka<MetricType, MatType, TreeType>::Cleanup()
{
for (size_t i = 0; i < data.n_cols; i++)
neighborsDistances[i] = DBL_MAX;
@@ -289,8 +328,12 @@ void DualTreeBoruvka<MetricType, TreeType>::Cleanup()
}
// convert the object to a string
-template<typename MetricType, typename TreeType>
-std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
+template<
+ typename MetricType,
+ typename MatType,
+ template<typename MetricType, typename StatisticType, typename MatType>
+ class TreeType>
+std::string DualTreeBoruvka<MetricType, MatType, TreeType>::ToString() const
{
std::ostringstream convert;
convert << "DualTreeBoruvka [" << this << "]" << std::endl;
@@ -303,7 +346,7 @@ std::string DualTreeBoruvka<MetricType, TreeType>::ToString() const
return convert.str();
}
-}; // namespace emst
-}; // namespace mlpack
+} // namespace emst
+} // namespace mlpack
#endif
diff --git a/src/mlpack/methods/emst/emst_main.cpp b/src/mlpack/methods/emst/emst_main.cpp
index 56a9c4b..6f51673 100644
--- a/src/mlpack/methods/emst/emst_main.cpp
+++ b/src/mlpack/methods/emst/emst_main.cpp
@@ -44,6 +44,7 @@ PARAM_INT("leaf_size", "Leaf size in the kd-tree. One-element leaves give the "
using namespace mlpack;
using namespace mlpack::emst;
using namespace mlpack::tree;
+using namespace mlpack::metric;
using namespace std;
int main(int argc, char* argv[])
@@ -86,8 +87,8 @@ int main(int argc, char* argv[])
Timer::Start("tree_building");
std::vector<size_t> oldFromNew;
- tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> tree(dataPoints,
- oldFromNew, leafSize);
+ KDTree<EuclideanDistance, DTBStat, arma::mat> tree(dataPoints, oldFromNew,
+ leafSize);
metric::LMetric<2, true> metric;
Timer::Stop("tree_building");
More information about the mlpack-git
mailing list