[mlpack-svn] r10605 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Dec 6 18:51:56 EST 2011
Author: march
Date: 2011-12-06 18:51:55 -0500 (Tue, 06 Dec 2011)
New Revision: 10605
Added:
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp
Removed:
mlpack/trunk/src/mlpack/methods/emst/emst.hpp
Modified:
mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
Log:
split dtb into _impl.hpp, should finish off ticket 117
Modified: mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/CMakeLists.txt 2011-12-06 23:51:55 UTC (rev 10605)
@@ -7,7 +7,8 @@
union_find.hpp
# dtb
dtb.hpp
- emst.hpp
+ dtb_impl.hpp
+ edge_pair.hpp
)
# Add directory name to sources.
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp 2011-12-06 23:51:55 UTC (rev 10605)
@@ -13,7 +13,7 @@
#ifndef __MLPACK_METHODS_EMST_DTB_HPP
#define __MLPACK_METHODS_EMST_DTB_HPP
-#include "emst.hpp"
+#include "edge_pair.hpp"
#include <mlpack/core.hpp>
#include <mlpack/core/tree/bounds.hpp>
@@ -41,69 +41,30 @@
int component_membership_;
public:
- void set_max_neighbor_distance(double distance)
- {
- max_neighbor_distance_ = distance;
- }
+ void set_max_neighbor_distance(double distance);
- double max_neighbor_distance()
- {
- return max_neighbor_distance_;
- }
+ double max_neighbor_distance();
- void set_component_membership(int membership)
- {
- component_membership_ = membership;
- }
+ void set_component_membership(int membership);
- int component_membership()
- {
- return component_membership_;
- }
+ int component_membership();
/**
* A generic initializer.
*/
- DTBStat()
- {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
+ DTBStat();
/**
* An initializer for leaves.
*/
- DTBStat(const arma::mat& dataset, size_t start, size_t count)
- {
- if (count == 1)
- {
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
- }
- else
- {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
- }
+ DTBStat(const arma::mat& dataset, size_t start, size_t count);
/**
* An initializer for non-leaves. Simply calls the leaf initializer.
*/
DTBStat(const arma::mat& dataset, size_t start, size_t count,
- const DTBStat& left_stat, const DTBStat& right_stat)
- {
- if (count == 1)
- {
- set_component_membership(start);
- set_max_neighbor_distance(DBL_MAX);
- }
- else
- {
- set_max_neighbor_distance(DBL_MAX);
- set_component_membership(-1);
- }
- }
+ const DTBStat& left_stat, const DTBStat& right_stat);
+
}; // class DTBStat
/**
@@ -148,495 +109,98 @@
DTBTree* tree_;
+ // for sorting the edge list after the computation
+ struct SortEdgesHelper_
+ {
+ bool operator() (const EdgePair& pairA, const EdgePair& pairB)
+ {
+ return (pairA.distance() < pairB.distance());
+ }
+ } SortFun;
+
+
////////////////// Constructors ////////////////////////
public:
DualTreeBoruvka() { }
- ~DualTreeBoruvka()
- {
- if (tree_ != NULL)
- delete tree_;
- }
+ ~DualTreeBoruvka();
////////////////////////// Private Functions ////////////////////
private:
/**
* Adds a single edge to the edge list
*/
- void AddEdge_(size_t e1, size_t e2, double distance)
- {
- //EdgePair edge;
- mlpack::Log::Assert((e1 != e2),
- "Indices are equal in DualTreeBoruvka.add_edge(...)");
-
- mlpack::Log::Assert((distance >= 0.0),
- "Negative distance input in DualTreeBoruvka.add_edge(...)");
-
- if (e1 < e2)
- edges_[number_of_edges_].Init(e1, e2, distance);
- else
- edges_[number_of_edges_].Init(e2, e1, distance);
-
- number_of_edges_++;
-
- } // AddEdge_
-
+ void AddEdge_(size_t e1, size_t e2, double distance);
+
/**
* Adds all the edges found in one iteration to the list of neighbors.
*/
- void AddAllEdges_()
- {
- for (size_t i = 0; i < number_of_points_; i++)
- {
- size_t component_i = connections_.Find(i);
- size_t in_edge_i = neighbors_in_component_[component_i];
- size_t out_edge_i = neighbors_out_component_[component_i];
- if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
- {
- double dist = neighbors_distances_[component_i];
- //total_dist_ = total_dist_ + dist;
- // changed to make this agree with the cover tree code
- total_dist_ = total_dist_ + sqrt(dist);
- AddEdge_(in_edge_i, out_edge_i, dist);
- connections_.Union(in_edge_i, out_edge_i);
- }
- }
- } // AddAllEdges_
-
-
+ void AddAllEdges_();
+
/**
* Handles the base case computation. Also called by naive.
*/
double ComputeBaseCase_(size_t query_start, size_t query_end,
- size_t reference_start, size_t reference_end)
- {
- number_leaf_computations_++;
-
- double new_upper_bound = -1.0;
-
- for (size_t query_index = query_start; query_index < query_end;
- query_index++)
- {
- // Find the index of the component the query is in
- size_t query_component_index = connections_.Find(query_index);
-
- arma::vec query_point = data_points_.col(query_index);
-
- for (size_t reference_index = reference_start;
- reference_index < reference_end; reference_index++)
- {
- size_t reference_component_index = connections_.Find(reference_index);
-
- if (query_component_index != reference_component_index)
- {
- arma::vec reference_point = data_points_.col(reference_index);
-
- double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
- reference_point);
-
- if (distance < neighbors_distances_[query_component_index])
- {
- mlpack::Log::Assert(query_index != reference_index);
-
- neighbors_distances_[query_component_index] = distance;
- neighbors_in_component_[query_component_index] = query_index;
- neighbors_out_component_[query_component_index] = reference_index;
- } // if distance
- } // if indices not equal
- } // for reference_index
-
- if (new_upper_bound < neighbors_distances_[query_component_index])
- new_upper_bound = neighbors_distances_[query_component_index];
-
- } // for query_index
-
- mlpack::Log::Assert(new_upper_bound >= 0.0);
- return new_upper_bound;
-
- } // ComputeBaseCase_
-
-
+ size_t reference_start, size_t reference_end);
+
/**
* Handles the recursive calls to find the nearest neighbors in an iteration
*/
void ComputeNeighborsRecursion_(DTBTree *query_node, DTBTree *reference_node,
- double incoming_distance)
- {
- // Check for a distance prune.
- if (query_node->Stat().max_neighbor_distance() < incoming_distance)
- {
- // Pruned by distance.
- number_distance_prunes_++;
- }
- // Check for a component prune.
- else if ((query_node->Stat().component_membership() >= 0)
- && (query_node->Stat().component_membership() ==
- reference_node->Stat().component_membership()))
- {
- // Pruned by component membership.
- mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
- mlpack::Log::Info << query_node->Stat().component_membership()
- << "q mem\n";
- mlpack::Log::Info << reference_node->Stat().component_membership()
- << "r mem\n";
-
- number_component_prunes_++;
- }
- else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
- {
- double new_bound = ComputeBaseCase_(query_node->Begin(),
- query_node->End(), reference_node->Begin(), reference_node->End());
-
- query_node->Stat().set_max_neighbor_distance(new_bound);
- }
- else if (query_node->IsLeaf()) // Other recursive calls.
- {
- // Recurse on reference_node only.
- number_r_recursions_++;
-
- double left_dist =
- query_node->Bound().MinDistance(reference_node->Left()->Bound());
- double right_dist =
- query_node->Bound().MinDistance(reference_node->Right()->Bound());
- mlpack::Log::Assert(left_dist >= 0.0);
- mlpack::Log::Assert(right_dist >= 0.0);
-
- if (left_dist < right_dist)
- {
- ComputeNeighborsRecursion_(query_node, reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node, reference_node->Right(),
- right_dist);
- }
- else
- {
- ComputeNeighborsRecursion_(query_node, reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node, reference_node->Left(),
- left_dist);
- }
- }
- else if (reference_node->IsLeaf())
- {
- // Recurse on query_node only.
- number_q_recursions_++;
-
- double left_dist =
- query_node->Left()->Bound().MinDistance(reference_node->Bound());
- double right_dist =
- query_node->Right()->Bound().MinDistance(reference_node->Bound());
-
- ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node,
- right_dist);
-
- // Update query_node's stat.
- query_node->Stat().set_max_neighbor_distance(
- std::max(query_node->Left()->Stat().max_neighbor_distance(),
- query_node->Right()->Stat().max_neighbor_distance()));
-
- }
- else
- {
- // Recurse on both.
- number_both_recursions_++;
-
- double left_dist = query_node->Left()->Bound().MinDistance(
- reference_node->Left()->Bound());
- double right_dist = query_node->Left()->Bound().MinDistance(
- reference_node->Right()->Bound());
-
- if (left_dist < right_dist)
- {
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_dist);
- }
- else
- {
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
- left_dist);
- }
-
- left_dist = query_node->Right()->Bound().MinDistance(
- reference_node->Left()->Bound());
- right_dist = query_node->Right()->Bound().MinDistance(
- reference_node->Right()->Bound());
-
- if (left_dist < right_dist)
- {
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_dist);
- }
- else
- {
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
- right_dist);
- ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
- left_dist);
- }
-
- query_node->Stat().set_max_neighbor_distance(
- std::max(query_node->Left()->Stat().max_neighbor_distance(),
- query_node->Right()->Stat().max_neighbor_distance()));
- }
- } // ComputeNeighborsRecursion_
-
+ double incoming_distance);
+
/**
* Computes the nearest neighbor of each point in each iteration
* of the algorithm
*/
- void ComputeNeighbors_()
- {
- if (do_naive_)
- {
- ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
- }
- else
- {
- ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
- }
- } // ComputeNeighbors_
+ void ComputeNeighbors_();
- struct SortEdgesHelper_
- {
- bool operator() (const EdgePair& pairA, const EdgePair& pairB)
- {
- return (pairA.distance() < pairB.distance());
- }
- } SortFun;
-
- void SortEdges_()
- {
- std::sort(edges_.begin(), edges_.end(), SortFun);
- } // SortEdges_()
-
+
+ void SortEdges_();
+
/**
* Unpermute the edge list and output it to results
*
- * TODO: Make this sort the edge list by distance as well for hierarchical
- * clusterings.
*/
- void EmitResults_(arma::mat& results)
- {
- SortEdges_();
+ void EmitResults_(arma::mat& results);
- mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
- results.set_size(number_of_edges_, 3);
-
- // Need to unpermute the point labels.
- if (!do_naive_)
- {
- for (size_t i = 0; i < (number_of_points_ - 1); i++)
- {
- // Make sure the edge list stores the smaller index first to
- // make checking correctness easier
- size_t ind1, ind2;
- ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
- ind2 = old_from_new_permutation_[edges_[i].greater_index()];
-
- edges_[i].set_lesser_index(std::min(ind1, ind2));
- edges_[i].set_greater_index(std::max(ind1, ind2));
-
- results(i, 0) = edges_[i].lesser_index();
- results(i, 1) = edges_[i].greater_index();
- results(i, 2) = sqrt(edges_[i].distance());
- }
- }
- else
- {
- for (size_t i = 0; i < number_of_edges_; i++)
- {
- results(i, 0) = edges_[i].lesser_index();
- results(i, 1) = edges_[i].greater_index();
- results(i, 2) = sqrt(edges_[i].distance());
- }
- }
- } // EmitResults_
-
/**
* This function resets the values in the nodes of the tree nearest neighbor
* distance, check for fully connected nodes
*/
- void CleanupHelper_(DTBTree* tree)
- {
- tree->Stat().set_max_neighbor_distance(DBL_MAX);
+ void CleanupHelper_(DTBTree* tree);
- if (!tree->IsLeaf())
- {
- CleanupHelper_(tree->Left());
- CleanupHelper_(tree->Right());
-
- if ((tree->Left()->Stat().component_membership() >= 0)
- && (tree->Left()->Stat().component_membership() ==
- tree->Right()->Stat().component_membership()))
- {
- tree->Stat().set_component_membership(tree->Left()->Stat().
- component_membership());
- }
- }
- else
- {
- size_t new_membership = connections_.Find(tree->Begin());
-
- for (size_t i = tree->Begin(); i < tree->End(); i++)
- {
- if (new_membership != connections_.Find(i))
- {
- new_membership = -1;
- mlpack::Log::Assert(tree->Stat().component_membership() < 0);
- return;
- }
- }
- tree->Stat().set_component_membership(new_membership);
- }
- } // CleanupHelper_
-
/**
* The values stored in the tree must be reset on each iteration.
*/
- void Cleanup_()
- {
- for (size_t i = 0; i < number_of_points_; i++)
- {
- neighbors_distances_[i] = DBL_MAX;
- }
- number_of_loops_++;
-
- if (!do_naive_)
- {
- CleanupHelper_(tree_);
- }
- }
-
+ void Cleanup_();
+
/**
* Format and output the results
*/
- void OutputResults_()
- {
- /* fx_result_double(module_, "total_squared_length", total_dist_);
- fx_result_int(module_, "number_of_points", number_of_points_);
- fx_result_int(module_, "dimension", data_points_.n_rows);
- fx_result_int(module_, "number_of_loops", number_of_loops_);
- fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
- fx_result_int(module_, "number_component_prunes", number_component_prunes_);
- fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
- fx_result_int(module_, "number_q_recursions", number_q_recursions_);
- fx_result_int(module_, "number_r_recursions", number_r_recursions_);
- fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
- // TODO, not sure how I missed this last time.
- mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
- mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
- mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
- /*
- mlpack::Log::Info << "number_of_loops" << std::endl;
- mlpack::Log::Info << "number_distance_prunes" << std::endl;
- mlpack::Log::Info << "number_component_prunes" << std::endl;
- mlpack::Log::Info << "number_leaf_computations" << std::endl;
- mlpack::Log::Info << "number_q_recursions" << std::endl;
- mlpack::Log::Info << "number_r_recursions" << std::endl;
- mlpack::Log::Info << "number_both_recursions" << std::endl;
- */
-
- } // OutputResults_
-
+ void OutputResults_();
+
/////////// Public Functions ///////////////////
public:
- size_t number_of_edges()
- {
- return number_of_edges_;
- }
+ size_t number_of_edges();
/**
* Takes in a reference to the data set. Copies the data, builds the tree,
* and initializes all of the member variables.
*/
- void Init(const arma::mat& data, bool naive = false, size_t leafSize = 1)
- {
- number_of_edges_ = 0;
- data_points_ = data; // copy
-
- do_naive_ = naive;
-
- if (!do_naive_)
- {
- // Default leaf size is 1
- // This gives best pruning empirically
- // Use leaf_size=1 unless space is a big concern
- Timers::StartTimer("emst/tree_building");
-
- tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
-
- Timers::StopTimer("emst/tree_building");
- }
- else
- {
- tree_ = NULL;
- old_from_new_permutation_.resize(0);
- }
-
- number_of_points_ = data_points_.n_cols;
- edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
- connections_.Init(number_of_points_);
-
- neighbors_in_component_.set_size(number_of_points_);
- neighbors_out_component_.set_size(number_of_points_);
- neighbors_distances_.set_size(number_of_points_);
- neighbors_distances_.fill(DBL_MAX);
-
- total_dist_ = 0.0;
- number_of_loops_ = 0;
- number_distance_prunes_ = 0;
- number_component_prunes_ = 0;
- number_leaf_computations_ = 0;
- number_q_recursions_ = 0;
- number_r_recursions_ = 0;
- number_both_recursions_ = 0;
- } // Init
-
+ void Init(const arma::mat& data, bool naive, size_t leafSize);
+
/**
* Call this function after Init. It will iteratively find the nearest
* neighbor of each component until the MST is complete.
*/
- void ComputeMST(arma::mat& results)
- {
- Timers::StartTimer("emst/MST_computation");
-
- while (number_of_edges_ < (number_of_points_ - 1))
- {
- ComputeNeighbors_();
-
- AddAllEdges_();
-
- Cleanup_();
-
- Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
- Log::Info << number_of_edges_ << " edges found so far.\n\n";
- /*
- Log::Info << number_leaf_computations_ << " base cases.\n";
- Log::Info << number_distance_prunes_ << " distance prunes.\n";
- Log::Info << number_component_prunes_ << " component prunes.\n";
- Log::Info << number_r_recursions_ << " reference recursions.\n";
- Log::Info << number_q_recursions_ << " query recursions.\n";
- Log::Info << number_both_recursions_ << " dual recursions.\n\n";
- */
- }
-
- Timers::StopTimer("emst/MST_computation");
-
- EmitResults_(results);
-
- OutputResults_();
- } // ComputeMST
-
+ void ComputeMST(arma::mat& results);
+
}; // class DualTreeBoruvka
}; // namespace emst
}; // namespace mlpack
+#include "dtb_impl.hpp"
+
#endif // __MLPACK_METHODS_EMST_DTB_HPP
Added: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp 2011-12-06 23:51:55 UTC (rev 10605)
@@ -0,0 +1,558 @@
+/*
+ * dtb_impl.hpp
+ *
+ *
+ * Created by William March on 12/6/11.
+ * Copyright 2011 __MyCompanyName__. All rights reserved.
+ *
+ */
+
+#ifndef __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+#define __MLPACK_METHODS_EMST_DTB_IMPL_HPP
+
+#include <mlpack/core.hpp>
+
+
+
+using namespace mlpack::emst;
+
+void DTBStat::set_max_neighbor_distance(double distance)
+{
+ max_neighbor_distance_ = distance;
+}
+
+double DTBStat::max_neighbor_distance()
+{
+ return max_neighbor_distance_;
+}
+
+void DTBStat::set_component_membership(int membership)
+{
+ component_membership_ = membership;
+}
+
+int DTBStat::component_membership()
+{
+ return component_membership_;
+}
+
+/**
+ * A generic initializer.
+ */
+DTBStat::DTBStat()
+{
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+}
+
+/**
+ * An initializer for leaves.
+ */
+DTBStat::DTBStat(const arma::mat& dataset, size_t start, size_t count)
+{
+ if (count == 1)
+ {
+ set_component_membership(start);
+ set_max_neighbor_distance(DBL_MAX);
+ }
+ else
+ {
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+ }
+}
+
+/**
+ * An initializer for non-leaves. Simply calls the leaf initializer.
+ */
+DTBStat::DTBStat(const arma::mat& dataset, size_t start, size_t count,
+ const DTBStat& left_stat, const DTBStat& right_stat)
+{
+ if (count == 1)
+ {
+ set_component_membership(start);
+ set_max_neighbor_distance(DBL_MAX);
+ }
+ else
+ {
+ set_max_neighbor_distance(DBL_MAX);
+ set_component_membership(-1);
+ }
+}
+
+ DualTreeBoruvka::~DualTreeBoruvka()
+ {
+ if (tree_ != NULL)
+ delete tree_;
+ }
+
+/**
+ * Adds a single edge to the edge list
+ */
+void DualTreeBoruvka::AddEdge_(size_t e1, size_t e2, double distance)
+{
+ //EdgePair edge;
+ mlpack::Log::Assert((e1 != e2),
+ "Indices are equal in DualTreeBoruvka.add_edge(...)");
+
+ mlpack::Log::Assert((distance >= 0.0),
+ "Negative distance input in DualTreeBoruvka.add_edge(...)");
+
+ if (e1 < e2)
+ edges_[number_of_edges_].Init(e1, e2, distance);
+ else
+ edges_[number_of_edges_].Init(e2, e1, distance);
+
+ number_of_edges_++;
+
+} // AddEdge_
+
+/**
+ * Adds all the edges found in one iteration to the list of neighbors.
+ */
+void DualTreeBoruvka::AddAllEdges_()
+{
+ for (size_t i = 0; i < number_of_points_; i++)
+ {
+ size_t component_i = connections_.Find(i);
+ size_t in_edge_i = neighbors_in_component_[component_i];
+ size_t out_edge_i = neighbors_out_component_[component_i];
+ if (connections_.Find(in_edge_i) != connections_.Find(out_edge_i))
+ {
+ double dist = neighbors_distances_[component_i];
+ //total_dist_ = total_dist_ + dist;
+ // changed to make this agree with the cover tree code
+ total_dist_ = total_dist_ + sqrt(dist);
+ AddEdge_(in_edge_i, out_edge_i, dist);
+ connections_.Union(in_edge_i, out_edge_i);
+ }
+ }
+} // AddAllEdges_
+
+
+/**
+ * Handles the base case computation. Also called by naive.
+ */
+double DualTreeBoruvka::ComputeBaseCase_(size_t query_start, size_t query_end,
+ size_t reference_start,
+ size_t reference_end)
+{
+ number_leaf_computations_++;
+
+ double new_upper_bound = -1.0;
+
+ for (size_t query_index = query_start; query_index < query_end;
+ query_index++)
+ {
+ // Find the index of the component the query is in
+ size_t query_component_index = connections_.Find(query_index);
+
+ arma::vec query_point = data_points_.col(query_index);
+
+ for (size_t reference_index = reference_start;
+ reference_index < reference_end; reference_index++)
+ {
+ size_t reference_component_index = connections_.Find(reference_index);
+
+ if (query_component_index != reference_component_index)
+ {
+ arma::vec reference_point = data_points_.col(reference_index);
+
+ double distance = mlpack::metric::LMetric<2>::Evaluate(query_point,
+ reference_point);
+
+ if (distance < neighbors_distances_[query_component_index])
+ {
+ mlpack::Log::Assert(query_index != reference_index);
+
+ neighbors_distances_[query_component_index] = distance;
+ neighbors_in_component_[query_component_index] = query_index;
+ neighbors_out_component_[query_component_index] = reference_index;
+ } // if distance
+ } // if indices not equal
+ } // for reference_index
+
+ if (new_upper_bound < neighbors_distances_[query_component_index])
+ new_upper_bound = neighbors_distances_[query_component_index];
+
+ } // for query_index
+
+ mlpack::Log::Assert(new_upper_bound >= 0.0);
+ return new_upper_bound;
+
+} // ComputeBaseCase_
+
+
+/**
+ * Handles the recursive calls to find the nearest neighbors in an iteration
+ */
+void DualTreeBoruvka::ComputeNeighborsRecursion_(DTBTree *query_node,
+ DTBTree *reference_node,
+ double incoming_distance)
+{
+ // Check for a distance prune.
+ if (query_node->Stat().max_neighbor_distance() < incoming_distance)
+ {
+ // Pruned by distance.
+ number_distance_prunes_++;
+ }
+ // Check for a component prune.
+ else if ((query_node->Stat().component_membership() >= 0)
+ && (query_node->Stat().component_membership() ==
+ reference_node->Stat().component_membership()))
+ {
+ // Pruned by component membership.
+ mlpack::Log::Assert(reference_node->Stat().component_membership() >= 0);
+ mlpack::Log::Info << query_node->Stat().component_membership()
+ << "q mem\n";
+ mlpack::Log::Info << reference_node->Stat().component_membership()
+ << "r mem\n";
+
+ number_component_prunes_++;
+ }
+ else if (query_node->IsLeaf() && reference_node->IsLeaf()) // Base case.
+ {
+ double new_bound = ComputeBaseCase_(query_node->Begin(),
+ query_node->End(), reference_node->Begin(), reference_node->End());
+
+ query_node->Stat().set_max_neighbor_distance(new_bound);
+ }
+ else if (query_node->IsLeaf()) // Other recursive calls.
+ {
+ // Recurse on reference_node only.
+ number_r_recursions_++;
+
+ double left_dist =
+ query_node->Bound().MinDistance(reference_node->Left()->Bound());
+ double right_dist =
+ query_node->Bound().MinDistance(reference_node->Right()->Bound());
+ mlpack::Log::Assert(left_dist >= 0.0);
+ mlpack::Log::Assert(right_dist >= 0.0);
+
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node, reference_node->Left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node, reference_node->Right(),
+ right_dist);
+ }
+ else
+ {
+ ComputeNeighborsRecursion_(query_node, reference_node->Right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node, reference_node->Left(),
+ left_dist);
+ }
+ }
+ else if (reference_node->IsLeaf())
+ {
+ // Recurse on query_node only.
+ number_q_recursions_++;
+
+ double left_dist =
+ query_node->Left()->Bound().MinDistance(reference_node->Bound());
+ double right_dist =
+ query_node->Right()->Bound().MinDistance(reference_node->Bound());
+
+ ComputeNeighborsRecursion_(query_node->Left(), reference_node, left_dist);
+ ComputeNeighborsRecursion_(query_node->Right(), reference_node,
+ right_dist);
+
+ // Update query_node's stat.
+ query_node->Stat().set_max_neighbor_distance(
+ std::max(query_node->Left()->Stat().max_neighbor_distance(),
+ query_node->Right()->Stat().max_neighbor_distance()));
+
+ }
+ else
+ {
+ // Recurse on both.
+ number_both_recursions_++;
+
+ double left_dist = query_node->Left()->Bound().MinDistance(
+ reference_node->Left()->Bound());
+ double right_dist = query_node->Left()->Bound().MinDistance(
+ reference_node->Right()->Bound());
+
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
+ right_dist);
+ }
+ else
+ {
+ ComputeNeighborsRecursion_(query_node->Left(), reference_node->Right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node->Left(), reference_node->Left(),
+ left_dist);
+ }
+
+ left_dist = query_node->Right()->Bound().MinDistance(
+ reference_node->Left()->Bound());
+ right_dist = query_node->Right()->Bound().MinDistance(
+ reference_node->Right()->Bound());
+
+ if (left_dist < right_dist)
+ {
+ ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
+ left_dist);
+ ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
+ right_dist);
+ }
+ else
+ {
+ ComputeNeighborsRecursion_(query_node->Right(), reference_node->Right(),
+ right_dist);
+ ComputeNeighborsRecursion_(query_node->Right(), reference_node->Left(),
+ left_dist);
+ }
+
+ query_node->Stat().set_max_neighbor_distance(
+ std::max(query_node->Left()->Stat().max_neighbor_distance(),
+ query_node->Right()->Stat().max_neighbor_distance()));
+ }
+} // ComputeNeighborsRecursion_
+
+/**
+ * Computes the nearest neighbor of each point in each iteration
+ * of the algorithm
+ */
+void DualTreeBoruvka::ComputeNeighbors_()
+{
+ if (do_naive_)
+ {
+ ComputeBaseCase_(0, number_of_points_, 0, number_of_points_);
+ }
+ else
+ {
+ ComputeNeighborsRecursion_(tree_, tree_, DBL_MAX);
+ }
+} // ComputeNeighbors_
+
+void DualTreeBoruvka::SortEdges_()
+{
+ std::sort(edges_.begin(), edges_.end(), SortFun);
+} // SortEdges_()
+
+/**
+ * Unpermute the edge list and output it to results
+ *
+ */
+void DualTreeBoruvka::EmitResults_(arma::mat& results)
+{
+ SortEdges_();
+
+ mlpack::Log::Assert(number_of_edges_ == number_of_points_ - 1);
+ results.set_size(number_of_edges_, 3);
+
+ // Need to unpermute the point labels.
+ if (!do_naive_)
+ {
+ for (size_t i = 0; i < (number_of_points_ - 1); i++)
+ {
+ // Make sure the edge list stores the smaller index first to
+ // make checking correctness easier
+ size_t ind1, ind2;
+ ind1 = old_from_new_permutation_[edges_[i].lesser_index()];
+ ind2 = old_from_new_permutation_[edges_[i].greater_index()];
+
+ edges_[i].set_lesser_index(std::min(ind1, ind2));
+ edges_[i].set_greater_index(std::max(ind1, ind2));
+
+ results(i, 0) = edges_[i].lesser_index();
+ results(i, 1) = edges_[i].greater_index();
+ results(i, 2) = sqrt(edges_[i].distance());
+ }
+ }
+ else
+ {
+ for (size_t i = 0; i < number_of_edges_; i++)
+ {
+ results(i, 0) = edges_[i].lesser_index();
+ results(i, 1) = edges_[i].greater_index();
+ results(i, 2) = sqrt(edges_[i].distance());
+ }
+ }
+} // EmitResults_
+
+/**
+ * This function resets the values in the nodes of the tree nearest neighbor
+ * distance, check for fully connected nodes
+ */
+void DualTreeBoruvka::CleanupHelper_(DTBTree* tree)
+{
+ tree->Stat().set_max_neighbor_distance(DBL_MAX);
+
+ if (!tree->IsLeaf())
+ {
+ CleanupHelper_(tree->Left());
+ CleanupHelper_(tree->Right());
+
+ if ((tree->Left()->Stat().component_membership() >= 0)
+ && (tree->Left()->Stat().component_membership() ==
+ tree->Right()->Stat().component_membership()))
+ {
+ tree->Stat().set_component_membership(tree->Left()->Stat().
+ component_membership());
+ }
+ }
+ else
+ {
+ size_t new_membership = connections_.Find(tree->Begin());
+
+ for (size_t i = tree->Begin(); i < tree->End(); i++)
+ {
+ if (new_membership != connections_.Find(i))
+ {
+ new_membership = -1;
+ mlpack::Log::Assert(tree->Stat().component_membership() < 0);
+ return;
+ }
+ }
+ tree->Stat().set_component_membership(new_membership);
+ }
+} // CleanupHelper_
+
+/**
+ * The values stored in the tree must be reset on each iteration.
+ */
+void DualTreeBoruvka::Cleanup_()
+{
+ for (size_t i = 0; i < number_of_points_; i++)
+ {
+ neighbors_distances_[i] = DBL_MAX;
+ }
+ number_of_loops_++;
+
+ if (!do_naive_)
+ {
+ CleanupHelper_(tree_);
+ }
+}
+
+/**
+ * Format and output the results
+ */
+void DualTreeBoruvka::OutputResults_()
+{
+ /* fx_result_double(module_, "total_squared_length", total_dist_);
+ fx_result_int(module_, "number_of_points", number_of_points_);
+ fx_result_int(module_, "dimension", data_points_.n_rows);
+ fx_result_int(module_, "number_of_loops", number_of_loops_);
+ fx_result_int(module_, "number_distance_prunes", number_distance_prunes_);
+ fx_result_int(module_, "number_component_prunes", number_component_prunes_);
+ fx_result_int(module_, "number_leaf_computations", number_leaf_computations_);
+ fx_result_int(module_, "number_q_recursions", number_q_recursions_);
+ fx_result_int(module_, "number_r_recursions", number_r_recursions_);
+ fx_result_int(module_, "number_both_recursions", number_both_recursions_);*/
+ // TODO, not sure how I missed this last time.
+ mlpack::Log::Info << "Total squared length: " << total_dist_ << std::endl;
+ mlpack::Log::Info << "Number of points: " << number_of_points_ << std::endl;
+ mlpack::Log::Info << "Dimension: " << data_points_.n_rows << std::endl;
+ /*
+ mlpack::Log::Info << "number_of_loops" << std::endl;
+ mlpack::Log::Info << "number_distance_prunes" << std::endl;
+ mlpack::Log::Info << "number_component_prunes" << std::endl;
+ mlpack::Log::Info << "number_leaf_computations" << std::endl;
+ mlpack::Log::Info << "number_q_recursions" << std::endl;
+ mlpack::Log::Info << "number_r_recursions" << std::endl;
+ mlpack::Log::Info << "number_both_recursions" << std::endl;
+ */
+
+} // OutputResults_
+
+size_t DualTreeBoruvka::number_of_edges()
+{
+ return number_of_edges_;
+}
+
+/**
+ * Takes in a reference to the data set. Copies the data, builds the tree,
+ * and initializes all of the member variables.
+ */
+void DualTreeBoruvka::Init(const arma::mat& data, bool naive = false,
+ size_t leafSize = 1)
+{
+ number_of_edges_ = 0;
+ data_points_ = data; // copy
+
+ do_naive_ = naive;
+
+ if (!do_naive_)
+ {
+ // Default leaf size is 1
+ // This gives best pruning empirically
+ // Use leaf_size=1 unless space is a big concern
+ Timers::StartTimer("emst/tree_building");
+
+ tree_ = new DTBTree(data_points_, old_from_new_permutation_, leafSize);
+
+ Timers::StopTimer("emst/tree_building");
+ }
+ else
+ {
+ tree_ = NULL;
+ old_from_new_permutation_.resize(0);
+ }
+
+ number_of_points_ = data_points_.n_cols;
+ edges_.resize(number_of_points_ - 1, EdgePair()); // fill with EdgePairs
+ connections_.Init(number_of_points_);
+
+ neighbors_in_component_.set_size(number_of_points_);
+ neighbors_out_component_.set_size(number_of_points_);
+ neighbors_distances_.set_size(number_of_points_);
+ neighbors_distances_.fill(DBL_MAX);
+
+ total_dist_ = 0.0;
+ number_of_loops_ = 0;
+ number_distance_prunes_ = 0;
+ number_component_prunes_ = 0;
+ number_leaf_computations_ = 0;
+ number_q_recursions_ = 0;
+ number_r_recursions_ = 0;
+ number_both_recursions_ = 0;
+} // Init
+
+/**
+ * Call this function after Init. It will iteratively find the nearest
+ * neighbor of each component until the MST is complete.
+ */
+void DualTreeBoruvka::ComputeMST(arma::mat& results)
+{
+ Timers::StartTimer("emst/MST_computation");
+
+ while (number_of_edges_ < (number_of_points_ - 1))
+ {
+ ComputeNeighbors_();
+
+ AddAllEdges_();
+
+ Cleanup_();
+
+ Log::Info << "Finished loop number: " << number_of_loops_ << std::endl;
+ Log::Info << number_of_edges_ << " edges found so far.\n\n";
+ /*
+ Log::Info << number_leaf_computations_ << " base cases.\n";
+ Log::Info << number_distance_prunes_ << " distance prunes.\n";
+ Log::Info << number_component_prunes_ << " component prunes.\n";
+ Log::Info << number_r_recursions_ << " reference recursions.\n";
+ Log::Info << number_q_recursions_ << " query recursions.\n";
+ Log::Info << number_both_recursions_ << " dual recursions.\n\n";
+ */
+ }
+
+ Timers::StopTimer("emst/MST_computation");
+
+ EmitResults_(results);
+
+ OutputResults_();
+} // ComputeMST
+
+
+
+
+
+#endif
\ No newline at end of file
Copied: mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp (from rev 10604, mlpack/trunk/src/mlpack/methods/emst/emst.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/emst/edge_pair.hpp 2011-12-06 23:51:55 UTC (rev 10605)
@@ -0,0 +1,80 @@
+/**
+ * @file emst.h
+ *
+ * @author Bill March (march at gatech.edu)
+ *
+ * This file contains utilities necessary for all of the minimum spanning tree
+ * algorithms.
+ */
+#ifndef __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+#define __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
+
+#include <mlpack/core.hpp>
+
+#include "union_find.hpp"
+
+namespace mlpack {
+namespace emst {
+
+/**
+ * An edge pair is simply two indices and a distance. It is used as the
+ * basic element of an edge list when computing a minimum spanning tree.
+ */
+class EdgePair
+{
+ private:
+ size_t lesser_index_;
+ size_t greater_index_;
+ double distance_;
+
+ public:
+ /**
+ * Initialize an EdgePair with two indices and a distance. The indices are
+ * called lesser and greater, implying that they be sorted before calling
+ * Init. However, this is not necessary for functionality; it is just a way
+ * to keep the edge list organized in other code.
+ */
+ void Init(size_t lesser, size_t greater, double dist)
+ {
+ mlpack::Log::Assert(lesser != greater,
+ "indices equal when creating EdgePair, lesser == greater");
+ lesser_index_ = lesser;
+ greater_index_ = greater;
+ distance_ = dist;
+ }
+
+ size_t lesser_index()
+ {
+ return lesser_index_;
+ }
+
+ void set_lesser_index(size_t index)
+ {
+ lesser_index_ = index;
+ }
+
+ size_t greater_index()
+ {
+ return greater_index_;
+ }
+
+ void set_greater_index(size_t index)
+ {
+ greater_index_ = index;
+ }
+
+ double distance() const
+ {
+ return distance_;
+ }
+
+ void set_distance(double new_dist)
+ {
+ distance_ = new_dist;
+ }
+}; // class EdgePair
+
+}; // namespace emst
+}; // namespace mlpack
+
+#endif // __MLPACK_METHODS_EMST_EDGE_PAIR_HPP
Deleted: mlpack/trunk/src/mlpack/methods/emst/emst.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst.hpp 2011-12-06 23:10:05 UTC (rev 10604)
+++ mlpack/trunk/src/mlpack/methods/emst/emst.hpp 2011-12-06 23:51:55 UTC (rev 10605)
@@ -1,80 +0,0 @@
-/**
- * @file emst.h
- *
- * @author Bill March (march at gatech.edu)
- *
- * This file contains utilities necessary for all of the minimum spanning tree
- * algorithms.
- */
-#ifndef __MLPACK_METHODS_EMST_EMST_HPP
-#define __MLPACK_METHODS_EMST_EMST_HPP
-
-#include <mlpack/core.hpp>
-
-#include "union_find.hpp"
-
-namespace mlpack {
-namespace emst {
-
-/**
- * An edge pair is simply two indices and a distance. It is used as the
- * basic element of an edge list when computing a minimum spanning tree.
- */
-class EdgePair
-{
- private:
- size_t lesser_index_;
- size_t greater_index_;
- double distance_;
-
- public:
- /**
- * Initialize an EdgePair with two indices and a distance. The indices are
- * called lesser and greater, implying that they be sorted before calling
- * Init. However, this is not necessary for functionality; it is just a way
- * to keep the edge list organized in other code.
- */
- void Init(size_t lesser, size_t greater, double dist)
- {
- mlpack::Log::Assert(lesser != greater,
- "indices equal when creating EdgePair, lesser == greater");
- lesser_index_ = lesser;
- greater_index_ = greater;
- distance_ = dist;
- }
-
- size_t lesser_index()
- {
- return lesser_index_;
- }
-
- void set_lesser_index(size_t index)
- {
- lesser_index_ = index;
- }
-
- size_t greater_index()
- {
- return greater_index_;
- }
-
- void set_greater_index(size_t index)
- {
- greater_index_ = index;
- }
-
- double distance() const
- {
- return distance_;
- }
-
- void set_distance(double new_dist)
- {
- distance_ = new_dist;
- }
-}; // class EdgePair
-
-}; // namespace emst
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_EMST_EMST_HPP
More information about the mlpack-svn
mailing list