[mlpack-svn] r10363 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Nov 23 17:22:16 EST 2011
Author: rcurtin
Date: 2011-11-23 17:22:16 -0500 (Wed, 23 Nov 2011)
New Revision: 10363
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/emst.hpp
mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
Log:
Format EMST as per #153.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-11-23 22:22:16 UTC (rev 10363)
@@ -10,7 +10,6 @@
* Spanning Tree: Algorithm, Analysis, Applications. In KDD, 2010.
*
*/
-
#ifndef __MLPACK_METHODS_EMST_DTB_HPP
#define __MLPACK_METHODS_EMST_DTB_HPP
@@ -25,7 +24,7 @@
namespace emst {
/**
-* A Stat class for use with fastlib's trees. This one only stores two values.
+ * A Stat class for use with fastlib's trees. This one only stores two values.
*
* @param max_neighbor_distance The upper bound on the distance to the nearest
* neighbor of any point in this node.
@@ -35,72 +34,83 @@
* points in this node. If points in this node are in different components,
* this value will be negative.
*/
-
-class DTBStat {
+class DTBStat
+{
private:
double max_neighbor_distance_;
int component_membership_;
public:
- void set_max_neighbor_distance(double distance) {
+ void set_max_neighbor_distance(double distance)
+ {
max_neighbor_distance_ = distance;
}
- double max_neighbor_distance() {
+ double max_neighbor_distance()
+ {
return max_neighbor_distance_;
}
- void set_component_membership(int membership) {
+ void set_component_membership(int membership)
+ {
component_membership_ = membership;
}
- int component_membership() {
+ int component_membership()
+ {
return component_membership_;
}
/**
- * A generic initializer.
- */
- DTBStat() {
+ * A generic initializer.
+ */
+ DTBStat()
+ {
set_max_neighbor_distance(DBL_MAX);
set_component_membership(-1);
}
/**
- * An initializer for leaves.
+ * An initializer for leaves.
*/
- DTBStat(const arma::mat& dataset, size_t start, size_t count) {
- if (count == 1) {
+ DTBStat(const arma::mat& dataset, size_t start, size_t count)
+ {
+ if (count == 1)
+ {
set_component_membership(start);
set_max_neighbor_distance(DBL_MAX);
- } else {
+ }
+ else
+ {
set_max_neighbor_distance(DBL_MAX);
set_component_membership(-1);
}
}
/**
- * An initializer for non-leaves. Simply calls the leaf initializer.
+ * An initializer for non-leaves. Simply calls the leaf initializer.
*/
DTBStat(const arma::mat& dataset, size_t start, size_t count,
- const DTBStat& left_stat, const DTBStat& right_stat) {
- if (count == 1) {
+ const DTBStat& left_stat, const DTBStat& right_stat)
+ {
+ if (count == 1)
+ {
set_component_membership(start);
set_max_neighbor_distance(DBL_MAX);
- } else {
+ }
+ else
+ {
set_max_neighbor_distance(DBL_MAX);
set_component_membership(-1);
}
}
-
}; // class DTBStat
-
/**
* Performs the MST calculation using the Dual-Tree Boruvka algorithm.
*/
-class DualTreeBoruvka {
-
+class DualTreeBoruvka
+{
public:
// For now, everything is in Euclidean space
static const size_t metric = 2;
@@ -110,7 +120,6 @@
//////// Member Variables /////////////////////
private:
-
size_t number_of_edges_;
std::vector<EdgePair> edges_; // must use vector with non-numerical types
size_t number_of_points_;
@@ -139,28 +148,23 @@
DTBTree* tree_;
-
////////////////// Constructors ////////////////////////
-
public:
+ DualTreeBoruvka() { }
- DualTreeBoruvka() {}
-
- ~DualTreeBoruvka() {
- if (tree_ != NULL) {
+ ~DualTreeBoruvka()
+ {
+ if (tree_ != NULL)
delete tree_;
- }
}
-
////////////////////////// Private Functions ////////////////////
private:
-
/**
- * Adds a single edge to the edge list
+ * Adds a single edge to the edge list
*/
- void AddEdge_(size_t e1, size_t e2, double distance) {
-
+ void AddEdge_(size_t e1, size_t e2, double distance)
+ {
//EdgePair edge;
mlpack::Log::Assert((e1 != e2),
"Indices are equal in DualTreeBoruvka.add_edge(...)");
@@ -168,28 +172,27 @@
mlpack::Log::Assert((distance >= 0.0),
"Negative distance input in DualTreeBoruvka.add_edge(...)");
- if (e1 < e2) {
+ if (e1 < e2)
edges_[number_of_edges_].Init(e1, e2, distance);
- }
- else {
+ else
edges_[number_of_edges_].Init(e2, e1, distance);
- }
number_of_edges_++;
} // AddEdge_
-
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
- void AddAllEdges_() {
-
- for (size_t i = 0; i < number_of_points_; i++) {
+ void AddAllEdges_()
+ {
+ for (size_t i = 0; i < number_of_points_; i++)
+ {
size_t component_i = connections_.Find(i);
size_t in_edge_i = neighbors_in_component_[component_i];
size_t out_edge_i = neighbors_out_component_[component_i];
- if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i)) {
+ if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+ {
double dist = neighbors_distances_[component_i];
//total_dist_ = total_dist_ + dist;
// changed to make this agree with the cover tree code
@@ -198,58 +201,52 @@
connections_.Union(in_edge_i, out_edge_i);
}
}
-
} // AddAllEdges_
/**
- * Handles the base case computation. Also called by naive.
- */
+ * Handles the base case computation. Also called by naive.
+ */
double ComputeBaseCase_(size_t query_start, size_t query_end,
- size_t reference_start, size_t reference_end) {
-
+ size_t reference_start, size_t reference_end)
+ {
number_leaf_computations_++;
double new_upper_bound = -1.0;
for (size_t query_index = query_start; query_index < query_end;
- query_index++) {
-
+ query_index++)
+ {
// Find the index of the component the query is in
size_t query_component_index = connections_.Find(query_index);
arma::vec query_point = data_points_.col(query_index);
for (size_t reference_index = reference_start;
- reference_index < reference_end; reference_index++) {
-
+ reference_index < reference_end; reference_index++)
+ {
size_t reference_component_index = connections_.Find(reference_index);
- if (query_component_index != reference_component_index) {
-
+ if (query_component_index != reference_component_index)
+ {
arma::vec reference_point = data_points_.col(reference_index);
double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
reference_point);
- if (distance < neighbors_distances_[query_component_index]) {
-
+ if (distance < neighbors_distances_[query_component_index])
+ {
mlpack::Log::Assert(query_index != reference_index);
neighbors_distances_[query_component_index] = distance;
neighbors_in_component_[query_component_index] = query_index;
neighbors_out_component_[query_component_index] = reference_index;
-
} // if distance
-
} // if indices not equal
-
-
} // for reference_index
- if (new_upper_bound < neighbors_distances_[query_component_index]) {
+ if (new_upper_bound < neighbors_distances_[query_component_index])
new_upper_bound = neighbors_distances_[query_component_index];
- }
} // for query_index
@@ -260,178 +257,183 @@
/**
- * Handles the recursive calls to find the nearest neighbors in an iteration
+ * Handles the recursive calls to find the nearest neighbors in an iteration
*/
void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
- double incoming_distance) {
-
- // Check for a distance prune
- if (query_node->stat().max_neighbor_distance() < incoming_distance) {
- //pruned by distance
+ double incoming_distance)
+ {
+ // Check for a distance prune.
+ if (query_node->stat().max_neighbor_distance() < incoming_distance)
+ {
+ // Pruned by distance.
number_distance_prunes_++;
}
- // Check for a component prune
+ // Check for a component prune.
else if ((query_node->stat().component_membership() >= 0)
- && (query_node->stat().component_membership() ==
- reference_node->stat().component_membership())) {
- //pruned by component membership
-
+ && (query_node->stat().component_membership() ==
+ reference_node->stat().component_membership()))
+ {
+ // Pruned by component membership.
mlpack::Log::Assert(reference_node->stat().component_membership() >= 0);
+ mlpack::Log::Info << query_node->stat().component_membership()
+ << "q mem\n";
+ mlpack::Log::Info << reference_node->stat().component_membership()
+ << "r mem\n";
- mlpack::Log::Info << query_node->stat().component_membership() << "q mem\n";
- mlpack::Log::Info << reference_node->stat().component_membership() << "r mem\n";
-
number_component_prunes_++;
}
- // The base case
- else if (query_node->is_leaf() && reference_node->is_leaf()) {
+ else if (query_node->is_leaf() && reference_node->is_leaf()) // Base case.
+ {
+ double new_bound = ComputeBaseCase_(query_node->begin(),
+ query_node->end(), reference_node->begin(), reference_node->end());
- double new_bound =
- ComputeBaseCase_(query_node->begin(), query_node->end(),
- reference_node->begin(), reference_node->end());
-
query_node->stat().set_max_neighbor_distance(new_bound);
-
}
- // Other recursive calls
- else if (query_node->is_leaf()) {
- //recurse on reference_node only
+ else if (query_node->is_leaf()) // Other recursive calls.
+ {
+ // Recurse on reference_node only.
number_r_recursions_++;
double left_dist =
- query_node->bound().MinDistance(reference_node->left()->bound());
+ query_node->bound().MinDistance(reference_node->left()->bound());
double right_dist =
- query_node->bound().MinDistance(reference_node->right()->bound());
+ query_node->bound().MinDistance(reference_node->right()->bound());
mlpack::Log::Assert(left_dist >= 0.0);
mlpack::Log::Assert(right_dist >= 0.0);
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(), right_dist);
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_dist);
}
- else {
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(), left_dist);
+ else
+ {
+ ComputeNeighborsRecursion_(query_node, reference_node->right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node, reference_node->left(),
+ left_dist);
}
-
}
- else if (reference_node->is_leaf()) {
- //recurse on query_node only
-
+ else if (reference_node->is_leaf())
+ {
+ // Recurse on query_node only.
number_q_recursions_++;
double left_dist =
- query_node->left()->bound().MinDistance(reference_node->bound());
+ query_node->left()->bound().MinDistance(reference_node->bound());
double right_dist =
- query_node->right()->bound().MinDistance(reference_node->bound());
+ query_node->right()->bound().MinDistance(reference_node->bound());
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node, left_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node, right_dist);
+ ComputeNeighborsRecursion_(query_node->left(), reference_node, left_dist);
+ ComputeNeighborsRecursion_(query_node->right(), reference_node,
+ right_dist);
- // Update query_node's stat
+ // Update query_node's stat.
query_node->stat().set_max_neighbor_distance(
std::max(query_node->left()->stat().max_neighbor_distance(),
- query_node->right()->stat().max_neighbor_distance()));
+ query_node->right()->stat().max_neighbor_distance()));
}
- else {
- //recurse on both
-
+ else
+ {
+ // Recurse on both.
number_both_recursions_++;
- double left_dist = query_node->left()->bound().
- MinDistance(reference_node->left()->bound());
- double right_dist = query_node->left()->bound().
- MinDistance(reference_node->right()->bound());
+ double left_dist = query_node->left()->bound().MinDistance(
+ reference_node->left()->bound());
+ double right_dist = query_node->left()->bound().MinDistance(
+ reference_node->right()->bound());
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(), right_dist);
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
+ right_dist);
}
- else {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(), left_dist);
+ else
+ {
+ ComputeNeighborsRecursion_(query_node->left(), reference_node->right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node->left(), reference_node->left(),
+ left_dist);
}
- left_dist = query_node->right()->bound().
- MinDistance(reference_node->left()->bound());
- right_dist = query_node->right()->bound().
- MinDistance(reference_node->right()->bound());
+ left_dist = query_node->right()->bound().MinDistance(
+ reference_node->left()->bound());
+ right_dist = query_node->right()->bound().MinDistance(
+ reference_node->right()->bound());
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(), right_dist);
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
+ right_dist);
}
- else {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(), left_dist);
+ else
+ {
+ ComputeNeighborsRecursion_(query_node->right(), reference_node->right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node->right(), reference_node->left(),
+ left_dist);
}
query_node->stat().set_max_neighbor_distance(
std::max(query_node->left()->stat().max_neighbor_distance(),
- query_node->right()->stat().max_neighbor_distance()));
-
- }// end else
-
+ query_node->right()->stat().max_neighbor_distance()));
+ }
} // ComputeNeighborsRecursion_
/**
- * Computes the nearest neighbor of each point in each iteration
+ * Computes the nearest neighbor of each point in each iteration
* of the algorithm
*/
- void ComputeNeighbors_() {
- if (do_naive_) {
+ void ComputeNeighbors_()
+ {
+ if (do_naive_)
+ {
ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
}
- else {
+ else
+ {
ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
}
} // ComputeNeighbors_
-
- struct SortEdgesHelper_ {
- bool operator() (const EdgePair& pairA, const EdgePair& pairB) {
+ struct SortEdgesHelper_
+ {
+ bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+ {
return (pairA.distance() < pairB.distance());
}
} SortFun;
- void SortEdges_() {
-
+ void SortEdges_()
+ {
std::sort(edges_.begin(), edges_.end(), SortFun);
-
} // SortEdges_()
/**
- * Unpermute the edge list and output it to results
+ * Unpermute the edge list and output it to results
*
* TODO: Make this sort the edge list by distance as well for hierarchical
* clusterings.
*/
- void EmitResults_(arma::mat& results) {
-
+ void EmitResults_(arma::mat& results)
+ {
SortEdges_();
mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
results.set_size(number_of_edges_, 3);
- // need to unpermute the point labels
- if (!do_naive_) {
- for (size_t i = 0; i < (number_of_points_ - 1); i++) {
-
+ // Need to unpermute the point labels.
+ if (!do_naive_)
+ {
+ for (size_t i = 0; i < (number_of_points_ - 1); i++)
+ {
// Make sure the edge list stores the smaller index first to
// make checking correctness easier
size_t ind1, ind2;
@@ -444,81 +446,79 @@
results(i, 0) = edges_[i].lesser_index();
results(i, 1) = edges_[i].greater_index();
results(i, 2) = sqrt(edges_[i].distance());
-
}
}
- else {
-
- for (size_t i = 0; i < number_of_edges_; i++) {
+ else
+ {
+ for (size_t i = 0; i < number_of_edges_; i++)
+ {
results(i, 0) = edges_[i].lesser_index();
results(i, 1) = edges_[i].greater_index();
results(i, 2) = sqrt(edges_[i].distance());
}
-
}
-
} // EmitResults_
-
-
/**
- * This function resets the values in the nodes of the tree
- * nearest neighbor distance, check for fully connected nodes
+ * This function resets the values in the nodes of the tree nearest neighbor
+ * distance, check for fully connected nodes
*/
- void CleanupHelper_(DTBTree* tree) {
-
+ void CleanupHelper_(DTBTree* tree)
+ {
tree->stat().set_max_neighbor_distance(DBL_MAX);
- if (!tree->is_leaf()) {
+ if (!tree->is_leaf())
+ {
CleanupHelper_(tree->left());
CleanupHelper_(tree->right());
if ((tree->left()->stat().component_membership() >= 0)
&& (tree->left()->stat().component_membership() ==
- tree->right()->stat().component_membership())) {
+ tree->right()->stat().component_membership()))
+ {
tree->stat().set_component_membership(tree->left()->stat().
- component_membership());
+ component_membership());
}
}
- else {
-
+ else
+ {
size_t new_membership = connections_.Find(tree->begin());
- for (size_t i = tree->begin(); i < tree->end(); i++) {
- if (new_membership != connections_.Find(i)) {
+ for (size_t i = tree->begin(); i < tree->end(); i++)
+ {
+ if (new_membership != connections_.Find(i))
+ {
new_membership = -1;
mlpack::Log::Assert(tree->stat().component_membership() < 0);
return;
}
}
tree->stat().set_component_membership(new_membership);
-
}
-
} // CleanupHelper_
/**
- * The values stored in the tree must be reset on each iteration.
+ * The values stored in the tree must be reset on each iteration.
*/
- void Cleanup_() {
-
- for (size_t i = 0; i < number_of_points_; i++) {
+ void Cleanup_()
+ {
+ for (size_t i = 0; i < number_of_points_; i++)
+ {
neighbors_distances_[i] = DBL_MAX;
- //DEBUG_ONLY(neighbors_in_component_[i] = BIG_BAD_NUMBER);
- //DEBUG_ONLY(neighbors_out_component_[i] = BIG_BAD_NUMBER);
}
number_of_loops_++;
- if (!do_naive_) {
+ if (!do_naive_)
+ {
CleanupHelper_(tree_);
}
}
/**
- * Format and output the results
+ * Format and output the results
*/
- void OutputResults_() {
-
+ void OutputResults_()
+ {
/* fx_result_double(module_, "total_squared_length", total_dist_);
fx_result_int(module_, "number_of_points", number_of_points_);
fx_result_int(module_, "dimension", data_points_.n_rows);
@@ -544,31 +544,28 @@
*/
mlpack::CLI::GetParam<double>("dtb/total_squared_length") = total_dist_;
-
} // OutputResults_
/////////// Public Functions ///////////////////
-
public:
-
- size_t number_of_edges() {
+ size_t number_of_edges()
+ {
return number_of_edges_;
}
-
/**
- * Takes in a reference to the data set. Copies the data,
- * builds the tree, and initializes all of the member variables.
- *
+ * Takes in a reference to the data set. Copies the data, builds the tree,
+ * and initializes all of the member variables.
*/
- void Init(const arma::mat& data) {
-
+ void Init(const arma::mat& data)
+ {
number_of_edges_ = 0;
data_points_ = data; // copy
do_naive_ = CLI::GetParam<bool>("naive/do_naive");
- if (!do_naive_) {
+ if (!do_naive_)
+ {
// Default leaf size is 1
// This gives best pruning empirically
// Use leaf_size=1 unless space is a big concern
@@ -580,13 +577,11 @@
tree_ = new DTBTree(data_points_, old_from_new_permutation_);
Timers::StopTimer("emst/tree_building");
-
}
- else {
-
+ else
+ {
tree_ = NULL;
old_from_new_permutation_.resize(0);
-
}
number_of_points_ = data_points_.n_cols;
@@ -608,24 +603,22 @@
number_both_recursions_ = 0;
} // Init
-
/**
* Call this function after Init. It will iteratively find the nearest
* neighbor of each component until the MST is complete.
*/
- void ComputeMST(arma::mat& results) {
-
+ void ComputeMST(arma::mat& results)
+ {
Timers::StartTimer("emst/MST_computation");
- while (number_of_edges_ < (number_of_points_ - 1)) {
-
+ while (number_of_edges_ < (number_of_points_ - 1))
+ {
ComputeNeighbors_();
AddAllEdges_();
Cleanup_();
-
Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
Log::Info << number_of_edges_ << " edges found so far.\n\n";
/*
@@ -636,7 +629,6 @@
Log::Info << number_q_recursions_ << " query recursions.\n";
Log::Info << number_both_recursions_ << " dual recursions.\n\n";
*/
-
}
Timers::StopTimer("emst/MST_computation");
@@ -644,10 +636,9 @@
EmitResults_(results);
OutputResults_();
-
} // ComputeMST
-}; //class DualTreeBoruvka
+}; // class DualTreeBoruvka
}; // namespace emst
}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/emst/emst.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.hpp 2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.hpp 2011-11-23 22:22:16 UTC (rev 10363)
@@ -1,12 +1,11 @@
/**
-* @file emst.h
-*
-* @author Bill March (march at gatech.edu)
-*
-* This file contains utilities necessary for all of the minimum spanning tree
-* algorithms.
-*/
-
+ * @file emst.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * This file contains utilities necessary for all of the minimum spanning tree
+ * algorithms.
+ */
#ifndef __MLPACK_METHODS_EMST_EMST_HPP
#define __MLPACK_METHODS_EMST_EMST_HPP
@@ -20,59 +19,61 @@
/**
* An edge pair is simply two indices and a distance. It is used as the
* basic element of an edge list when computing a minimum spanning tree.
-*/
-class EdgePair {
-
-private:
+ */
+class EdgePair
+{
+ private:
size_t lesser_index_;
size_t greater_index_;
double distance_;
-public:
+ public:
+ /**
+ * Initialize an EdgePair with two indices and a distance. The indices are
+ * called lesser and greater, implying that they be sorted before calling
+ * Init. However, this is not necessary for functionality; it is just a way
+ * to keep the edge list organized in other code.
+ */
+ void Init(size_t lesser, size_t greater, double dist)
+ {
+ mlpack::Log::Assert(lesser != greater,
+ "indices equal when creating EdgePair, lesser == greater");
+ lesser_index_ = lesser;
+ greater_index_ = greater;
+ distance_ = dist;
+ }
-
- /**
- * Initialize an EdgePair with two indices and a distance. The indices are
- * called lesser and greater, implying that they be sorted before calling
- * Init. However, this is not necessary for functionality; it is just a way
- * to keep the edge list organized in other code.
- */
- void Init(size_t lesser, size_t greater, double dist) {
-
- mlpack::Log::Assert(lesser != greater,
- "indices equal when creating EdgePair, lesser == greater");
- lesser_index_ = lesser;
- greater_index_ = greater;
- distance_ = dist;
-
- }
-
- size_t lesser_index() {
+ size_t lesser_index()
+ {
return lesser_index_;
}
- void set_lesser_index(size_t index) {
+ void set_lesser_index(size_t index)
+ {
lesser_index_ = index;
}
- size_t greater_index() {
+ size_t greater_index()
+ {
return greater_index_;
}
- void set_greater_index(size_t index) {
+ void set_greater_index(size_t index)
+ {
greater_index_ = index;
}
- double distance() const {
+ double distance() const
+ {
return distance_;
}
- void set_distance(double new_dist) {
+ void set_distance(double new_dist)
+ {
distance_ = new_dist;
}
+}; // class EdgePair
-};// class EdgePair
-
}; // namespace emst
}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp 2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp 2011-11-23 22:22:16 UTC (rev 10363)
@@ -1,5 +1,5 @@
/**
-* @file emst.cc
+ * @file emst.cc
*
* Calls the DualTreeBoruvka algorithm from dtb.h
* Can optionally call Naive Boruvka's method
@@ -10,7 +10,7 @@
* In KDD, 2010.
*
* @author Bill March (march at gatech.edu)
-*/
+ */
#include "dtb.hpp"
@@ -27,11 +27,10 @@
using namespace mlpack;
using namespace mlpack::emst;
-int main(int argc, char* argv[]) {
-
+int main(int argc, char* argv[])
+{
CLI::ParseCommandLine(argc, argv);
-
///////////////// READ IN DATA //////////////////////////////////
std::string data_file_name = CLI::GetParam<std::string>("emst/input_file");
@@ -41,8 +40,8 @@
data::Load(data_file_name.c_str(), data_points, true);
// Do naive
- if (CLI::GetParam<bool>("naive/do_naive")) {
-
+ if (CLI::GetParam<bool>("naive/do_naive"))
+ {
Log::Info << "Running naive algorithm.\n";
DualTreeBoruvka naive;
@@ -58,8 +57,8 @@
data::Save(naive_output_filename.c_str(), naive_results, true);
}
- else {
-
+ else
+ {
Log::Info << "Data read, building tree.\n";
/////////////// Initialize DTB //////////////////////
@@ -83,5 +82,4 @@
}
return 0;
-
}
Modified: mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.hpp 2011-11-23 22:06:03 UTC (rev 10362)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.hpp 2011-11-23 22:22:16 UTC (rev 10363)
@@ -8,7 +8,6 @@
* Calling unionfind.Union(x, y) unites the components indexed by x and y.
* unionfind.Find(x) returns the index of the component containing point x.
*/
-
#ifndef __MLPACK_METHODS_EMST_UNION_FIND_HPP
#define __MLPACK_METHODS_EMST_UNION_FIND_HPP
@@ -18,20 +17,16 @@
namespace emst {
/**
- * @class UnionFind
- *
- *A Union-Find data structure. See Cormen, Rivest, & Stein for details.
+ * A Union-Find data structure. See Cormen, Rivest, & Stein for details.
*/
-class UnionFind {
- friend class TestUnionFind;
-private:
-
+class UnionFind
+{
+ private:
arma::Col<size_t> parent_;
arma::ivec rank_;
size_t number_of_elements_;
-public:
-
+ public:
UnionFind() {}
~UnionFind() {}
@@ -42,17 +37,16 @@
*
* @param size The number of elements to be tracked.
*/
-
- void Init(size_t size) {
-
+ void Init(size_t size)
+ {
number_of_elements_ = size;
parent_.set_size(number_of_elements_);
rank_.set_size(number_of_elements_);
- for (size_t i = 0; i < number_of_elements_; i++) {
+ for (size_t i = 0; i < number_of_elements_; i++)
+ {
parent_[i] = i;
rank_[i] = 0;
}
-
}
/**
@@ -61,17 +55,18 @@
* @param x the component to be found
* @return The index of the component containing x
*/
- size_t Find(size_t x) {
-
- if (parent_[x] == x) {
+ size_t Find(size_t x)
+ {
+ if (parent_[x] == x)
+ {
return x;
}
- else {
+ else
+ {
// This ensures that the tree has a small depth
parent_[x] = Find(parent_[x]);
return parent_[x];
}
-
}
/**
@@ -82,29 +77,31 @@
* @param x one component
* @param y the other component
*/
- void Union(size_t x, size_t y) {
-
+ void Union(size_t x, size_t y)
+ {
size_t x_root = Find(x);
size_t y_root = Find(y);
- if (x_root == y_root) {
+ if (x_root == y_root)
+ {
return;
}
- else if (rank_[x_root] == rank_[y_root]) {
+ else if (rank_[x_root] == rank_[y_root])
+ {
parent_[y_root] = parent_[x_root];
rank_[x_root] = rank_[x_root] + 1;
}
- else if (rank_[x_root] > rank_[y_root]) {
+ else if (rank_[x_root] > rank_[y_root])
+ {
parent_[y_root] = x_root;
}
- else {
+ else
+ {
parent_[x_root] = y_root;
}
-
}
+}; // class UnionFind
-}; //class UnionFind
-
}; // namespace emst
}; // namespace mlpack
More information about the mlpack-svn
mailing list