[mlpack-svn] r10771 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Dec 14 07:25:58 EST 2011
Author: rcurtin
Date: 2011-12-14 07:25:58 -0500 (Wed, 14 Dec 2011)
New Revision: 10771
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
Log:
Comment things a little better, and rename one of the timers.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-14 12:16:49 UTC (rev 10770)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-14 12:25:58 UTC (rev 10771)
@@ -23,22 +23,19 @@
namespace emst {
/**
- * A Stat class for use with fastlib's trees. This one only stores two values.
- *
- * @param max_neighbor_distance The upper bound on the distance to the nearest
- * neighbor of any point in this node.
- *
- * @param component_membership The index of the component that all points in
- * this node belong to. This is the same index returned by UnionFind for all
- * points in this node. If points in this node are in different components,
- * this value will be negative.
+ * A statistic for use with MLPACK trees, which stores the upper bound on
+ * distance to nearest neighbors and the component which this node belongs to.
*/
class DTBStat
{
private:
- //! Maximum neighbor distance.
+ //! Upper bound on the distance to the nearest neighbor of any point in this
+ //! node.
double maxNeighborDistance;
- //! Component membership of this node.
+ //! The index of the component that all points in this node belong to. This
+ //! is the same index returned by UnionFind for all points in this node. If
+ //! points in this node are in different components, this value will be
+ //! negative.
int componentMembership;
public:
@@ -54,7 +51,7 @@
DTBStat(const MatType& dataset, const size_t start, const size_t count);
/**
- * An initializer for non-leaves. Simply calls the leaf initializer.
+ * An initializer for non-leaves.
*/
template<typename MatType>
DTBStat(const MatType& dataset, const size_t start, const size_t count,
@@ -73,7 +70,23 @@
}; // class DTBStat
/**
- * Performs the MST calculation using the Dual-Tree Boruvka algorithm.
+ * Performs the MST calculation using the Dual-Tree Boruvka algorithm, using any
+ * type of tree. At the moment this class does not support arbitrary distance
+ * metrics, and uses the squared Euclidean distance.
+ *
+ * For more information on the algorithm, see the following citation:
+ *
+ * @inproceedings{
+ * author = {March, W.B., Ram, P., and Gray, A.G.},
+ * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
+ * Applications.}},
+ * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}
+ * series = {KDD '10},
+ * year = {2010}
+ * }
+ *
+ * @tparam TreeType Type of tree to use.
*/
template<
typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
@@ -82,9 +95,9 @@
{
private:
//! Copy of the data (if necessary).
- arma::mat dataCopy;
+ typename TreeType::Mat dataCopy;
//! Reference to the data (this is what should be used for accessing data).
- arma::mat& data;
+ typename TreeType::Mat& data;
//! Pointer to the root of the tree.
TreeType* tree;
@@ -109,7 +122,7 @@
//! List of edge distances.
arma::vec neighborsDistances;
- // output info
+ //! Total distance of the tree.
double totalDist;
// For sorting the edge list after the computation.
@@ -121,8 +134,6 @@
}
} SortFun;
-
-////////////////// Constructors ////////////////////////
public:
/**
* Create the tree from the given dataset. This copies the dataset to an
@@ -161,12 +172,16 @@
~DualTreeBoruvka();
/**
- * Call this function after Init. It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
+ * Iteratively find the nearest neighbor of each component until the MST is
+ * complete. The results will be a 3xN matrix (with N equal to the number of
+ * edges in the minimum spanning tree). The first row will contain the lesser
+ * index of the edge; the second row will contain the greater index of the
+ * edge; and the third row will contain the distance between the two edges.
+ *
+ * @param results Matrix which results will be stored in.
*/
void ComputeMST(arma::mat& results);
- ////////////////////////// Private Functions ////////////////////
private:
/**
* Adds a single edge to the edge list
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2011-12-14 12:16:49 UTC (rev 10770)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2011-12-14 12:25:58 UTC (rev 10771)
@@ -69,7 +69,7 @@
connections(data.n_cols),
totalDist(0.0)
{
- Timer::Start("emst/treebuilding");
+ Timer::Start("emst/tree_building");
if (!naive)
{
@@ -83,7 +83,7 @@
tree = new TreeType(data, oldFromNew, data.n_cols);
}
- Timer::Stop("emst/treebuilding");
+ Timer::Stop("emst/tree_building");
edges.reserve(data.n_cols - 1); // Set size.
@@ -120,14 +120,16 @@
}
/**
- * Call this function after Init. It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
+ * Iteratively find the nearest neighbor of each component until the MST is
+ * complete.
*/
template<typename TreeType>
void DualTreeBoruvka<TreeType>::ComputeMST(arma::mat& results)
{
Timer::Start("emst/mst_computation");
+ totalDist = 0; // Reset distance.
+
while (edges.size() < (data.n_cols - 1))
{
// Compute neighbors.
More information about the mlpack-svn
mailing list