[mlpack-svn] r10131 - mlpack/trunk/src/contrib/nslagle/myKDE
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 3 13:56:35 EDT 2011
Author: nslagle
Date: 2011-11-03 13:56:35 -0400 (Thu, 03 Nov 2011)
New Revision: 10131
Modified:
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
Log:
mlpack/contrib/nslagle: add the multi-tree base function
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-11-03 17:43:28 UTC (rev 10130)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-11-03 17:56:35 UTC (rev 10131)
@@ -46,6 +46,7 @@
class KdeDualTree
{
private:
+ TKernel kernel;
/* possibly, these refer to the same object */
TTree* referenceRoot;
TTree* queryRoot;
@@ -58,7 +59,7 @@
arma::mat upperBoundQPointByBandwidth;
arma::mat lowerBoundQPointByBandwidth;
arma::mat upperBoundQNodeByBandwidth;
- arma::mat lowerBoundQPointByBandwidth;
+ arma::mat lowerBoundQNodeByBandwidth;
/* relative estimate to limit bandwidth calculations */
double delta;
/* relative error with respect to the density estimate */
@@ -75,7 +76,7 @@
void SetDefaults();
void MultiBandwidthDualTree();
void MultiBandwidthDualTreeBase(TTree* Q,
- TTree* T,
+ TTree* T, size_t QIndex,
std::set<double> remainingBandwidths);
double GetPriority(TTree* nodeQ, TTree* nodeT)
{
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-11-03 17:43:28 UTC (rev 10130)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-11-03 17:56:35 UTC (rev 10131)
@@ -48,22 +48,84 @@
BandwidthRange(0.01, 100.0);
bandwidthCount = 10;
delta = epsilon = 0.05;
+ kernel = TKernel(1.0);
}
template<typename TKernel, typename TTree>
void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTreeBase(TTree* Q,
- TTree* T,
+ TTree* T, size_t QIndex,
std::set<double> remainingBandwidths)
{
+ size_t sizeOfTNode = T->count();
+ size_t sizeOfQNode = Q->count();
for (size_t q = Q->begin(); q < Q->end(); ++q)
{
+ arma::vec queryPoint = queryData.unsafe_col(q);
for (size_t t = T->begin(); t < T->end(); ++t)
{
+ arma::vec diff = queryPoint - referenceData.unsafe_col(t);
+ double distSquared = arma::dot(diff, diff);
+ std::set<double>::iterator bIt = remainingBandwidths.end();
+ size_t bandwidthIndex = bandwidthCount;
+ while (bIt != remainingBandwidths.begin())
+ {
+ --bIt;
+ --bandwidthIndex;
+ double bandwidth = *bIt;
+ double scaledProduct = distSquared / (bandwidth * bandwidth);
+ /* TODO: determine the power of the incoming argument */
+ double contribution = kernel(scaledProduct);
+ if (contribution > DBL_EPSILON)
+ {
+ upperBoundQPointByBandwidth(q, bandwidthIndex) += contribution;
+ lowerBoundQPointByBandwidth(q, bandwidthIndex) += contribution;
+ }
+ else
+ {
+ break;
+ }
+ }
}
+ for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+ {
+ upperBoundQPointByBandwidth(q, bIndex) -= sizeOfTNode;
+ }
}
+ size_t levelOfQ = GetLevelOfNode(Q);
+ for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+ {
+ /* subtract out the current log-likelihood amount for this Q node so we can readjust
+ * the Q node bounds by current bandwidth */
+ upperBoundLevelByBandwidth(levelOfQ, bIndex) -=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(levelOfQ, bIndex) -=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ arma::vec upperBound = upperBoundQPointByBandwidth.unsafe_col(bIndex);
+ arma::vec lowerBound = lowerBoundQPointByBandwidth.unsafe_col(bIndex);
+ double minimumLower = lowerBoundQPointByBandwidth(Q->begin(), bIndex);
+ double maximumUpper = upperBoundQPointByBandwidth(Q->begin(), bIndex);
+ for (size_t q = Q->begin(); q < Q->end(); ++q)
+ {
+ if (lowerBoundQPointByBandwidth(q,bIndex) < minimumLower)
+ {
+ minimumLower = lowerBoundQPointByBandwidth(q,bIndex);
+ }
+ if (upperBoundQPointByBandwidth(q,bIndex) > maximumUpper)
+ {
+ maximumUpper = upperBoundQPointByBandwidth(q,bIndex);
+ }
+ }
+ /* adjust Q node bounds, then add the new quantities to the level by bandwidth
+ * log-likelihood bounds */
+ lowerBoundQNodeByBandwidth(QIndex, bIndex) = minimumLower;
+ upperBoundQNodeByBandwidth(QIndex, bIndex) = maximumUpper - sizeOfTNode;
+ upperBoundLevelByBandwidth(levelOfQ, bIndex) +=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(levelOfQ, bIndex) +=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ }
}
-
};
};
More information about the mlpack-svn
mailing list