[mlpack-svn] r10039 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Oct 26 12:45:00 EDT 2011
Author: jcline3
Date: 2011-10-26 12:44:59 -0400 (Wed, 26 Oct 2011)
New Revision: 10039
Added:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/emst.hpp
mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
mlpack/trunk/src/mlpack/methods/emst/union_find.hpp
mlpack/trunk/src/mlpack/methods/emst/union_find_test.cpp
Removed:
mlpack/trunk/src/mlpack/methods/emst/dtb.h
mlpack/trunk/src/mlpack/methods/emst/emst.h
mlpack/trunk/src/mlpack/methods/emst/emst_main.cc
mlpack/trunk/src/mlpack/methods/emst/union_find.h
mlpack/trunk/src/mlpack/methods/emst/union_find_test.cc
Modified:
mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
Log:
Move emst to hpp, cpp extensions
Change #ifndef, #defines
Modified: mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2011-10-26 16:44:59 UTC (rev 10039)
@@ -4,10 +4,10 @@
# Anything not in this list will not be compiled into MLPACK.
set(SOURCES
# union_find
- union_find.h
+ union_find.hpp
# dtb
- dtb.h
- emst.h
+ dtb.hpp
+ emst.hpp
)
# Add directory name to sources.
@@ -20,14 +20,14 @@
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
add_executable(emst
- emst_main.cc
+ emst_main.cpp
)
target_link_libraries(emst
mlpack
)
add_executable(union_find_test
- union_find_test.cc
+ union_find_test.cpp
)
target_link_libraries(union_find_test
mlpack
Deleted: mlpack/trunk/src/mlpack/methods/emst/dtb.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.h 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.h 2011-10-26 16:44:59 UTC (rev 10039)
@@ -1,644 +0,0 @@
-/**
-* @file dtb.h
- *
- * @author Bill March (march at gatech.edu)
- *
- * Contains an implementation of the DualTreeBoruvka algorithm for finding a
- * Euclidean Minimum Spanning Tree.
- *
- * Citation: March, W. B.; Ram, P.; and Gray, A. G. Fast Euclidean Minimum Spanning
- * Tree: Algorithm, Analysis, Applications. In KDD, 2010.
- *
- */
-
-#ifndef DTB_H
-#define DTB_H
-
-#include "emst.h"
-
-#include <mlpack/core.h>
-#include <mlpack/core/tree/bounds.hpp>
-#include <mlpack/core/tree/spacetree.hpp>
-#include <mlpack/core/kernels/lmetric.hpp>
-
-namespace mlpack {
-namespace emst {
-
-/*
-const fx_submodule_doc dtb_submodules[] = {
-FX_SUBMODULE_DOC_DONE
-};
- */
-
-/**
-* A Stat class for use with fastlib's trees. This one only stores two values.
- *
- * @param max_neighbor_distance The upper bound on the distance to the nearest
- * neighbor of any point in this node.
- *
- * @param component_membership The index of the component that all points in
- * this node belong to. This is the same index returned by UnionFind for all
- * points in this node. If points in this node are in different components,
- * this value will be negative.
- */
-class DTBStat {
- private:
- double max_neighbor_distance_;
- size_t component_membership_;
-
- public:
- void set_max_neighbor_distance(double distance) {
- max_neighbor_distance_ = distance;
- }
-
- double max_neighbor_distance() {
- return max_neighbor_distance_;
- }
-
- void set_component_membership(size_t membership) {
- component_membership_ = membership;
- }
-
- size_t component_membership() {
- return component_membership_;
- }
-
- /**
- * A generic initializer.
- */
- DTBStat() {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
-
- /**
- * An initializer for leaves.
- */
- DTBStat(const arma::mat& dataset, size_t start, size_t count) {
- if (count == 1) {
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
- } else {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
- }
-
- /**
- * An initializer for non-leaves. Simply calls the leaf initializer.
- */
- DTBStat(const arma::mat& dataset, size_t start, size_t count,
- const DTBStat& left_stat, const DTBStat& right_stat) {
- if (count == 1) {
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
- } else {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
- }
-
-}; // class DTBStat
-
-
-/**
- * Performs the MST calculation using the Dual-Tree Boruvka algorithm.
- */
-class DualTreeBoruvka {
-
-// FORBID_ACCIDENTAL_COPIES(DualTreeBoruvka);
-
- public:
- // For now, everything is in Euclidean space
- static const size_t metric = 2;
-
- typedef tree::BinarySpaceTree<bound::HRectBound<metric>, DTBStat> DTBTree;
-
- //////// Member Variables /////////////////////
-
- private:
-
- size_t number_of_edges_;
- std::vector<EdgePair> edges_; // must use vector with non-numerical types
- size_t number_of_points_;
- UnionFind connections_;
- struct datanode* module_;
- arma::mat data_points_;
- size_t leaf_size_;
-
- // lists
- std::vector<size_t> old_from_new_permutation_;
- arma::Col<size_t> neighbors_in_component_;
- arma::Col<size_t> neighbors_out_component_;
- arma::vec neighbors_distances_;
-
- // output info
- double total_dist_;
- size_t number_of_loops_;
- size_t number_distance_prunes_;
- size_t number_component_prunes_;
- size_t number_leaf_computations_;
- size_t number_q_recursions_;
- size_t number_r_recursions_;
- size_t number_both_recursions_;
-
- int do_naive_;
-
- DTBTree* tree_;
-
-
-////////////////// Constructors ////////////////////////
-
- public:
-
- DualTreeBoruvka() {}
-
- ~DualTreeBoruvka() {
- if (tree_ != NULL) {
- delete tree_;
- }
- }
-
-
- ////////////////////////// Private Functions ////////////////////
- private:
-
- /**
- * Adds a single edge to the edge list
- */
- void AddEdge_(size_t e1, size_t e2, double distance) {
-
- //EdgePair edge;
- mlpack::Log::Assert((e1 != e2),
- "Indices are equal in DualTreeBoruvka.add_edge(...)");
-
- mlpack::Log::Assert((distance >= 0.0),
- "Negative distance input in DualTreeBoruvka.add_edge(...)");
-
- if (e1 < e2) {
- edges_[number_of_edges_].Init(e1, e2, distance);
- }
- else {
- edges_[number_of_edges_].Init(e2, e1, distance);
- }
-
- number_of_edges_++;
-
- } // AddEdge_
-
-
- /**
- * Adds all the edges found in one iteration to the list of neighbors.
- */
- void AddAllEdges_() {
-
- for (size_t i = 0; i < number_of_points_; i++) {
- size_t component_i = connections_.Find(i);
- size_t in_edge_i = neighbors_in_component_[component_i];
- size_t out_edge_i = neighbors_out_component_[component_i];
- if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i)) {
- double dist = neighbors_distances_[component_i];
- //total_dist_ = total_dist_ + dist;
- // changed to make this agree with the cover tree code
- total_dist_ = total_dist_ + sqrt(dist);
- AddEdge_(in_edge_i, out_edge_i, dist);
- connections_.Union(in_edge_i, out_edge_i);
- }
- }
-
- } // AddAllEdges_
-
-
- /**
- * Handles the base case computation. Also called by naive.
- */
- double ComputeBaseCase_(size_t query_start, size_t query_end,
- size_t reference_start, size_t reference_end) {
-
- number_leaf_computations_++;
-
- double new_upper_bound = -1.0;
-
- for (size_t query_index = query_start; query_index < query_end;
- query_index++) {
-
- // Find the index of the component the query is in
- size_t query_component_index = connections_.Find(query_index);
-
- arma::vec query_point = data_points_.col(query_index);
-
- for (size_t reference_index = reference_start;
- reference_index < reference_end; reference_index++) {
-
- size_t reference_component_index = connections_.Find(reference_index);
-
- if (query_component_index != reference_component_index) {
-
- arma::vec reference_point = data_points_.col(reference_index);
-
- double distance = mlpack::kernel::LMetric<2>::Evaluate(query_point,
- reference_point);
-
- if (distance < neighbors_distances_[query_component_index]) {
-
- mlpack::Log::Assert(query_index != reference_index);
-
- neighbors_distances_[query_component_index] = distance;
- neighbors_in_component_[query_component_index] = query_index;
- neighbors_out_component_[query_component_index] = reference_index;
-
- } // if distance
-
- } // if indices not equal
-
-
- } // for reference_index
-
- if (new_upper_bound < neighbors_distances_[query_component_index]) {
- new_upper_bound = neighbors_distances_[query_component_index];
- }
-
- } // for query_index
-
- mlpack::Log::Assert(new_upper_bound >= 0.0);
- return new_upper_bound;
-
- } // ComputeBaseCase_
-
-
- /**
- * Handles the recursive calls to find the nearest neighbors in an iteration
- */
- void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
- double incoming_distance) {
-
- // Check for a distance prune
- if (query_node->stat().max_neighbor_distance() < incoming_distance) {
- //pruned by distance
- number_distance_prunes_++;
- }
- // Check for a component prune
- else if ((query_node->stat().component_membership() >= 0)
- && (query_node->stat().component_membership() ==
- reference_node->stat().component_membership())) {
- //pruned by component membership
-
- mlpack::Log::Assert(reference_node->stat().component_membership() >= 0);
-
- number_component_prunes_++;
- }
- // The base case
- else if (query_node->is_leaf() && reference_node->is_leaf()) {
-
- double new_bound =
- ComputeBaseCase_(query_node->begin(), query_node->end(),
- reference_node->begin(), reference_node->end());
-
- query_node->stat().set_max_neighbor_distance(new_bound);
-
- }
- // Other recursive calls
- else if (query_node->is_leaf()) {
- //recurse on reference_node only
- number_r_recursions_++;
-
- double left_dist =
- query_node->bound().MinDistance(reference_node->left()->bound());
- double right_dist =
- query_node->bound().MinDistance(reference_node->right()->bound());
- mlpack::Log::Assert(left_dist >= 0.0);
- mlpack::Log::Assert(right_dist >= 0.0);
-
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(), right_dist);
- }
- else {
- ComputeNeighborsRecursion_(query_node,
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node,
- reference_node->left(), left_dist);
- }
-
- }
- else if (reference_node->is_leaf()) {
- //recurse on query_node only
-
- number_q_recursions_++;
-
- double left_dist =
- query_node->left()->bound().MinDistance(reference_node->bound());
- double right_dist =
- query_node->right()->bound().MinDistance(reference_node->bound());
-
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node, left_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node, right_dist);
-
- // Update query_node's stat
- query_node->stat().set_max_neighbor_distance(
- std::max(query_node->left()->stat().max_neighbor_distance(),
- query_node->right()->stat().max_neighbor_distance()));
-
- }
- else {
- //recurse on both
-
- number_both_recursions_++;
-
- double left_dist = query_node->left()->bound().
- MinDistance(reference_node->left()->bound());
- double right_dist = query_node->left()->bound().
- MinDistance(reference_node->right()->bound());
-
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(), right_dist);
- }
- else {
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node->left(),
- reference_node->left(), left_dist);
- }
-
- left_dist = query_node->right()->bound().
- MinDistance(reference_node->left()->bound());
- right_dist = query_node->right()->bound().
- MinDistance(reference_node->right()->bound());
-
- if (left_dist < right_dist) {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(), left_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(), right_dist);
- }
- else {
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->right(), right_dist);
- ComputeNeighborsRecursion_(query_node->right(),
- reference_node->left(), left_dist);
- }
-
- query_node->stat().set_max_neighbor_distance(
- std::max(query_node->left()->stat().max_neighbor_distance(),
- query_node->right()->stat().max_neighbor_distance()));
-
- }// end else
-
- } // ComputeNeighborsRecursion_
-
- /**
- * Computes the nearest neighbor of each point in each iteration
- * of the algorithm
- */
- void ComputeNeighbors_() {
- if (do_naive_) {
- ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
- }
- else {
- ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
- }
- } // ComputeNeighbors_
-
-
- struct SortEdgesHelper_ {
- bool operator() (const EdgePair& pairA, const EdgePair& pairB) {
- return (pairA.distance() > pairB.distance());
- }
- } SortFun;
-
- void SortEdges_() {
-
- std::sort(edges_.begin(), edges_.end(), SortFun);
-
- } // SortEdges_()
-
- /**
- * Unpermute the edge list and output it to results
- *
- * TODO: Make this sort the edge list by distance as well for hierarchical
- * clusterings.
- */
- void EmitResults_(arma::mat& results) {
-
- SortEdges_();
-
- mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
- results.set_size(3, number_of_edges_);
-
- if (!do_naive_) {
- for (size_t i = 0; i < (number_of_points_ - 1); i++) {
-
- edges_[i].set_lesser_index(old_from_new_permutation_[edges_[i]
- .lesser_index()]);
-
- edges_[i].set_greater_index(old_from_new_permutation_[edges_[i]
- .greater_index()]);
-
- results(0, i) = edges_[i].lesser_index();
- results(1, i) = edges_[i].greater_index();
- results(2, i) = sqrt(edges_[i].distance());
-
- }
- }
- else {
-
- for (size_t i = 0; i < number_of_edges_; i++) {
- results(0, i) = edges_[i].lesser_index();
- results(1, i) = edges_[i].greater_index();
- results(2, i) = sqrt(edges_[i].distance());
- }
-
- }
-
- } // EmitResults_
-
-
-
- /**
- * This function resets the values in the nodes of the tree
- * nearest neighbor distance, check for fully connected nodes
- */
- void CleanupHelper_(DTBTree* tree) {
-
- tree->stat().set_max_neighbor_distance(DBL_MAX);
-
- if (!tree->is_leaf()) {
- CleanupHelper_(tree->left());
- CleanupHelper_(tree->right());
-
- if ((tree->left()->stat().component_membership() >= 0)
- && (tree->left()->stat().component_membership() ==
- tree->right()->stat().component_membership())) {
- tree->stat().set_component_membership(tree->left()->stat().
- component_membership());
- }
- }
- else {
-
- size_t new_membership = connections_.Find(tree->begin());
-
- for (size_t i = tree->begin(); i < tree->end(); i++) {
- if (new_membership != connections_.Find(i)) {
- new_membership = -1;
- mlpack::Log::Assert(tree->stat().component_membership() < 0);
- return;
- }
- }
- tree->stat().set_component_membership(new_membership);
-
- }
-
- } // CleanupHelper_
-
- /**
- * The values stored in the tree must be reset on each iteration.
- */
- void Cleanup_() {
-
- for (size_t i = 0; i < number_of_points_; i++) {
- neighbors_distances_[i] = DBL_MAX;
- //DEBUG_ONLY(neighbors_in_component_[i] = BIG_BAD_NUMBER);
- //DEBUG_ONLY(neighbors_out_component_[i] = BIG_BAD_NUMBER);
- }
- number_of_loops_++;
-
- if (!do_naive_) {
- CleanupHelper_(tree_);
- }
- }
-
- /**
- * Format and output the results
- */
- void OutputResults_() {
-
- /* fx_result_double(module_, "total_squared_length", total_dist_);
- fx_result_int(module_, "number_of_points", number_of_points_);
- fx_result_int(module_, "dimension", data_points_.n_rows);
- fx_result_int(module_, "number_of_loops", number_of_loops_);
- fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
- fx_result_int(module_, "number_component_prunes", number_component_prunes_);
- fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
- fx_result_int(module_, "number_q_recursions", number_q_recursions_);
- fx_result_int(module_, "number_r_recursions", number_r_recursions_);
- fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
- // TODO, not sure how I missed this last time.
- mlpack::Log::Info << "total_squared_length" << total_dist_ << std::endl;
- mlpack::Log::Info << "number_of_points" << number_of_points_ << std::endl;
- mlpack::Log::Info << "dimension" << data_points_.n_rows << std::endl;
- mlpack::Log::Info << "number_of_loops" << std::endl;
- mlpack::Log::Info << "number_distance_prunes" << std::endl;
- mlpack::Log::Info << "number_component_prunes" << std::endl;
- mlpack::Log::Info << "number_leaf_computations" << std::endl;
- mlpack::Log::Info << "number_q_recursions" << std::endl;
- mlpack::Log::Info << "number_r_recursions" << std::endl;
- mlpack::Log::Info << "number_both_recursions" << std::endl;
- } // OutputResults_
-
- /////////// Public Functions ///////////////////
-
- public:
-
- size_t number_of_edges() {
- return number_of_edges_;
- }
-
-
- /**
- * Takes in a reference to the data set and a module. Copies the data,
- * builds the tree, and initializes all of the member variables.
- *
- * This module will be checked for the optional parameters "leaf_size" and
- * "do_naive".
- */
- void Init(const arma::mat& data) {
-
- number_of_edges_ = 0;
- data_points_ = data; // copy
-
- do_naive_ = CLI::GetParam<bool>("naive/do_naive");
-
- if (!do_naive_) {
- // Default leaf size is 1
- // This gives best pruning empirically
- // Use leaf_size=1 unless space is a big concern
- CLI::GetParam<int>("tree/leaf_size") =
- CLI::GetParam<size_t>("naive/leaf_size");
-
- CLI::StartTimer("naive/tree_building");
-
- tree_ = new DTBTree(data_points_, old_from_new_permutation_);
-
- CLI::StopTimer("naive/tree_building");
- }
- else {
- tree_ = NULL;
- old_from_new_permutation_.resize(0);
- }
-
- number_of_points_ = data_points_.n_cols;
- edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
- connections_.Init(number_of_points_);
-
- neighbors_in_component_.set_size(number_of_points_);
- neighbors_out_component_.set_size(number_of_points_);
- neighbors_distances_.set_size(number_of_points_);
- neighbors_distances_.fill(DBL_MAX);
-
- total_dist_ = 0.0;
- number_of_loops_ = 0;
- number_distance_prunes_ = 0;
- number_component_prunes_ = 0;
- number_leaf_computations_ = 0;
- number_q_recursions_ = 0;
- number_r_recursions_ = 0;
- number_both_recursions_ = 0;
- } // Init
-
-
- /**
- * Call this function after Init. It will iteratively find the nearest
- * neighbor of each component until the MST is complete.
- */
- void ComputeMST(arma::mat& results) {
-
- CLI::StartTimer("emst/MST_computation");
-
- while (number_of_edges_ < (number_of_points_ - 1)) {
- ComputeNeighbors_();
-
- AddAllEdges_();
-
- Cleanup_();
-
- Log::Info << "number_of_loops = " << number_of_loops_ << std::endl;
- }
-
- CLI::StopTimer("emst/MST_computation");
-
-// if (results != NULL) {
-
- EmitResults_(results);
-
-// }
-
-
- OutputResults_();
-
- } // ComputeMST
-
-}; //class DualTreeBoruvka
-
-}; // namespace emst
-}; // namespace mlpack
-
-PARAM(size_t, "leaf_size", "Size of the leaves.", "naive", 1, false);
-
-#endif // inclusion guards
Copied: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp (from rev 10033, mlpack/trunk/src/mlpack/methods/emst/dtb.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-10-26 16:44:59 UTC (rev 10039)
@@ -0,0 +1,644 @@
+/**
+* @file dtb.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * Contains an implementation of the DualTreeBoruvka algorithm for finding a
+ * Euclidean Minimum Spanning Tree.
+ *
+ * Citation: March, W. B.; Ram, P.; and Gray, A. G. Fast Euclidean Minimum Spanning
+ * Tree: Algorithm, Analysis, Applications. In KDD, 2010.
+ *
+ */
+
+#ifndef __MLPACK_METHODS_EMST_DTB_HPP
+#define __MLPACK_METHODS_EMST_DTB_HPP
+
+#include "emst.hpp"
+
+#include <mlpack/core.h>
+#include <mlpack/core/tree/bounds.hpp>
+#include <mlpack/core/tree/spacetree.hpp>
+#include <mlpack/core/kernels/lmetric.hpp>
+
+PARAM(size_t, "leaf_size", "Size of the leaves.", "naive", 1, false);
+
+namespace mlpack {
+namespace emst {
+
+/*
+const fx_submodule_doc dtb_submodules[] = {
+FX_SUBMODULE_DOC_DONE
+};
+ */
+
+/**
+* A Stat class for use with fastlib's trees. This one only stores two values.
+ *
+ * @param max_neighbor_distance The upper bound on the distance to the nearest
+ * neighbor of any point in this node.
+ *
+ * @param component_membership The index of the component that all points in
+ * this node belong to. This is the same index returned by UnionFind for all
+ * points in this node. If points in this node are in different components,
+ * this value will be negative.
+ */
+class DTBStat {
+ private:
+ double max_neighbor_distance_;
+ size_t component_membership_;
+
+ public:
+ void set_max_neighbor_distance(double distance) {
+ max_neighbor_distance_ = distance;
+ }
+
+ double max_neighbor_distance() {
+ return max_neighbor_distance_;
+ }
+
+ void set_component_membership(size_t membership) {
+ component_membership_ = membership;
+ }
+
+ size_t component_membership() {
+ return component_membership_;
+ }
+
+ /**
+ * A generic initializer.
+ */
+ DTBStat() {
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+ }
+
+ /**
+ * An initializer for leaves.
+ */
+ DTBStat(const arma::mat& dataset, size_t start, size_t count) {
+ if (count == 1) {
+ set_component_membership(start);
+ set_max_neighbor_distance(DBL_MAX);
+ } else {
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+ }
+ }
+
+ /**
+ * An initializer for non-leaves. Simply calls the leaf initializer.
+ */
+ DTBStat(const arma::mat& dataset, size_t start, size_t count,
+ const DTBStat& left_stat, const DTBStat& right_stat) {
+ if (count == 1) {
+ set_component_membership(start);
+ set_max_neighbor_distance(DBL_MAX);
+ } else {
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+ }
+ }
+
+}; // class DTBStat
+
+
+/**
+ * Performs the MST calculation using the Dual-Tree Boruvka algorithm.
+ */
+class DualTreeBoruvka {
+
+// FORBID_ACCIDENTAL_COPIES(DualTreeBoruvka);
+
+ public:
+ // For now, everything is in Euclidean space
+ static const size_t metric = 2;
+
+ typedef tree::BinarySpaceTree<bound::HRectBound<metric>, DTBStat> DTBTree;
+
+ //////// Member Variables /////////////////////
+
+ private:
+
+ size_t number_of_edges_;
+ std::vector<EdgePair> edges_; // must use vector with non-numerical types
+ size_t number_of_points_;
+ UnionFind connections_;
+ struct datanode* module_;
+ arma::mat data_points_;
+ size_t leaf_size_;
+
+ // lists
+ std::vector<size_t> old_from_new_permutation_;
+ arma::Col<size_t> neighbors_in_component_;
+ arma::Col<size_t> neighbors_out_component_;
+ arma::vec neighbors_distances_;
+
+ // output info
+ double total_dist_;
+ size_t number_of_loops_;
+ size_t number_distance_prunes_;
+ size_t number_component_prunes_;
+ size_t number_leaf_computations_;
+ size_t number_q_recursions_;
+ size_t number_r_recursions_;
+ size_t number_both_recursions_;
+
+ int do_naive_;
+
+ DTBTree* tree_;
+
+
+////////////////// Constructors ////////////////////////
+
+ public:
+
+ DualTreeBoruvka() {}
+
+ ~DualTreeBoruvka() {
+ if (tree_ != NULL) {
+ delete tree_;
+ }
+ }
+
+
+ ////////////////////////// Private Functions ////////////////////
+ private:
+
+ /**
+ * Adds a single edge to the edge list
+ */
+ void AddEdge_(size_t e1, size_t e2, double distance) {
+
+ //EdgePair edge;
+ mlpack::Log::Assert((e1 != e2),
+ "Indices are equal in DualTreeBoruvka.add_edge(...)");
+
+ mlpack::Log::Assert((distance >= 0.0),
+ "Negative distance input in DualTreeBoruvka.add_edge(...)");
+
+ if (e1 < e2) {
+ edges_[number_of_edges_].Init(e1, e2, distance);
+ }
+ else {
+ edges_[number_of_edges_].Init(e2, e1, distance);
+ }
+
+ number_of_edges_++;
+
+ } // AddEdge_
+
+
+ /**
+ * Adds all the edges found in one iteration to the list of neighbors.
+ */
+ void AddAllEdges_() {
+
+ for (size_t i = 0; i < number_of_points_; i++) {
+ size_t component_i = connections_.Find(i);
+ size_t in_edge_i = neighbors_in_component_[component_i];
+ size_t out_edge_i = neighbors_out_component_[component_i];
+ if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i)) {
+ double dist = neighbors_distances_[component_i];
+ //total_dist_ = total_dist_ + dist;
+ // changed to make this agree with the cover tree code
+ total_dist_ = total_dist_ + sqrt(dist);
+ AddEdge_(in_edge_i, out_edge_i, dist);
+ connections_.Union(in_edge_i, out_edge_i);
+ }
+ }
+
+ } // AddAllEdges_
+
+
+ /**
+ * Handles the base case computation. Also called by naive.
+ */
+ double ComputeBaseCase_(size_t query_start, size_t query_end,
+ size_t reference_start, size_t reference_end) {
+
+ number_leaf_computations_++;
+
+ double new_upper_bound = -1.0;
+
+ for (size_t query_index = query_start; query_index < query_end;
+ query_index++) {
+
+ // Find the index of the component the query is in
+ size_t query_component_index = connections_.Find(query_index);
+
+ arma::vec query_point = data_points_.col(query_index);
+
+ for (size_t reference_index = reference_start;
+ reference_index < reference_end; reference_index++) {
+
+ size_t reference_component_index = connections_.Find(reference_index);
+
+ if (query_component_index != reference_component_index) {
+
+ arma::vec reference_point = data_points_.col(reference_index);
+
+ double distance = mlpack::kernel::LMetric<2>::Evaluate(query_point,
+ reference_point);
+
+ if (distance < neighbors_distances_[query_component_index]) {
+
+ mlpack::Log::Assert(query_index != reference_index);
+
+ neighbors_distances_[query_component_index] = distance;
+ neighbors_in_component_[query_component_index] = query_index;
+ neighbors_out_component_[query_component_index] = reference_index;
+
+ } // if distance
+
+ } // if indices not equal
+
+
+ } // for reference_index
+
+ if (new_upper_bound < neighbors_distances_[query_component_index]) {
+ new_upper_bound = neighbors_distances_[query_component_index];
+ }
+
+ } // for query_index
+
+ mlpack::Log::Assert(new_upper_bound >= 0.0);
+ return new_upper_bound;
+
+ } // ComputeBaseCase_
+
+
+ /**
+ * Handles the recursive calls to find the nearest neighbors in an iteration
+ */
+ void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
+ double incoming_distance) {
+
+ // Check for a distance prune
+ if (query_node->stat().max_neighbor_distance() < incoming_distance) {
+ //pruned by distance
+ number_distance_prunes_++;
+ }
+ // Check for a component prune
+ else if ((query_node->stat().component_membership() >= 0)
+ && (query_node->stat().component_membership() ==
+ reference_node->stat().component_membership())) {
+ //pruned by component membership
+
+ mlpack::Log::Assert(reference_node->stat().component_membership() >= 0);
+
+ number_component_prunes_++;
+ }
+ // The base case
+ else if (query_node->is_leaf() && reference_node->is_leaf()) {
+
+ double new_bound =
+ ComputeBaseCase_(query_node->begin(), query_node->end(),
+ reference_node->begin(), reference_node->end());
+
+ query_node->stat().set_max_neighbor_distance(new_bound);
+
+ }
+ // Other recursive calls
+ else if (query_node->is_leaf()) {
+ //recurse on reference_node only
+ number_r_recursions_++;
+
+ double left_dist =
+ query_node->bound().MinDistance(reference_node->left()->bound());
+ double right_dist =
+ query_node->bound().MinDistance(reference_node->right()->bound());
+ mlpack::Log::Assert(left_dist >= 0.0);
+ mlpack::Log::Assert(right_dist >= 0.0);
+
+ if (left_dist < right_dist) {
+ ComputeNeighborsRecursion_(query_node,
+ reference_node->left(), left_dist);
+ ComputeNeighborsRecursion_(query_node,
+ reference_node->right(), right_dist);
+ }
+ else {
+ ComputeNeighborsRecursion_(query_node,
+ reference_node->right(), right_dist);
+ ComputeNeighborsRecursion_(query_node,
+ reference_node->left(), left_dist);
+ }
+
+ }
+ else if (reference_node->is_leaf()) {
+ //recurse on query_node only
+
+ number_q_recursions_++;
+
+ double left_dist =
+ query_node->left()->bound().MinDistance(reference_node->bound());
+ double right_dist =
+ query_node->right()->bound().MinDistance(reference_node->bound());
+
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node, left_dist);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node, right_dist);
+
+ // Update query_node's stat
+ query_node->stat().set_max_neighbor_distance(
+ std::max(query_node->left()->stat().max_neighbor_distance(),
+ query_node->right()->stat().max_neighbor_distance()));
+
+ }
+ else {
+ //recurse on both
+
+ number_both_recursions_++;
+
+ double left_dist = query_node->left()->bound().
+ MinDistance(reference_node->left()->bound());
+ double right_dist = query_node->left()->bound().
+ MinDistance(reference_node->right()->bound());
+
+ if (left_dist < right_dist) {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(), left_dist);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(), right_dist);
+ }
+ else {
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->right(), right_dist);
+ ComputeNeighborsRecursion_(query_node->left(),
+ reference_node->left(), left_dist);
+ }
+
+ left_dist = query_node->right()->bound().
+ MinDistance(reference_node->left()->bound());
+ right_dist = query_node->right()->bound().
+ MinDistance(reference_node->right()->bound());
+
+ if (left_dist < right_dist) {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(), left_dist);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(), right_dist);
+ }
+ else {
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->right(), right_dist);
+ ComputeNeighborsRecursion_(query_node->right(),
+ reference_node->left(), left_dist);
+ }
+
+ query_node->stat().set_max_neighbor_distance(
+ std::max(query_node->left()->stat().max_neighbor_distance(),
+ query_node->right()->stat().max_neighbor_distance()));
+
+ }// end else
+
+ } // ComputeNeighborsRecursion_
+
+ /**
+ * Computes the nearest neighbor of each point in each iteration
+ * of the algorithm
+ */
+ void ComputeNeighbors_() {
+ if (do_naive_) {
+ ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
+ }
+ else {
+ ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
+ }
+ } // ComputeNeighbors_
+
+
+ struct SortEdgesHelper_ {
+ bool operator() (const EdgePair& pairA, const EdgePair& pairB) {
+ return (pairA.distance() > pairB.distance());
+ }
+ } SortFun;
+
+ void SortEdges_() {
+
+ std::sort(edges_.begin(), edges_.end(), SortFun);
+
+ } // SortEdges_()
+
+ /**
+ * Unpermute the edge list and output it to results
+ *
+ * TODO: Make this sort the edge list by distance as well for hierarchical
+ * clusterings.
+ */
+ void EmitResults_(arma::mat& results) {
+
+ SortEdges_();
+
+ mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
+ results.set_size(3, number_of_edges_);
+
+ if (!do_naive_) {
+ for (size_t i = 0; i < (number_of_points_ - 1); i++) {
+
+ edges_[i].set_lesser_index(old_from_new_permutation_[edges_[i]
+ .lesser_index()]);
+
+ edges_[i].set_greater_index(old_from_new_permutation_[edges_[i]
+ .greater_index()]);
+
+ results(0, i) = edges_[i].lesser_index();
+ results(1, i) = edges_[i].greater_index();
+ results(2, i) = sqrt(edges_[i].distance());
+
+ }
+ }
+ else {
+
+ for (size_t i = 0; i < number_of_edges_; i++) {
+ results(0, i) = edges_[i].lesser_index();
+ results(1, i) = edges_[i].greater_index();
+ results(2, i) = sqrt(edges_[i].distance());
+ }
+
+ }
+
+ } // EmitResults_
+
+
+
+ /**
+ * This function resets the values in the nodes of the tree
+ * nearest neighbor distance, check for fully connected nodes
+ */
+ void CleanupHelper_(DTBTree* tree) {
+
+ tree->stat().set_max_neighbor_distance(DBL_MAX);
+
+ if (!tree->is_leaf()) {
+ CleanupHelper_(tree->left());
+ CleanupHelper_(tree->right());
+
+ if ((tree->left()->stat().component_membership() >= 0)
+ && (tree->left()->stat().component_membership() ==
+ tree->right()->stat().component_membership())) {
+ tree->stat().set_component_membership(tree->left()->stat().
+ component_membership());
+ }
+ }
+ else {
+
+ size_t new_membership = connections_.Find(tree->begin());
+
+ for (size_t i = tree->begin(); i < tree->end(); i++) {
+ if (new_membership != connections_.Find(i)) {
+ new_membership = -1;
+ mlpack::Log::Assert(tree->stat().component_membership() < 0);
+ return;
+ }
+ }
+ tree->stat().set_component_membership(new_membership);
+
+ }
+
+ } // CleanupHelper_
+
+ /**
+ * The values stored in the tree must be reset on each iteration.
+ */
+ void Cleanup_() {
+
+ for (size_t i = 0; i < number_of_points_; i++) {
+ neighbors_distances_[i] = DBL_MAX;
+ //DEBUG_ONLY(neighbors_in_component_[i] = BIG_BAD_NUMBER);
+ //DEBUG_ONLY(neighbors_out_component_[i] = BIG_BAD_NUMBER);
+ }
+ number_of_loops_++;
+
+ if (!do_naive_) {
+ CleanupHelper_(tree_);
+ }
+ }
+
+ /**
+ * Format and output the results
+ */
+ void OutputResults_() {
+
+ /* fx_result_double(module_, "total_squared_length", total_dist_);
+ fx_result_int(module_, "number_of_points", number_of_points_);
+ fx_result_int(module_, "dimension", data_points_.n_rows);
+ fx_result_int(module_, "number_of_loops", number_of_loops_);
+ fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
+ fx_result_int(module_, "number_component_prunes", number_component_prunes_);
+ fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
+ fx_result_int(module_, "number_q_recursions", number_q_recursions_);
+ fx_result_int(module_, "number_r_recursions", number_r_recursions_);
+ fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
+ // TODO, not sure how I missed this last time.
+ mlpack::Log::Info << "total_squared_length" << total_dist_ << std::endl;
+ mlpack::Log::Info << "number_of_points" << number_of_points_ << std::endl;
+ mlpack::Log::Info << "dimension" << data_points_.n_rows << std::endl;
+ mlpack::Log::Info << "number_of_loops" << std::endl;
+ mlpack::Log::Info << "number_distance_prunes" << std::endl;
+ mlpack::Log::Info << "number_component_prunes" << std::endl;
+ mlpack::Log::Info << "number_leaf_computations" << std::endl;
+ mlpack::Log::Info << "number_q_recursions" << std::endl;
+ mlpack::Log::Info << "number_r_recursions" << std::endl;
+ mlpack::Log::Info << "number_both_recursions" << std::endl;
+ } // OutputResults_
+
+ /////////// Public Functions ///////////////////
+
+ public:
+
+ size_t number_of_edges() {
+ return number_of_edges_;
+ }
+
+
+ /**
+ * Takes in a reference to the data set and a module. Copies the data,
+ * builds the tree, and initializes all of the member variables.
+ *
+ * This module will be checked for the optional parameters "leaf_size" and
+ * "do_naive".
+ */
+ void Init(const arma::mat& data) {
+
+ number_of_edges_ = 0;
+ data_points_ = data; // copy
+
+ do_naive_ = CLI::GetParam<bool>("naive/do_naive");
+
+ if (!do_naive_) {
+ // Default leaf size is 1
+ // This gives best pruning empirically
+ // Use leaf_size=1 unless space is a big concern
+ CLI::GetParam<int>("tree/leaf_size") =
+ CLI::GetParam<size_t>("naive/leaf_size");
+
+ CLI::StartTimer("naive/tree_building");
+
+ tree_ = new DTBTree(data_points_, old_from_new_permutation_);
+
+ CLI::StopTimer("naive/tree_building");
+ }
+ else {
+ tree_ = NULL;
+ old_from_new_permutation_.resize(0);
+ }
+
+ number_of_points_ = data_points_.n_cols;
+ edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
+ connections_.Init(number_of_points_);
+
+ neighbors_in_component_.set_size(number_of_points_);
+ neighbors_out_component_.set_size(number_of_points_);
+ neighbors_distances_.set_size(number_of_points_);
+ neighbors_distances_.fill(DBL_MAX);
+
+ total_dist_ = 0.0;
+ number_of_loops_ = 0;
+ number_distance_prunes_ = 0;
+ number_component_prunes_ = 0;
+ number_leaf_computations_ = 0;
+ number_q_recursions_ = 0;
+ number_r_recursions_ = 0;
+ number_both_recursions_ = 0;
+ } // Init
+
+
+ /**
+ * Call this function after Init. It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+ void ComputeMST(arma::mat& results) {
+
+ CLI::StartTimer("emst/MST_computation");
+
+ while (number_of_edges_ < (number_of_points_ - 1)) {
+ ComputeNeighbors_();
+
+ AddAllEdges_();
+
+ Cleanup_();
+
+ Log::Info << "number_of_loops = " << number_of_loops_ << std::endl;
+ }
+
+ CLI::StopTimer("emst/MST_computation");
+
+// if (results != NULL) {
+
+ EmitResults_(results);
+
+// }
+
+
+ OutputResults_();
+
+ } // ComputeMST
+
+}; //class DualTreeBoruvka
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_DTB_HPP
Deleted: mlpack/trunk/src/mlpack/methods/emst/emst.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.h 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.h 2011-10-26 16:44:59 UTC (rev 10039)
@@ -1,79 +0,0 @@
-/**
-* @file emst.h
-*
-* @author Bill March (march at gatech.edu)
-*
-* This file contains utilities necessary for all of the minimum spanning tree
-* algorithms.
-*/
-
-#ifndef EMST_H
-#define EMST_H
-
-#include <mlpack/core.h>
-
-#include "union_find.h"
-
-namespace mlpack {
-namespace emst {
-
-/**
- * An edge pair is simply two indices and a distance. It is used as the
- * basic element of an edge list when computing a minimum spanning tree.
-*/
-class EdgePair {
-
-private:
- size_t lesser_index_;
- size_t greater_index_;
- double distance_;
-
-public:
-
-
- /**
- * Initialize an EdgePair with two indices and a distance. The indices are
- * called lesser and greater, implying that they be sorted before calling
- * Init. However, this is not necessary for functionality; it is just a way
- * to keep the edge list organized in other code.
- */
- void Init(size_t lesser, size_t greater, double dist) {
-
- mlpack::Log::Assert(lesser != greater,
- "indices equal when creating EdgePair, lesser == greater");
- lesser_index_ = lesser;
- greater_index_ = greater;
- distance_ = dist;
-
- }
-
- size_t lesser_index() {
- return lesser_index_;
- }
-
- void set_lesser_index(size_t index) {
- lesser_index_ = index;
- }
-
- size_t greater_index() {
- return greater_index_;
- }
-
- void set_greater_index(size_t index) {
- greater_index_ = index;
- }
-
- double distance() const {
- return distance_;
- }
-
- void set_distance(double new_dist) {
- distance_ = new_dist;
- }
-
-};// class EdgePair
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif
Copied: mlpack/trunk/src/mlpack/methods/emst/emst.hpp (from rev 10030, mlpack/trunk/src/mlpack/methods/emst/emst.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.hpp 2011-10-26 16:44:59 UTC (rev 10039)
@@ -0,0 +1,79 @@
+/**
+* @file emst.h
+*
+* @author Bill March (march at gatech.edu)
+*
+* This file contains utilities necessary for all of the minimum spanning tree
+* algorithms.
+*/
+
+#ifndef __MLPACK_METHODS_EMST_EMST_HPP
+#define __MLPACK_METHODS_EMST_EMST_HPP
+
+#include <mlpack/core.h>
+
+#include "union_find.hpp"
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * An edge pair is simply two indices and a distance. It is used as the
+ * basic element of an edge list when computing a minimum spanning tree.
+*/
+class EdgePair {
+
+private:
+ size_t lesser_index_;
+ size_t greater_index_;
+ double distance_;
+
+public:
+
+
+ /**
+ * Initialize an EdgePair with two indices and a distance. The indices are
+ * called lesser and greater, implying that they be sorted before calling
+ * Init. However, this is not necessary for functionality; it is just a way
+ * to keep the edge list organized in other code.
+ */
+ void Init(size_t lesser, size_t greater, double dist) {
+
+ mlpack::Log::Assert(lesser != greater,
+ "indices equal when creating EdgePair, lesser == greater");
+ lesser_index_ = lesser;
+ greater_index_ = greater;
+ distance_ = dist;
+
+ }
+
+ size_t lesser_index() {
+ return lesser_index_;
+ }
+
+ void set_lesser_index(size_t index) {
+ lesser_index_ = index;
+ }
+
+ size_t greater_index() {
+ return greater_index_;
+ }
+
+ void set_greater_index(size_t index) {
+ greater_index_ = index;
+ }
+
+ double distance() const {
+ return distance_;
+ }
+
+ void set_distance(double new_dist) {
+ distance_ = new_dist;
+ }
+
+};// class EdgePair
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_EMST_HPP
Deleted: mlpack/trunk/src/mlpack/methods/emst/emst_main.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cc 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cc 2011-10-26 16:44:59 UTC (rev 10039)
@@ -1,160 +0,0 @@
-/**
-* @file emst.cc
- *
- * Calls the DualTreeBoruvka algorithm from dtb.h
- * Can optionally call Naive Boruvka's method
- * See README for command line options.
- *
- * @author Bill March (march at gatech.edu)
-*/
-
-#include "dtb.h"
-
-#include <mlpack/core.h>
-
-PARAM_FLAG("using_thor", "For when an implementation of thor is around",
- "emst");
-PARAM_STRING_REQ("input_file", "Data input file.", "emst");
-PARAM_STRING("output_file", "Data output file.", "emst", "emst_output.csv");
-
-PARAM_FLAG("do_naive", "Check against naive.", "naive");
-PARAM_STRING("output_file", "Naive data output file.", "naive",
- "naive_output.csv");
-
-PARAM(double, "total_squared_length", "Calculation result.", "dtb", 0.0, false);
-
-using namespace mlpack;
-using namespace mlpack::emst;
-
-int main(int argc, char* argv[]) {
- CLI::ParseCommandLine(argc, argv);
-
- // For when I implement a thor version
- bool using_thor = CLI::GetParam<bool>("emst/using_thor");
-
-
- if (using_thor) {
- Log::Warn << "thor is not yet supported" << std::endl;
- }
- else {
-
- ///////////////// READ IN DATA //////////////////////////////////
-
- std::string data_file_name = CLI::GetParam<std::string>("emst/input_file");
-
- arma::mat data_points;
- data::Load(data_file_name.c_str(), data_points);
-
- /////////////// Initialize DTB //////////////////////
- DualTreeBoruvka dtb;
-
- ////////////// Run DTB /////////////////////
- arma::mat results;
-
- dtb.ComputeMST(results);
-
- //////////////// Check against naive //////////////////////////
- if (CLI::GetParam<bool>("naive/do_naive")) {
-
- DualTreeBoruvka naive;
- CLI::GetParam<bool>("naive/do_naive") = true;
-
- naive.Init(data_points);
-
- arma::mat naive_results;
- naive.ComputeMST(naive_results);
-
- /* Compare the naive output to the DTB output */
-
- CLI::StartTimer("naive/comparison");
-
-
- // Check if the edge lists are the same
- // Loop over the naive edge list
- int is_correct = 1;
- /*
- for (size_t naive_index = 0; naive_index < results.size();
- naive_index++) {
-
- int this_loop_correct = 0;
- size_t naive_lesser_index = results[naive_index].lesser_index();
- size_t naive_greater_index = results[naive_index].greater_index();
- double naive_distance = results[naive_index].distance();
-
- // Loop over the DTB edge list and compare against naive
- // Break when an edge is found that matches the current naive edge
- for (size_t dual_index = 0; dual_index < naive_results.size();
- dual_index++) {
-
- size_t dual_lesser_index = results[dual_index].lesser_index();
- size_t dual_greater_index = results[dual_index].greater_index();
- double dual_distance = results[dual_index].distance();
-
- if (naive_lesser_index == dual_lesser_index) {
- if (naive_greater_index == dual_greater_index) {
- DEBUG_ASSERT(naive_distance == dual_distance);
- this_loop_correct = 1;
- break;
- }
- }
-
- }
-
- if (this_loop_correct == 0) {
- is_correct = 0;
- break;
- }
-
- }
- */
- if (is_correct == 0) {
-
- Log::Warn << "Naive check failed!" << std::endl <<
- "Edge lists are different." << std::endl << std::endl;
-
- // Check if the outputs have the same length
- if (CLI::GetParam<double>("naive/total_squared_length") !=
- CLI::GetParam<double>("naive/total_squared_length")) {
-
- Log::Fatal << "Total lengths are different! "
- << " One algorithm has failed." << std::endl;
-
- return 1;
- }
- else {
- // NOTE: if the edge lists are different, but the total lengths are
- // the same, the algorithm may still be correct. The MST is not
- // uniquely defined for some point sets. For example, an equilateral
- // triangle has three minimum spanning trees. It is possible for
- // naive and DTB to find different spanning trees in this case.
- Log::Info << "Total lengths are the same.";
- Log::Info << "It is possible the point set";
- Log::Info << "has more than one minimum spanning tree." << std::endl;
- }
-
- }
- else {
- Log::Info << "Naive and DualTreeBoruvka produced the same MST." <<
- std::endl << std::endl;
- }
-
- CLI::StopTimer("naive/comparison");
-
- std::string naive_output_filename =
- CLI::GetParam<std::string>("naive/output_file");
-
- data::Save(naive_output_filename.c_str(), naive_results);
- }
-
- //////////////// Output the Results ////////////////
-
- std::string output_filename =
- CLI::GetParam<std::string>("emst/output_file");
-
- data::Save(output_filename.c_str(), results);
-
- }// end else (if using_thor)
-
- return 0;
-
-}
Copied: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp (from rev 10030, mlpack/trunk/src/mlpack/methods/emst/emst_main.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp 2011-10-26 16:44:59 UTC (rev 10039)
@@ -0,0 +1,160 @@
+/**
+* @file emst.cc
+ *
+ * Calls the DualTreeBoruvka algorithm from dtb.h
+ * Can optionally call Naive Boruvka's method
+ * See README for command line options.
+ *
+ * @author Bill March (march at gatech.edu)
+*/
+
+#include "dtb.hpp"
+
+#include <mlpack/core.h>
+
+PARAM_FLAG("using_thor", "For when an implementation of thor is around",
+ "emst");
+PARAM_STRING_REQ("input_file", "Data input file.", "emst");
+PARAM_STRING("output_file", "Data output file.", "emst", "emst_output.csv");
+
+PARAM_FLAG("do_naive", "Check against naive.", "naive");
+PARAM_STRING("output_file", "Naive data output file.", "naive",
+ "naive_output.csv");
+
+PARAM(double, "total_squared_length", "Calculation result.", "dtb", 0.0, false);
+
+using namespace mlpack;
+using namespace mlpack::emst;
+
+int main(int argc, char* argv[]) {
+ CLI::ParseCommandLine(argc, argv);
+
+ // For when I implement a thor version
+ bool using_thor = CLI::GetParam<bool>("emst/using_thor");
+
+
+ if (using_thor) {
+ Log::Warn << "thor is not yet supported" << std::endl;
+ }
+ else {
+
+ ///////////////// READ IN DATA //////////////////////////////////
+
+ std::string data_file_name = CLI::GetParam<std::string>("emst/input_file");
+
+ arma::mat data_points;
+ data::Load(data_file_name.c_str(), data_points);
+
+ /////////////// Initialize DTB //////////////////////
+ DualTreeBoruvka dtb;
+
+ ////////////// Run DTB /////////////////////
+ arma::mat results;
+
+ dtb.ComputeMST(results);
+
+ //////////////// Check against naive //////////////////////////
+ if (CLI::GetParam<bool>("naive/do_naive")) {
+
+ DualTreeBoruvka naive;
+ CLI::GetParam<bool>("naive/do_naive") = true;
+
+ naive.Init(data_points);
+
+ arma::mat naive_results;
+ naive.ComputeMST(naive_results);
+
+ /* Compare the naive output to the DTB output */
+
+ CLI::StartTimer("naive/comparison");
+
+
+ // Check if the edge lists are the same
+ // Loop over the naive edge list
+ int is_correct = 1;
+ /*
+ for (size_t naive_index = 0; naive_index < results.size();
+ naive_index++) {
+
+ int this_loop_correct = 0;
+ size_t naive_lesser_index = results[naive_index].lesser_index();
+ size_t naive_greater_index = results[naive_index].greater_index();
+ double naive_distance = results[naive_index].distance();
+
+ // Loop over the DTB edge list and compare against naive
+ // Break when an edge is found that matches the current naive edge
+ for (size_t dual_index = 0; dual_index < naive_results.size();
+ dual_index++) {
+
+ size_t dual_lesser_index = results[dual_index].lesser_index();
+ size_t dual_greater_index = results[dual_index].greater_index();
+ double dual_distance = results[dual_index].distance();
+
+ if (naive_lesser_index == dual_lesser_index) {
+ if (naive_greater_index == dual_greater_index) {
+ DEBUG_ASSERT(naive_distance == dual_distance);
+ this_loop_correct = 1;
+ break;
+ }
+ }
+
+ }
+
+ if (this_loop_correct == 0) {
+ is_correct = 0;
+ break;
+ }
+
+ }
+ */
+ if (is_correct == 0) {
+
+ Log::Warn << "Naive check failed!" << std::endl <<
+ "Edge lists are different." << std::endl << std::endl;
+
+ // Check if the outputs have the same length
+ if (CLI::GetParam<double>("naive/total_squared_length") !=
+ CLI::GetParam<double>("naive/total_squared_length")) {
+
+ Log::Fatal << "Total lengths are different! "
+ << " One algorithm has failed." << std::endl;
+
+ return 1;
+ }
+ else {
+ // NOTE: if the edge lists are different, but the total lengths are
+ // the same, the algorithm may still be correct. The MST is not
+ // uniquely defined for some point sets. For example, an equilateral
+ // triangle has three minimum spanning trees. It is possible for
+ // naive and DTB to find different spanning trees in this case.
+ Log::Info << "Total lengths are the same.";
+ Log::Info << "It is possible the point set";
+ Log::Info << "has more than one minimum spanning tree." << std::endl;
+ }
+
+ }
+ else {
+ Log::Info << "Naive and DualTreeBoruvka produced the same MST." <<
+ std::endl << std::endl;
+ }
+
+ CLI::StopTimer("naive/comparison");
+
+ std::string naive_output_filename =
+ CLI::GetParam<std::string>("naive/output_file");
+
+ data::Save(naive_output_filename.c_str(), naive_results);
+ }
+
+ //////////////// Output the Results ////////////////
+
+ std::string output_filename =
+ CLI::GetParam<std::string>("emst/output_file");
+
+ data::Save(output_filename.c_str(), results);
+
+ }// end else (if using_thor)
+
+ return 0;
+
+}
Deleted: mlpack/trunk/src/mlpack/methods/emst/union_find.h
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.h 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.h 2011-10-26 16:44:59 UTC (rev 10039)
@@ -1,111 +0,0 @@
-/**
- * @file union_find.h
- *
- * @author Bill March (march at gatech.edu)
- *
- * Implements a union-find data structure. This structure tracks the components
- * of a graph. Each point in the graph is initially in its own component.
- * Calling unionfind.Union(x, y) unites the components indexed by x and y.
- * unionfind.Find(x) returns the index of the component containing point x.
- */
-
-#ifndef UNCLIN_FIND_H
-#define UNCLIN_FIND_H
-
-#include <mlpack/core.h>
-
-namespace mlpack {
-namespace emst {
-
-/**
- * @class UnionFind
- *
- *A Union-Find data structure. See Cormen, Rivest, & Stein for details.
- */
-class UnionFind {
- friend class TestUnionFind;
-private:
-
- arma::Col<size_t> parent_;
- arma::ivec rank_;
- size_t number_of_elements_;
-
-public:
-
- UnionFind() {}
-
- ~UnionFind() {}
-
- /**
- * Initializes the structure. This implementation assumes
- * that the size is known advance and fixed
- *
- * @param size The number of elements to be tracked.
- */
-
- void Init(size_t size) {
-
- number_of_elements_ = size;
- parent_.set_size(number_of_elements_);
- rank_.set_size(number_of_elements_);
- for (size_t i = 0; i < number_of_elements_; i++) {
- parent_[i] = i;
- rank_[i] = 0;
- }
-
- }
-
- /**
- * Returns the component containing an element
- *
- * @param x the component to be found
- * @return The index of the component containing x
- */
- size_t Find(size_t x) {
-
- if (parent_[x] == x) {
- return x;
- }
- else {
- // This ensures that the tree has a small depth
- parent_[x] = Find(parent_[x]);
- return parent_[x];
- }
-
- }
-
- /**
- * @function Union
- *
- * Union the components containing x and y
- *
- * @param x one component
- * @param y the other component
- */
- void Union(size_t x, size_t y) {
-
- size_t x_root = Find(x);
- size_t y_root = Find(y);
-
- if (x_root == y_root) {
- return;
- }
- else if (rank_[x_root] == rank_[y_root]) {
- parent_[y_root] = parent_[x_root];
- rank_[x_root] = rank_[x_root] + 1;
- }
- else if (rank_[x_root] > rank_[y_root]) {
- parent_[y_root] = x_root;
- }
- else {
- parent_[x_root] = y_root;
- }
-
- }
-
-}; //class UnionFind
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif
Copied: mlpack/trunk/src/mlpack/methods/emst/union_find.hpp (from rev 10030, mlpack/trunk/src/mlpack/methods/emst/union_find.h)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find.hpp 2011-10-26 16:44:59 UTC (rev 10039)
@@ -0,0 +1,111 @@
+/**
+ * @file union_find.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * Implements a union-find data structure. This structure tracks the components
+ * of a graph. Each point in the graph is initially in its own component.
+ * Calling unionfind.Union(x, y) unites the components indexed by x and y.
+ * unionfind.Find(x) returns the index of the component containing point x.
+ */
+
+#ifndef __MLPACK_METHODS_EMST_UNION_FIND_HPP
+#define __MLPACK_METHODS_EMST_UNION_FIND_HPP
+
+#include <mlpack/core.h>
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * @class UnionFind
+ *
+ *A Union-Find data structure. See Cormen, Rivest, & Stein for details.
+ */
+class UnionFind {
+ friend class TestUnionFind;
+private:
+
+ arma::Col<size_t> parent_;
+ arma::ivec rank_;
+ size_t number_of_elements_;
+
+public:
+
+ UnionFind() {}
+
+ ~UnionFind() {}
+
+ /**
+ * Initializes the structure. This implementation assumes
+ * that the size is known advance and fixed
+ *
+ * @param size The number of elements to be tracked.
+ */
+
+ void Init(size_t size) {
+
+ number_of_elements_ = size;
+ parent_.set_size(number_of_elements_);
+ rank_.set_size(number_of_elements_);
+ for (size_t i = 0; i < number_of_elements_; i++) {
+ parent_[i] = i;
+ rank_[i] = 0;
+ }
+
+ }
+
+ /**
+ * Returns the component containing an element
+ *
+ * @param x the component to be found
+ * @return The index of the component containing x
+ */
+ size_t Find(size_t x) {
+
+ if (parent_[x] == x) {
+ return x;
+ }
+ else {
+ // This ensures that the tree has a small depth
+ parent_[x] = Find(parent_[x]);
+ return parent_[x];
+ }
+
+ }
+
+ /**
+ * @function Union
+ *
+ * Union the components containing x and y
+ *
+ * @param x one component
+ * @param y the other component
+ */
+ void Union(size_t x, size_t y) {
+
+ size_t x_root = Find(x);
+ size_t y_root = Find(y);
+
+ if (x_root == y_root) {
+ return;
+ }
+ else if (rank_[x_root] == rank_[y_root]) {
+ parent_[y_root] = parent_[x_root];
+ rank_[x_root] = rank_[x_root] + 1;
+ }
+ else if (rank_[x_root] > rank_[y_root]) {
+ parent_[y_root] = x_root;
+ }
+ else {
+ parent_[x_root] = y_root;
+ }
+
+ }
+
+}; //class UnionFind
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_UNION_FIND_HPP
Deleted: mlpack/trunk/src/mlpack/methods/emst/union_find_test.cc
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find_test.cc 2011-10-26 16:28:23 UTC (rev 10038)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find_test.cc 2011-10-26 16:44:59 UTC (rev 10039)
@@ -1,50 +0,0 @@
-/**
- * @file union_find_test.cc
- *
- * @author Bill March (march at gatech.edu)
- *
- * Unit tests for the Union-Find data structure.
- */
-
-#include "union_find.h"
-
-#include <mlpack/core.h>
-
-#define BOOST_TEST_MODULE UnionFindTest
-#include <boost/test/unit_test.hpp>
-
-using namespace mlpack;
-using namespace mlpack::emst;
-
-BOOST_AUTO_TEST_CASE(TestFind) {
- static const size_t test_size_ = 10;
- UnionFind test_union_find_;
- test_union_find_.Init(test_size_);
-
- for (size_t i = 0; i < test_size_; i++) {
- BOOST_REQUIRE(test_union_find_.Find(i) == i);
- }
- test_union_find_.Union(0,1);
- test_union_find_.Union(1, 2);
-
- BOOST_REQUIRE(test_union_find_.Find(2) == test_union_find_.Find(0));
-
-}
-
-BOOST_AUTO_TEST_CASE(TestUnion) {
- static const size_t test_size_ = 10;
- UnionFind test_union_find_;
- test_union_find_.Init(test_size_);
-
- test_union_find_.Union(0, 1);
- test_union_find_.Union(2, 3);
- test_union_find_.Union(0, 2);
- test_union_find_.Union(5, 0);
- test_union_find_.Union(0, 6);
-
- BOOST_REQUIRE(test_union_find_.Find(0) == test_union_find_.Find(1));
- BOOST_REQUIRE(test_union_find_.Find(2) == test_union_find_.Find(3));
- BOOST_REQUIRE(test_union_find_.Find(1) == test_union_find_.Find(5));
- BOOST_REQUIRE(test_union_find_.Find(6) == test_union_find_.Find(3));
-}
-
Copied: mlpack/trunk/src/mlpack/methods/emst/union_find_test.cpp (from rev 10030, mlpack/trunk/src/mlpack/methods/emst/union_find_test.cc)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/union_find_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/union_find_test.cpp 2011-10-26 16:44:59 UTC (rev 10039)
@@ -0,0 +1,50 @@
+/**
+ * @file union_find_test.cc
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * Unit tests for the Union-Find data structure.
+ */
+
+#include "union_find.hpp"
+
+#include <mlpack/core.h>
+
+#define BOOST_TEST_MODULE UnionFindTest
+#include <boost/test/unit_test.hpp>
+
+using namespace mlpack;
+using namespace mlpack::emst;
+
+BOOST_AUTO_TEST_CASE(TestFind) {
+ static const size_t test_size_ = 10;
+ UnionFind test_union_find_;
+ test_union_find_.Init(test_size_);
+
+ for (size_t i = 0; i < test_size_; i++) {
+ BOOST_REQUIRE(test_union_find_.Find(i) == i);
+ }
+ test_union_find_.Union(0,1);
+ test_union_find_.Union(1, 2);
+
+ BOOST_REQUIRE(test_union_find_.Find(2) == test_union_find_.Find(0));
+
+}
+
+BOOST_AUTO_TEST_CASE(TestUnion) {
+ static const size_t test_size_ = 10;
+ UnionFind test_union_find_;
+ test_union_find_.Init(test_size_);
+
+ test_union_find_.Union(0, 1);
+ test_union_find_.Union(2, 3);
+ test_union_find_.Union(0, 2);
+ test_union_find_.Union(5, 0);
+ test_union_find_.Union(0, 6);
+
+ BOOST_REQUIRE(test_union_find_.Find(0) == test_union_find_.Find(1));
+ BOOST_REQUIRE(test_union_find_.Find(2) == test_union_find_.Find(3));
+ BOOST_REQUIRE(test_union_find_.Find(1) == test_union_find_.Find(5));
+ BOOST_REQUIRE(test_union_find_.Find(6) == test_union_find_.Find(3));
+}
+
More information about the mlpack-svn
mailing list