[mlpack-svn] r16685 - mlpack/trunk/src/mlpack/methods/emst
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Jun 12 16:15:49 EDT 2014
Author: rcurtin
Date: Thu Jun 12 16:15:48 2014
New Revision: 16685
Log:
Remove leafSize parameter from DTB constructor.
Modified:
mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb.hpp Thu Jun 12 16:15:48 2014
@@ -79,7 +79,7 @@
//! Copy of the data (if necessary).
typename TreeType::Mat dataCopy;
//! Reference to the data (this is what should be used for accessing data).
- typename TreeType::Mat& data;
+ const typename TreeType::Mat& data;
//! Pointer to the root of the tree.
TreeType* tree;
@@ -130,7 +130,6 @@
*/
DualTreeBoruvka(const typename TreeType::Mat& dataset,
const bool naive = false,
- const size_t leafSize = 1,
const MetricType metric = MetricType());
/**
@@ -150,7 +149,8 @@
* @param tree Pre-built tree.
* @param dataset Dataset corresponding to the pre-built tree.
*/
- DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
+ DualTreeBoruvka(TreeType* tree,
+ const typename TreeType::Mat& dataset,
const MetricType metric = MetricType());
/**
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_impl.hpp Thu Jun 12 16:15:48 2014
@@ -50,11 +50,10 @@
DualTreeBoruvka<MetricType, TreeType>::DualTreeBoruvka(
const typename TreeType::Mat& dataset,
const bool naive,
- const size_t leafSize,
const MetricType metric) :
dataCopy(dataset),
data(dataCopy), // The reference points to our copy of the data.
- ownTree(true),
+ ownTree(!naive),
naive(naive),
connections(data.n_cols),
totalDist(0.0),
@@ -62,17 +61,10 @@
{
Timer::Start("emst/tree_building");
+ // Default leaf size is 1; this gives the best pruning, empirically. Use
+ // leaf_size = 1 unless space is a big concern.
if (!naive)
- {
- // Default leaf size is 1; this gives the best pruning, empirically. Use
- // leaf_size = 1 unless space is a big concern.
- tree = new TreeType(data, oldFromNew, leafSize);
- }
- else
- {
- // Naive tree holds all data in one leaf.
- tree = new TreeType(data, oldFromNew, data.n_cols);
- }
+ tree = new TreeType(dataCopy, oldFromNew);
Timer::Stop("emst/tree_building");
@@ -91,7 +83,7 @@
const MetricType metric) :
data(dataset),
tree(tree),
- ownTree(true),
+ ownTree(false),
naive(false),
connections(data.n_cols),
totalDist(0.0),
@@ -126,19 +118,32 @@
typedef DTBRules<MetricType, TreeType> RuleType;
RuleType rules(data, connections, neighborsDistances, neighborsInComponent,
neighborsOutComponent, metric);
-
while (edges.size() < (data.n_cols - 1))
{
- typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
-
- traverser.Traverse(*tree, *tree);
+ if (naive)
+ {
+ // Full O(N^2) traversal.
+ for (size_t i = 0; i < data.n_cols; ++i)
+ for (size_t j = 0; j < data.n_cols; ++j)
+ rules.BaseCase(i, j);
+ }
+ else
+ {
+ typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
+ traverser.Traverse(*tree, *tree);
+ }
AddAllEdges();
Cleanup();
Log::Info << edges.size() << " edges found so far." << std::endl;
- Log::Info << traverser.NumPrunes() << " nodes pruned." << std::endl;
+ if (!naive)
+ {
+ Log::Info << rules.BaseCases() << " cumulative base cases." << std::endl;
+ Log::Info << rules.Scores() << " cumulative node combinations scored."
+ << std::endl;
+ }
}
Timer::Stop("emst/mst_computation");
@@ -146,7 +151,7 @@
EmitResults(results);
Log::Info << "Total spanning tree length: " << totalDist << std::endl;
-} // ComputeMST
+}
/**
* Adds a single edge to the edge list
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules.hpp Thu Jun 12 16:15:48 2014
@@ -74,7 +74,7 @@
* @param queryNode Candidate query node to recurse into.
* @param referenceNode Candidate reference node to recurse into.
*/
- double Score(TreeType& queryNode, TreeType& referenceNode) const;
+ double Score(TreeType& queryNode, TreeType& referenceNode);
/**
* Get the score for recursion order, passing the base case result (in the
@@ -88,7 +88,7 @@
*/
double Score(TreeType& queryNode,
TreeType& referenceNode,
- const double baseCaseResult) const;
+ const double baseCaseResult);
/**
* Re-evaluate the score for recursion order. A low score indicates priority
@@ -110,6 +110,16 @@
const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
TraversalInfoType& TraversalInfo() { return traversalInfo; }
+ //! Get the number of base cases performed.
+ size_t BaseCases() const { return baseCases; }
+ //! Modify the number of base cases performed.
+ size_t& BaseCases() { return baseCases; }
+
+ //! Get the number of node combinations that have been scored.
+ size_t Scores() const { return scores; }
+ //! Modify the number of node combinations that have been scored.
+ size_t& Scores() { return scores; }
+
private:
//! The data points.
const arma::mat& dataSet;
@@ -138,6 +148,11 @@
TraversalInfoType traversalInfo;
+ //! The number of base cases calculated.
+ size_t baseCases;
+ //! The number of node combinations that have been scored.
+ size_t scores;
+
}; // class DTBRules
} // emst namespace
Modified: mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/dtb_rules_impl.hpp Thu Jun 12 16:15:48 2014
@@ -24,7 +24,9 @@
neighborsDistances(neighborsDistances),
neighborsInComponent(neighborsInComponent),
neighborsOutComponent(neighborsOutComponent),
- metric(metric)
+ metric(metric),
+ baseCases(0),
+ scores(0)
{
// Nothing else to do.
}
@@ -46,6 +48,7 @@
if (queryComponentIndex != referenceComponentIndex)
{
+ ++baseCases;
double distance = metric.Evaluate(dataSet.col(queryIndex),
dataSet.col(referenceIndex));
@@ -127,7 +130,7 @@
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
- TreeType& referenceNode) const
+ TreeType& referenceNode)
{
// If all the queries belong to the same component as all the references
// then we prune.
@@ -136,6 +139,7 @@
referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
+ ++scores;
const double distance = queryNode.MinDistance(&referenceNode);
const double bound = CalculateBound(queryNode);
@@ -147,7 +151,7 @@
template<typename MetricType, typename TreeType>
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
TreeType& referenceNode,
- const double baseCaseResult) const
+ const double baseCaseResult)
{
// If all the queries belong to the same component as all the references
// then we prune.
@@ -156,6 +160,7 @@
referenceNode.Stat().ComponentMembership()))
return DBL_MAX;
+ ++scores;
const double distance = queryNode.MinDistance(referenceNode, baseCaseResult);
const double bound = CalculateBound(queryNode);
Modified: mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp (original)
+++ mlpack/trunk/src/mlpack/methods/emst/emst_main.cpp Thu Jun 12 16:15:48 2014
@@ -80,18 +80,48 @@
<< ")! Must be greater than or equal to 1." << std::endl;
}
- // Initialize the tree and get ready to compute the MST.
+ // Initialize the tree and get ready to compute the MST. Compute the tree
+ // by hand.
const size_t leafSize = (size_t) CLI::GetParam<int>("leaf_size");
- DualTreeBoruvka<> dtb(dataPoints, false, leafSize);
+
+ Timer::Start("tree_building");
+ std::vector<size_t> oldFromNew;
+ tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> tree(dataPoints,
+ oldFromNew, leafSize);
+ metric::LMetric<2, true> metric;
+ Timer::Stop("tree_building");
+
+ DualTreeBoruvka<> dtb(&tree, dataPoints, metric);
// Run the DTB algorithm.
Log::Info << "Calculating minimum spanning tree." << endl;
arma::mat results;
dtb.ComputeMST(results);
+ // Unmap the results.
+ arma::mat unmappedResults(results.n_rows, results.n_cols);
+ for (size_t i = 0; i < results.n_cols; ++i)
+ {
+ const size_t indexA = oldFromNew[size_t(results(0, i))];
+ const size_t indexB = oldFromNew[size_t(results(1, i))];
+
+ if (indexA < indexB)
+ {
+ unmappedResults(0, i) = indexA;
+ unmappedResults(1, i) = indexB;
+ }
+ else
+ {
+ unmappedResults(0, i) = indexB;
+ unmappedResults(1, i) = indexA;
+ }
+
+ unmappedResults(2, i) = results(2, i);
+ }
+
// Output the results.
const string outputFilename = CLI::GetParam<string>("output_file");
- data::Save(outputFilename, results, true);
+ data::Save(outputFilename, unmappedResults, true);
}
}
More information about the mlpack-svn
mailing list