[mlpack-svn] r10131 - mlpack/trunk/src/contrib/nslagle/myKDE

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Thu Nov 3 13:56:35 EDT 2011


Author: nslagle
Date: 2011-11-03 13:56:35 -0400 (Thu, 03 Nov 2011)
New Revision: 10131

Modified:
   mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
   mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
Log:
mlpack/contrib/nslagle: add the multi-tree base function

Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp	2011-11-03 17:43:28 UTC (rev 10130)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp	2011-11-03 17:56:35 UTC (rev 10131)
@@ -46,6 +46,7 @@
 class KdeDualTree
 {
  private:
+  TKernel kernel;
   /* possibly, these refer to the same object */
   TTree* referenceRoot;
   TTree* queryRoot;
@@ -58,7 +59,7 @@
   arma::mat upperBoundQPointByBandwidth;
   arma::mat lowerBoundQPointByBandwidth;
   arma::mat upperBoundQNodeByBandwidth;
-  arma::mat lowerBoundQPointByBandwidth;
+  arma::mat lowerBoundQNodeByBandwidth;
   /* relative estimate to limit bandwidth calculations */
   double delta;
   /* relative error with respect to the density estimate */
@@ -75,7 +76,7 @@
   void SetDefaults();
   void MultiBandwidthDualTree();
   void MultiBandwidthDualTreeBase(TTree* Q,
-                                  TTree* T,
+                                  TTree* T, size_t QIndex,
                                   std::set<double> remainingBandwidths);
   double GetPriority(TTree* nodeQ, TTree* nodeT)
   {

Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp	2011-11-03 17:43:28 UTC (rev 10130)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp	2011-11-03 17:56:35 UTC (rev 10131)
@@ -48,22 +48,84 @@
   BandwidthRange(0.01, 100.0);
   bandwidthCount = 10;
   delta = epsilon = 0.05;
+  kernel = TKernel(1.0);
 }
 
 template<typename TKernel, typename TTree>
 void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTreeBase(TTree* Q,
-                                TTree* T,
+                                TTree* T, size_t QIndex,
                                 std::set<double> remainingBandwidths)
 {
+  size_t sizeOfTNode = T->count();
+  size_t sizeOfQNode = Q->count();
   for (size_t q = Q->begin(); q < Q->end(); ++q)
   {
+    arma::vec queryPoint = queryData.unsafe_col(q);
     for (size_t t = T->begin(); t < T->end(); ++t)
     {
+      arma::vec diff = queryPoint - referenceData.unsafe_col(t);
+      double distSquared = arma::dot(diff, diff);
+      std::set<double>::iterator bIt = remainingBandwidths.end();
+      size_t bandwidthIndex = bandwidthCount;
+      while (bIt != remainingBandwidths.begin())
+      {
+        --bIt;
+        --bandwidthIndex;
+        double bandwidth = *bIt;
+        double scaledProduct = distSquared / (bandwidth * bandwidth);
+        /* TODO: determine the power of the incoming argument */
+        double contribution = kernel(scaledProduct);
+        if (contribution > DBL_EPSILON)
+        {
+          upperBoundQPointByBandwidth(q, bandwidthIndex) += contribution;
+          lowerBoundQPointByBandwidth(q, bandwidthIndex) += contribution;
+        }
+        else
+        {
+          break;
+        }
+      }
     }
+    for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+    {
+      upperBoundQPointByBandwidth(q, bIndex) -= sizeOfTNode;
+    }
   }
+  size_t levelOfQ = GetLevelOfNode(Q);
+  for (size_t bIndex = bandwidthCount - remainingBandwidths.size(); bIndex < remainingBandwidths.size(); ++bIndex)
+  {
+    /* subtract out the current log-likelihood amount for this Q node so we can readjust
+     *   the Q node bounds by current bandwidth */
+    upperBoundLevelByBandwidth(levelOfQ, bIndex) -=
+        sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+    lowerBoundLevelByBandwidth(levelOfQ, bIndex) -=
+        sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+    arma::vec upperBound = upperBoundQPointByBandwidth.unsafe_col(bIndex);
+    arma::vec lowerBound = lowerBoundQPointByBandwidth.unsafe_col(bIndex);
+    double minimumLower = lowerBoundQPointByBandwidth(Q->begin(), bIndex);
+    double maximumUpper = upperBoundQPointByBandwidth(Q->begin(), bIndex);
+    for (size_t q = Q->begin(); q < Q->end(); ++q)
+    {
+      if (lowerBoundQPointByBandwidth(q,bIndex) < minimumLower)
+      {
+        minimumLower = lowerBoundQPointByBandwidth(q,bIndex);
+      }
+      if (upperBoundQPointByBandwidth(q,bIndex) > maximumUpper)
+      {
+        maximumUpper = upperBoundQPointByBandwidth(q,bIndex);
+      }
+    }
+    /* adjust Q node bounds, then add the new quantities to the level by bandwidth
+     *   log-likelihood bounds */
+    lowerBoundQNodeByBandwidth(QIndex, bIndex) = minimumLower;
+    upperBoundQNodeByBandwidth(QIndex, bIndex) = maximumUpper - sizeOfTNode;
+    upperBoundLevelByBandwidth(levelOfQ, bIndex) +=
+        sizeOfQNode * log(upperBoundQNodeByBandwidth(QIndex, bIndex));
+    lowerBoundLevelByBandwidth(levelOfQ, bIndex) +=
+        sizeOfQNode * log(lowerBoundQNodeByBandwidth(QIndex, bIndex));
+  }
 }
 
-
 };
 };
 




More information about the mlpack-svn mailing list