[mlpack-svn] r15777 - mlpack/trunk/src/mlpack/methods/fastmks
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Sep 13 15:53:57 EDT 2013
Author: rcurtin
Date: Fri Sep 13 15:53:57 2013
New Revision: 15777
Log:
Print base cases and scores as output even when not debugging.
Modified:
mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/fastmks/fastmks_impl.hpp Fri Sep 13 15:53:57 2013
@@ -36,6 +36,9 @@
if (!naive)
referenceTree = new TreeType(referenceSet);
+ if (!naive && !single)
+ queryTree = new TreeType(referenceSet);
+
Timer::Stop("tree_building");
}
@@ -86,6 +89,9 @@
if (!naive)
referenceTree = new TreeType(referenceSet, metric);
+ if (!naive && !single)
+ queryTree = new TreeType(referenceSet, metric);
+
Timer::Stop("tree_building");
}
@@ -132,7 +138,9 @@
naive(naive),
metric(referenceTree->Metric())
{
- // Nothing to do.
+ // The query tree cannot be the same as the reference tree.
+ if (referenceTree)
+ queryTree = new TreeType(*referenceTree);
}
// Two datasets, pre-built trees.
@@ -166,6 +174,12 @@
if (referenceTree)
delete referenceTree;
}
+ else if (&querySet == &referenceSet)
+ {
+ // The user passed in a reference tree which we needed to copy.
+ if (queryTree)
+ delete queryTree;
+ }
}
template<typename KernelType, typename TreeType>
@@ -228,6 +242,9 @@
Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+ Log::Info << rules.BaseCases() << " base cases." << std::endl;
+ Log::Info << rules.Scores() << " scores." << std::endl;
+
Timer::Stop("computing_products");
return;
}
@@ -238,14 +255,13 @@
typename TreeType::template DualTreeTraverser<RuleType> traverser(rules);
- if (queryTree)
- traverser.Traverse(*queryTree, *referenceTree);
- else
- traverser.Traverse(*referenceTree, *referenceTree);
+ traverser.Traverse(*queryTree, *referenceTree);
const size_t numPrunes = traverser.NumPrunes();
Log::Info << "Pruned " << numPrunes << " nodes." << std::endl;
+ Log::Info << rules.BaseCases() << " base cases." << std::endl;
+ Log::Info << rules.Scores() << " scores." << std::endl;
Timer::Stop("computing_products");
return;
More information about the mlpack-svn
mailing list