[mlpack-svn] r13195 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Tue Jul 10 16:37:13 EDT 2012
Author: rcurtin
Date: 2012-07-10 16:37:13 -0400 (Tue, 10 Jul 2012)
New Revision: 13195
Modified:
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Use log negative error to prevent calculation overflows (a first step towards
making cT and eT unnecessary). The algorithm has not yet been changed... so
LogNegativeError is always used as '-std::exp(LogNegativeError(.))' at the
moment.
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-10 18:58:14 UTC (rev 13194)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-10 20:37:13 UTC (rev 13195)
@@ -31,6 +31,28 @@
* overflow, so you should use either normalize your data in the (-1, 1)
* hypercube or use long double or modify this code to perform computations
* using logarithms.
+ *
+ * A density estimation tree is similar to both a decision tree and a space
+ * partitioning tree (like a kd-tree). Each leaf represents a constant-density
+ * hyper-rectangle. The tree is constructed in such a way as to minimize the
+ * integrated square error between the probability distribution of the tree and
+ * the observed probability distribution of the data. Because the tree is
+ * similar to a decision tree, the density estimation tree can provide very fast
+ * density estimates for a given point.
+ *
+ * For more information, see the following paper:
+ *
+ * @code
+ * @incollection{ram2011,
+ * author = {Ram, Parikshit and Gray, Alexander G.},
+ * title = {Density estimation trees},
+ * booktitle = {{Proceedings of the 17th ACM SIGKDD International Conference
+ * on Knowledge Discovery and Data Mining}},
+ * series = {KDD '11},
+ * year = {2011},
+ * pages = {627--635}
+ * }
+ * @endcode
*/
template<typename eT = float,
typename cT = long double>
@@ -123,7 +145,7 @@
////////////////////// Private Functions ////////////////////////////////////
private:
- cT ComputeNodeError_(size_t total_points);
+ double LogNegativeError(size_t total_points);
bool FindSplit_(MatType* data,
size_t* split_dim,
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-10 18:58:14 UTC (rev 13194)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-10 20:37:13 UTC (rev 13195)
@@ -15,34 +15,19 @@
namespace mlpack{
namespace det {
-// This function computes the l2-error of a given node from the formula
-// R(t) = -|t|^2 / (N^2 V_t).
+// This function computes the log-l2-negative-error of a given node from the
+// formula R(t) = log(|t|^2 / (N^2 V_t)).
template<typename eT, typename cT>
-cT DTree<eT, cT>::ComputeNodeError_(size_t total_points)
+double DTree<eT, cT>::LogNegativeError(size_t total_points)
{
- size_t node_size = end_ - start_;
+ // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
+ double error = 2 * std::log((double) (end_ - start_)) -
+ 2 * std::log((double) total_points);
+ for (size_t i = 0; i < max_vals_->n_elem; ++i)
+ error -= std::log((*max_vals_)[i] - (*min_vals_)[i]);
- cT log_vol_t = 0;
- for (size_t i = 0; i < max_vals_->n_elem; i++)
- if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
- // Use log to prevent overflow.
- log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
-
- // Check for overflow -- if it doesn't work, try higher precision by default
- // cT = long double, so if you can't work with that there is nothing else you
- // can do - except computing error using log and dealing with everything in
- // log form.
- assert(std::exp(log_vol_t) > 0.0);
-
- cT log_neg_error = 2 * std::log((cT) node_size / (cT) total_points)
- - log_vol_t;
-
- assert(std::exp(log_neg_error) > 0.0);
-
- cT error = -1.0 * std::exp(log_neg_error);
-
return error;
-} // ComputeNodeError
+}
// This function finds the best split with respect to the L2-error,
// but trying all possible splits. The dataset is the full data set but the
@@ -285,7 +270,7 @@
left_(NULL),
right_(NULL)
{
- error_ = ComputeNodeError_(total_points);
+ error_ = -std::exp(LogNegativeError(total_points));
bucket_tag_ = -1;
root_ = true;
@@ -304,7 +289,7 @@
GetMaxMinVals_(data, max_vals_, min_vals_);
- error_ = ComputeNodeError_(data->n_cols);
+ error_ = -std::exp(LogNegativeError(data->n_cols));
bucket_tag_ = -1;
root_ = true;
@@ -344,7 +329,7 @@
left_(NULL),
right_(NULL)
{
- error_ = ComputeNodeError_(total_points);
+ error_ = -std::exp(LogNegativeError(total_points));
bucket_tag_ = -1;
root_ = false;
More information about the mlpack-svn
mailing list