[mlpack-svn] r10081 - mlpack/trunk/src/contrib/nslagle/myKDE
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Oct 30 21:58:09 EDT 2011
Author: nslagle
Date: 2011-10-30 21:58:09 -0400 (Sun, 30 Oct 2011)
New Revision: 10081
Modified:
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
Log:
contrib/nslagle: newest, not yet working KDE code; multi-bandwidth, dual-tree
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-10-30 23:25:16 UTC (rev 10080)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree.hpp 2011-10-31 01:58:09 UTC (rev 10081)
@@ -1,36 +1,111 @@
#ifndef KDE_DUAL_TREE_HPP
#define KDE_DUAL_TREE_HPP
-#include <mlpack/core.h>
-#include <mlpack/core/tree/spacetree.h>
-#include <mlpack/core/tree/hrectbound.h>
#include <iostream>
+#include <priority_queue>
+#include <mlpack/core.h>
+#include <mlpack/core/kernels/gaussian_kernel.hpp>
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/hrectbound.hpp>
+#include <mlpack/core/math/range.hpp>
+
namespace mlpack
{
- namespace kde
+namespace kde
+{
+/* structure within the priority queue */
+struct queueNode
+{
+ TTree* T;
+ TTree* Q;
+ arma::vec du;
+ arma::vec dl;
+ double priority;
+ double bLower;
+ double bUpper;
+};
+class QueueNodeCompare
+{
+ bool reverse;
+ public:
+ QueueNodeCompare(const bool& revparam=false) : reverse(revparam) {}
+ bool operator() (const struct queueNode& lhs,
+ const struct queueNode& rhs) const
{
- template <typename TKernel>
- class KdeDualTree
- {
- private:
- /* possibly, these refer to the same object */
- tree::BinarySpaceTree<bound::HRectBound<2> >* trainingRoot;
- tree::BinarySpaceTree<bound::HRectBound<2> >* queryRoot;
- arma::mat trainingData;
- arma::mat queryData;
- double bandwidth;
- public:
- /* the two data sets are different */
- KdeDualTree (arma::mat& trainingData, arma::mat& queryData);
- /* the training data is also the query data */
- KdeDualTree (arma::mat& trainingData);
- /* find a suitable bandwidth */
- double optimizeBandwidth (double lower, double upper, size_t attempts);
- };
- };
+ if (reverse)
+ return (lhs.priority>rhs.priority);
+ else
+ return (lhs.priority<rhs.priority);
+ }
};
+
+template <typename TKernel = kernel::GaussianKernel,
+ typename TTree = tree::BinarySpaceTree<bound::HRectBound<2> > >
+class KdeDualTree
+{
+ private:
+ /* possibly, these refer to the same object */
+ TTree* referenceRoot;
+ TTree* queryRoot;
+ std::vector<size_t> referenceShuffledIndices;
+ std::vector<size_t> queryShuffledIndices;
+ arma::mat referenceData;
+ arma::mat queryData;
+ arma::mat upperBoundLevelByBandwidth;
+ arma::mat lowerBoundLevelByBandwidth;
+ arma::mat upperBoundQPointByBandwidth;
+ arma::mat lowerBoundQPointByBandwidth;
+ arma::mat upperBoundQNodeByBandwidth;
+ arma::mat lowerBoundQPointByBandwidth;
+ /* relative estimate to limit bandwidth calculations */
+ double delta;
+ /* relative error with respect to the density estimate */
+ double epsilon;
+ math::Range bandwidths;
+ std::priority_queue<struct queueNode,
+ std::vector<struct queueNode>,
+ QueueNodeCompare> nodePriorityQueue;
+ size_t bandwidthCount;
+ std::set<double> bandwidths;
+ size_t levelsInTree;
+ size_t queryTreeSize;
+
+ void SetDefaults();
+ void MultiBandwidthDualTree();
+ void MultiBandwidthDualTreeBase(TTree* Q,
+ TTree* T,
+ std::set<double> remainingBandwidths);
+ double GetPriority(TTree* nodeQ, TTree* nodeT)
+ {
+ return nodeQ->bound().MinDistance(*nodeT);
+ }
+ size_t GetLevelOfNode(TTree* node)
+ {
+ return levelsInTree - node->levelsBelow();
+ }
+ public:
+ /* the two data sets are different */
+ KdeDualTree (arma::mat& referenceData, arma::mat& queryData);
+ /* the reference data is also the query data */
+ KdeDualTree (arma::mat& referenceData);
+ /* setters and getters */
+ const math::Range& BandwidthRange() const { return bandwidthRange; }
+ const size_t& BandwidthCount() const { return bandwidthCount; }
+ const double& Delta() const { return delta; }
+ const double& Epsilon() const { return epsilon; }
+
+ void BandwidthRange(double l, double u) { bandwidthRange = math::Range(l,u); }
+ size_t& BandwidthCount() { return bandwidthCount; }
+ double& Delta() { return delta; }
+ double& Epsilon() { return epsilon; }
+};
+}; /* end namespace kde */
+}; /* end namespace mlpack */
+
+#define USE_KDE_DUAL_TREE_IMPL_HPP
#include "kde_dual_tree_impl.hpp"
+#undef USE_KDE_DUAL_TREE_IMPL_HPP
#endif
Modified: mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp
===================================================================
--- mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-10-30 23:25:16 UTC (rev 10080)
+++ mlpack/trunk/src/contrib/nslagle/myKDE/kde_dual_tree_impl.hpp 2011-10-31 01:58:09 UTC (rev 10081)
@@ -3,8 +3,10 @@
#define KDE_DUAL_TREE_IMPL_HPP
#ifndef KDE_DUAL_TREE_HPP
+#ifndef USE_KDE_DUAL_TREE_IMPL_HPP
#error "Do not include this header directly."
#endif
+#endif
using namespace mlpack;
using namespace mlpack::kde;
@@ -14,24 +16,54 @@
namespace kde
{
-template<typename TKernel>
-KdeDualTree<TKernel>::KdeDualTree (arma::mat& train,
- arma::mat& query) :
- trainingRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (train)),
- queryRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (query))
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::KdeDualTree (arma::mat& reference,
+ arma::mat& query)
{
- trainingData = train;
+ referenceRoot (new TTree (reference)),
+ queryRoot (new TTree (query))
+ referenceData = reference;
queryData = query;
+ levelsInTree = queryRoot->levelsBelow();
+ queryTreeSize = queryRoot->treeSize();
+ SetDefaults();
}
-template<typename TKernel>
-KdeDualTree<TKernel>::KdeDualTree (arma::mat& train) :
- trainingRoot (new tree::BinarySpaceTree<bound::HRectBound<2> > (train)),
- queryRoot (trainingRoot)
+
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::KdeDualTree (arma::mat& reference)
{
- trainingData = train;
- queryData = train;
+ referenceData = reference;
+ queryData = reference;
+
+ referenceRoot = new TTree (reference, referenceShuffledIndices);
+ queryRoot = referenceRoot;
+ levelsInTree = queryRoot->levelsBelow();
+ queryTreeSize = queryRoot->treeSize();
+ SetDefaults();
}
+template<typename TKernel, typename TTree>
+KdeDualTree<TKernel, TTree>::SetDefaults()
+{
+ BandwidthRange(0.01, 100.0);
+ bandwidthCount = 10;
+ delta = epsilon = 0.05;
+}
+
+template<typename TKernel, typename TTree>
+void KdeDualTree<TKernel, TTree>::MultiBandwidthDualTreeBase(TTree* Q,
+ TTree* T,
+ std::set<double> remainingBandwidths)
+{
+ for (size_t q = Q->begin(); q < Q->end(); ++q)
+ {
+ for (size_t t = T->begin(); t < T->end(); ++t)
+ {
+ }
+ }
+}
+
+
};
};
More information about the mlpack-svn
mailing list