[mlpack-svn] r10081 - mlpack/trunk/src/contrib/nslagle/myKDE

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Oct 30 21:58:09 EDT 2011


Author: nslagle
Date: 2011-10-30 21:58:09 -0400 (Sun, 30 Oct 2011)
New Revision: 10081

Modified:
   mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
   mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
Log:
contrib/nslagle: newest, not yet working KDE code; multi-bandwidth, dual-tree

Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp	2011-10-30 23:25:16 UTC (rev 10080)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp	2011-10-31 01:58:09 UTC (rev 10081)
@@ -1,36 +1,111 @@
 #ifndef KDE_DUAL_TREE_HPP
 #define KDE_DUAL_TREE_HPP
 
-#include <mlpack/core.h>
-#include <mlpack/core/tree/spacetree.h>
-#include <mlpack/core/tree/hrectbound.h>
 #include <iostream>
+#include <priority_queue>
 
+#include <mlpack/core.h>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/hrectbound.hpp>
+#include <mlpack/core/math/range.hpp>
+
 namespace mlpack
 {
-  namespace kde
+namespace kde
+{
+/* structure within the priority queue */
+struct queueNode
+{
+  TTree* T;
+  TTree* Q;
+  arma::vec du;
+  arma::vec dl;
+  double priority;
+  double bLower;
+  double bUpper;
+};
+class QueueNodeCompare
+{
+  bool reverse;
+ public:
+  QueueNodeCompare(const bool& revparam=false) : reverse(revparam) {}
+  bool operator() (const struct queueNode& lhs,
+                   const struct queueNode& rhs) const
   {
-    template <typename TKernel>
-    class KdeDualTree
-    {
-      private:
-        /* possibly, these refer to the same object */
-        tree::BinarySpaceTree<bound::HRectBound<2> >* trainingRoot;
-        tree::BinarySpaceTree<bound::HRectBound<2> >* queryRoot;
-        arma::mat trainingData;
-        arma::mat queryData;
-        double bandwidth;
-      public:
-        /* the two data sets are different */
-        KdeDualTree (arma::mat& trainingData, arma::mat& queryData);
-        /* the training data is also the query data */
-        KdeDualTree (arma::mat& trainingData);
-        /* find a suitable bandwidth */
-        double optimizeBandwidth (double lower, double upper, size_t attempts);
-    };
-  };
+    if (reverse)
+      return (lhs.priority>rhs.priority);
+    else
+      return (lhs.priority<rhs.priority);
+  }
 };
 
+
+template <typename TKernel = kernel::GaussianKernel,
+          typename TTree = tree::BinarySpaceTree<bound::HRectBound<2> > >
+class KdeDualTree
+{
+ private:
+  /* possibly, these refer to the same object */
+  TTree* referenceRoot;
+  TTree* queryRoot;
+  std::vector<size_t> referenceShuffledIndices;
+  std::vector<size_t> queryShuffledIndices;
+  arma::mat referenceData;
+  arma::mat queryData;
+  arma::mat upperBoundLevelByBandwidth;
+  arma::mat lowerBoundLevelByBandwidth;
+  arma::mat upperBoundQPointByBandwidth;
+  arma::mat lowerBoundQPointByBandwidth;
+  arma::mat upperBoundQNodeByBandwidth;
+  arma::mat lowerBoundQPointByBandwidth;
+  /* relative estimate to limit bandwidth calculations */
+  double delta;
+  /* relative error with respect to the density estimate */
+  double epsilon;
+  math::Range bandwidths;
+  std::priority_queue<struct queueNode,
+                      std::vector<struct queueNode>,
+                      QueueNodeCompare> nodePriorityQueue;
+  size_t bandwidthCount;
+  std::set<double> bandwidths;
+  size_t levelsInTree;
+  size_t queryTreeSize;
+
+  void SetDefaults();
+  void MultiBandwidthDualTree();
+  void MultiBandwidthDualTreeBase(TTree* Q,
+                                  TTree* T,
+                                  std::set<double> remainingBandwidths);
+  double GetPriority(TTree* nodeQ, TTree* nodeT)
+  {
+    return nodeQ->bound().MinDistance(*nodeT);
+  }
+  size_t GetLevelOfNode(TTree* node)
+  {
+    return levelsInTree - node->levelsBelow();
+  }
+ public:
+  /* the two data sets are different */
+  KdeDualTree (arma::mat& referenceData, arma::mat& queryData);
+  /* the reference data is also the query data */
+  KdeDualTree (arma::mat& referenceData);
+  /* setters and getters */
+  const math::Range& BandwidthRange() const { return bandwidthRange; }
+  const size_t& BandwidthCount() const { return bandwidthCount; }
+  const double& Delta() const { return delta; }
+  const double& Epsilon() const { return epsilon; }
+
+  void BandwidthRange(double l, double u) { bandwidthRange = math::Range(l,u); }
+  size_t& BandwidthCount() { return bandwidthCount; }
+  double& Delta() { return delta; }
+  double& Epsilon() { return epsilon; }
+};
+}; /* end namespace kde */
+}; /* end namespace mlpack */
+
+#define USE_KDE_DUAL_TREE_IMPL_HPP
 #include "kde_dual_tree_impl.hpp"
+#undef USE_KDE_DUAL_TREE_IMPL_HPP
 
 #endif

Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp	2011-10-30 23:25:16 UTC (rev 10080)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp	2011-10-31 01:58:09 UTC (rev 10081)
@@ -3,8 +3,10 @@
 #define KDE_DUAL_TREE_IMPL_HPP
 
 #ifndef KDE_DUAL_TREE_HPP
+#ifndef USE_KDE_DUAL_TREE_IMPL_HPP
 #error "Do not include this header directly."
 #endif
+#endif
 
 using namespace mlpack;
 using namespace mlpack::kde;
@@ -14,24 +16,54 @@
 namespace kde
 {
 
-template<typename TKernel>
-KdeDualTree<TKernel>::KdeDualTree (arma::mat& train,
-                                   arma::mat& query) :
-  trainingRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (train)),
-  queryRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (query))
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::KdeDualTree (arma::mat& reference,
+                                   arma::mat& query)
 {
-  trainingData = train;
+  referenceRoot (new TTree (reference)),
+  queryRoot (new TTree (query))
+  referenceData = reference;
   queryData = query;
+  levelsInTree = queryRoot->levelsBelow();
+  queryTreeSize = queryRoot->treeSize();
+  SetDefaults();
 }
-template<typename TKernel>
-KdeDualTree<TKernel>::KdeDualTree (arma::mat& train) :
-  trainingRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (train)),
-  queryRoot (trainingRoot)
+
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::KdeDualTree (arma::mat& reference)
 {
-  trainingData = train;
-  queryData = train;
+  referenceData = reference;
+  queryData = reference;
+
+  referenceRoot = new TTree (reference, referenceShuffledIndices);
+  queryRoot = referenceRoot;
+  levelsInTree = queryRoot->levelsBelow();
+  queryTreeSize = queryRoot->treeSize();
+  SetDefaults();
 }
 
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::SetDefaults()
+{
+  BandwidthRange(0.01, 100.0);
+  bandwidthCount = 10;
+  delta = epsilon = 0.05;
+}
+
+template<typename TKernel, typename TTree>
+void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTreeBase(TTree* Q,
+                                TTree* T,
+                                std::set<double> remainingBandwidths)
+{
+  for (size_t q = Q->begin(); q < Q->end(); ++q)
+  {
+    for (size_t t = T->begin(); t < T->end(); ++t)
+    {
+    }
+  }
+}
+
+
 };
 };
 




More information about the mlpack-svn mailing list