[mlpack-svn] r10133 - mlpack/trunk/src/contrib/nslagle/myKDE
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 3 16:23:20 EDT 2011
Author: nslagle
Date: 2011-11-03 16:23:20 -0400 (Thu, 03 Nov 2011)
New Revision: 10133
Modified:
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
Log:
mlpack/contrib/nslagle: add the multi-dualtree function
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-11-03 19:05:10 UTC (rev 10132)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-11-03 20:23:20 UTC (rev 10133)
@@ -10,6 +10,8 @@
#include <mlpack/core/tree/hrectbound.hpp>
#include <mlpack/core/math/range.hpp>
+#define PRIORITY_MAX DBL_MAX
+
namespace mlpack
{
namespace kde
@@ -19,11 +21,12 @@
{
TTree* T;
TTree* Q;
- arma::vec du;
- arma::vec dl;
+ size_t QIndex;
+ arma::vec deltaLower;
+ arma::vec deltaUpper;
double priority;
- double bLower;
- double bUpper;
+ size_t bLowerIndex;
+ size_t bUpperIndex;
};
class QueueNodeCompare
{
@@ -40,7 +43,6 @@
}
};
-
template <typename TKernel = kernel::GaussianKernel,
typename TTree = tree::BinarySpaceTree<bound::HRectBound<2> > >
class KdeDualTree
@@ -69,7 +71,7 @@
std::vector<struct queueNode>,
QueueNodeCompare> nodePriorityQueue;
size_t bandwidthCount;
- std::set<double> bandwidths;
+ std::vector<double> bandwidths;
size_t levelsInTree;
size_t queryTreeSize;
@@ -77,7 +79,7 @@
void MultiBandwidthDualTree();
void MultiBandwidthDualTreeBase(TTree* Q,
TTree* T, size_t QIndex,
- std::set<double> remainingBandwidths);
+ size_t lowerBIndex, size_t upperBIndex);
double GetPriority(TTree* nodeQ, TTree* nodeT)
{
return nodeQ->bound().MinDistance(*nodeT);
@@ -86,6 +88,7 @@
{
return levelsInTree - node->levelsBelow();
}
+ void Winnow(size_t bLower, size_t bUpper, size_t* newLower, size_t* newUpper);
public:
/* the two data sets are different */
KdeDualTree (arma::mat& referenceData, arma::mat& queryData);
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-11-03 19:05:10 UTC (rev 10132)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-11-03 20:23:20 UTC (rev 10133)
@@ -52,9 +52,254 @@
}
template<typename TKernel, typename TTree>
+void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTree()
+{
+ while (!nodePriorityQueue.empty())
+ {
+ /* get the first structure in the queue */
+ struct queueNode queueCurrent = nodePriorityQueue.pop();
+ TTree* Q = queueCurrent.Q;
+ TTree* T = queueCurrent.T;
+ size_t sizeOfTNode = T->count();
+ size_t sizeOfQNode = Q->count();
+ size_t QIndex = queueCurrent.QIndex;
+ arma::vec deltaLower = queueCurrent.deltaLower;
+ arma::vec deltaUpper = queueCurrent.deltaUpper;
+ /* v is the level of the Q node */
+ size_t v = GetLevelOfNode(Q);
+ size_t bUpper = queueCurrent.bUpperIndex;
+ size_t bLower = queueCurrent.bLowerIndex;
+ /* check to see whether we've reached the epsilon condition */
+ bool epsilonCondition = true;
+ for (size_t bIndex = queueCurrent.bLowerIndex;
+ bIndex <= queueCurrent.bUpperIndex;
+ ++bIndex)
+ {
+ if (lowerBoundLevelByBandwidth(v,bIndex) > DBL_EPSILON)
+ {
+ double constraint = (upperBoundLevelByBandwidth(v,bIndex) -
+ lowerBoundLevelByBandwidth(v,bIndex)) /
+ lowerBoundLevelByBandwidth(v,bIndex);
+ if (constraint > epsilon)
+ {
+ epsilonCondition = false;
+ break;
+ }
+ }
+ else
+ {
+ /* we haven't set this lower bound */
+ epsilonCondition = false;
+ break;
+ }
+ }
+ /* return */
+ if (epsilonCondition)
+ {
+ return;
+ }
+ /* we didn't meet the criteria; let's narrow the bandwidth range */
+ Winnow(v, &bLower, &bUpper);
+ if (queueCurrent.priority < PRIORITY_MAX)
+ {
+ double dMin = pow(Q->bound().MinDistance(T->bound()), 0.5);
+ double dMax = pow(Q->bound().MaxDistance(T->bound()), 0.5);
+ /* iterate through the remaining bandwidths */
+ bool meetsDeltaCondition = true;
+ std::vector<bool> deltaCondition;
+ for (size_t bIndex = bLower; bIndex <= bUpper; ++bIndex)
+ {
+ double bandwidth = bandwidths[bIndex];
+ double dl = sizeOfTNode * kernel(dMax / bandwidth);
+ double du = sizeOfTNode * kernel(dMin / bandwidth);
+ deltaLower(bIndex) = dl;
+ deltaUpper(bIndex) = du - sizeOfTNode;
+ if ((du - dl)/(lowerBoundQNodeByBandwidth(QIndex, bIndex) + dl) < delta)
+ {
+ for (size_t q = Q->begin(); q < Q->end(); ++q)
+ {
+ lowerBoundQPointByBandwidth(q,bIndex) += deltaLower(bIndex);
+ upperBoundQPointByBandwidth(q,bIndex) += deltaUpper(bIndex);
+ }
+ /* subtract the current log-likelihood */
+ upperBoundLevelByBandwidth(v, bIndex) -=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(v, bIndex) -=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ /* adjust the current inner portion */
+ lowerBoundQNodeByBandwidth(QIndex, bIndex) += deltaLower(bIndex);
+ upperBoundQNodeByBandwidth(QIndex, bIndex) += deltaUpper(bIndex);
+ /* add the corrected log-likelihood */
+ upperBoundLevelByBandwidth(v, bIndex) +=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(v, bIndex) +=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ }
+ /* check the delta condition with the new values */
+ if ((du - dl)/(lowerBoundQNodeByBandwidth(QIndex, bIndex) + dl) >= delta)
+ {
+ deltaCondition.push_back(false);
+ meetsDeltaCondition = false;
+ }
+ else
+ {
+ deltaCondition.push_back(true);
+ }
+ }
+ /* check whether we met the delta condition for
+ * all applicable bandwidths */
+ if (meetsDeltaCondition)
+ {
+ /* adjust the current structure, then reinsert it into the queue */
+ queueCurrent.dl = deltaLower;
+ queueCurrent.du = deltaUpper;
+ queueCurrent.bUpperIndex = bUpper;
+ queueCurrent.bLowerIndex = bLower;
+ queueCurrent.priority += PRIORITY_MAX;
+ nodePriorityQueue.insert(queueCurrent);
+ continue;
+ }
+ else
+ {
+ /* winnow according to the delta conditions */
+ std::vector<bool>::iterator bIt = deltaCondition.begin();
+ while (*bIt && bIt != deltaCondition.end())
+ {
+ ++bIt;
+ ++bLower;
+ }
+ bIt = deltaCondition.end();
+ --bIt;
+ while (*bIt && bIt != deltaCondition.begin())
+ {
+ --bIt;
+ --bUpper;
+ }
+ }
+ }
+ else /* the priority exceeds the maximum available */
+ {
+ deltaLower = -deltaLower;
+ deltaUpper = -deltaUpper;
+ for (size_t bIndex = bLower; bIndex <= bUpper; ++bIndex)
+ {
+ for (size_t q = Q->begin(); q < Q->end(); ++q)
+ {
+ lowerBoundQPointByBandwidth(q,bIndex) += deltaLower(bIndex);
+ upperBoundQPointByBandwidth(q,bIndex) += deltaUpper(bIndex);
+ }
+ /* subtract the current log-likelihood */
+ upperBoundLevelByBandwidth(v, bIndex) -=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(v, bIndex) -=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ /* adjust the current inner portion */
+ lowerBoundQNodeByBandwidth(QIndex, bIndex) += deltaLower(bIndex);
+ upperBoundQNodeByBandwidth(QIndex, bIndex) += deltaUpper(bIndex);
+ /* add the corrected log-likelihood */
+ upperBoundLevelByBandwidth(v, bIndex) +=
+ sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+ lowerBoundLevelByBandwidth(v, bIndex) +=
+ sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+ }
+ }
+ if (Q->is_leaf() && T->is_leaf())
+ {
+ MultiBandwidthDualTreeBase(Q, T, QIndex, bLower, bUpper);
+ }
+ double priority = pow(Q->bound().MinDistance(T->bound()), 0.5);
+ if (!Q->is_left() && !T->is_leaf())
+ {
+ struct queueNode leftLeft =
+ {T->left(),Q->left(), 2*QIndex + 1, arma::vec(deltaUpper),
+ arma::vec(deltaLower), priority, bLower, bUpper};
+ struct queueNode leftRight =
+ {T->left(),Q->right(), 2*QIndex + 2, arma::vec(deltaUpper),
+ arma::vec(deltaLower), priority, bLower, bUpper};
+ struct queueNode rightLeft =
+ {T->right(),Q->left(), 2*QIndex + 1, arma::vec(deltaUpper),
+ arma::vec(deltaLower), priority, bLower, bUpper};
+ struct queueNode rightRight =
+ {T->right(),Q->right(), 2*QIndex + 2, arma::vec(deltaUpper),
+ arma::vec(deltaLower), priority, bLower, bUpper};
+ nodePriorityQueue.insert(leftLeft);
+ nodePriorityQueue.insert(leftRight);
+ nodePriorityQueue.insert(rightLeft);
+ nodePriorityQueue.insert(rightRight);
+ }
+ }
+}
+
+void KdeDualTree<TKernel, TTree>::Winnow(size_t level,
+ size_t* bLower,
+ size_t* bUpper)
+{
+ size_t bIndex = *bLower;
+ double constraint = delta;
+ bool enteredTheLoop = false;
+ /* bring the lower up */
+ if (lowerBoundLevelByBandwidth(level,bIndex) > DBL_EPSILON)
+ {
+ constraint = (upperBoundLevelByBandwidth(level,bIndex) -
+ lowerBoundLevelByBandwidth(level,bIndex)) /
+ lowerBoundLevelByBandwidth(level,bIndex);
+ }
+ while (constraint < delta && bIndex <= *bUpper)
+ {
+ enteredTheLoop = true;
+ ++bIndex;
+ if (lowerBoundLevelByBandwidth(level,bIndex) > DBL_EPSILON)
+ {
+ constraint = (upperBoundLevelByBandwidth(level,bIndex) -
+ lowerBoundLevelByBandwidth(level,bIndex)) /
+ lowerBoundLevelByBandwidth(level,bIndex);
+ }
+ else
+ {
+ break;
+ }
+ }
+ if (enteredTheLoop)
+ {
+ *bLower = bIndex - 1;
+ }
+
+ bIndex = *bUpper;
+ constraint = delta;
+ enteredTheLoop = false;
+ /* bring the lower up */
+ if (lowerBoundLevelByBandwidth(level,bIndex) > DBL_EPSILON)
+ {
+ constraint = (upperBoundLevelByBandwidth(level,bIndex) -
+ lowerBoundLevelByBandwidth(level,bIndex)) /
+ lowerBoundLevelByBandwidth(level,bIndex);
+ }
+ while (constraint < delta && bIndex >= *bLower)
+ {
+ enteredTheLoop = true;
+ --bIndex;
+ if (lowerBoundLevelByBandwidth(level,bIndex) > DBL_EPSILON)
+ {
+ constraint = (upperBoundLevelByBandwidth(level,bIndex) -
+ lowerBoundLevelByBandwidth(level,bIndex)) /
+ lowerBoundLevelByBandwidth(level,bIndex);
+ }
+ else
+ {
+ break;
+ }
+ }
+ if (enteredTheLoop)
+ {
+ *bUpper = bIndex + 1;
+ }
+}
+
+
+template<typename TKernel, typename TTree>
void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTreeBase(TTree* Q,
TTree* T, size_t QIndex,
- std::set<double> remainingBandwidths)
+ size_t lowerBIndex, size_t upperBIndex)
{
size_t sizeOfTNode = T->count();
size_t sizeOfQNode = Q->count();
@@ -65,14 +310,12 @@
{
arma::vec diff = queryPoint - referenceData.unsafe_col(t);
double distSquared = arma::dot(diff, diff);
- std::set<double>::iterator bIt = remainingBandwidths.end();
- size_t bandwidthIndex = bandwidthCount;
- while (bIt != remainingBandwidths.begin())
+ size_t bandwidthIndex = upperBIndex;
+ while (bandwidthIndex > lowerBIndex)
{
- --bIt;
--bandwidthIndex;
- double bandwidth = *bIt;
- double scaledProduct = distSquared / (bandwidth * bandwidth);
+ double bandwidth = bandwidths[bandwidthIndex];
+ double scaledProduct = pow(distSquared, 0.5) / bandwidth;
/* TODO: determine the power of the incoming argument */
double contribution = kernel(scaledProduct);
if (contribution > DBL_EPSILON)
@@ -86,13 +329,13 @@
}
}
}
- for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+ for (size_t bIndex = lowerBIndex; bIndex <= upperBIndex; ++bIndex)
{
upperBoundQPointByBandwidth(q, bIndex) -= sizeOfTNode;
}
}
size_t levelOfQ = GetLevelOfNode(Q);
- for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+ for (size_t bIndex = lowerBIndex; bIndex <= upperBIndex; ++bIndex)
{
/* subtract out the current log-likelihood amount for this Q node so we can readjust
* the Q node bounds by current bandwidth */
More information about the mlpack-svn
mailing list