[mlpack-svn] r13307 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 1 16:19:44 EDT 2012
Author: rcurtin
Date: 2012-08-01 16:19:43 -0400 (Wed, 01 Aug 2012)
New Revision: 13307
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Convert all of DTree (with the exception of ComputeVariableImportance()) to work
in the log domain. This was fairly difficult...
Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-08-01 20:19:43 UTC (rev 13307)
@@ -158,13 +158,13 @@
if (fp != NULL)
{
- dtreeOpt->WriteTree(0, fp);
+ dtreeOpt->WriteTree(fp);
fclose(fp);
}
}
else
{
- dtreeOpt->WriteTree(0, stdout);
+ dtreeOpt->WriteTree(stdout);
printf("\n");
}
}
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-08-01 20:19:43 UTC (rev 13307)
@@ -24,7 +24,7 @@
string leaf_class_membership_file = "")
{
// Tag the leaves with numbers.
- int num_leaves = dtree->TagTree(0);
+ int num_leaves = dtree->TagTree();
arma::Mat<size_t> table(num_leaves, num_classes);
table.zeros();
@@ -160,20 +160,20 @@
while (dtree->SubtreeLeaves() > 1)
{
std::pair<double, double> tree_seq(old_alpha,
- -1.0 * dtree->SubtreeLeavesError());
+ dtree->SubtreeLeavesLogNegError());
pruned_sequence.push_back(tree_seq);
old_alpha = alpha;
- alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
+ alpha = dtree->PruneAndUpdate(old_alpha, dataset->n_cols, useVolumeReg);
// Some sanity checks.
assert((alpha < std::numeric_limits<double>::max()) ||
(dtree->SubtreeLeaves() == 1));
assert(alpha > old_alpha);
- assert(dtree->SubtreeLeavesError() >= -1.0 * tree_seq.second);
+ assert(dtree->SubtreeLeavesLogNegError() < tree_seq.second);
}
std::pair<double, double> tree_seq(old_alpha,
- -1.0 * dtree->SubtreeLeavesError());
+ dtree->SubtreeLeavesLogNegError());
pruned_sequence.push_back(tree_seq);
Log::Info << pruned_sequence.size() << " trees in the sequence; max_alpha: "
@@ -242,12 +242,15 @@
val_cv += dtree_cv->ComputeValue(test_point);
}
- // Update the cv error value.
- it->second -= 2.0 * val_cv / (double) dataset->n_cols;
+ // Update the cv error value by mapping out of log-space then back into
+ // it, using long doubles.
+ long double notLogVal = -std::exp((long double) it->second) -
+ 2.0 * val_cv / (double) dataset->n_cols;
+ it->second = (double) std::log(-notLogVal);
// Determine the new alpha value and prune accordingly.
old_alpha = sqrt(((it + 1)->first) * ((it + 2)->first));
- alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
+ alpha = dtree_cv->PruneAndUpdate(old_alpha, train->n_cols, useVolumeReg);
}
// Compute test values for this state of the tree.
@@ -259,7 +262,9 @@
}
// Update the cv error value.
- it->second -= 2.0 * val_cv / (double) dataset->n_cols;
+ long double notLogVal = -std::exp((long double) it->second) -
+ 2.0 * val_cv / (double) dataset->n_cols;
+ it->second -= (double) std::log(-notLogVal);
test.reset();
delete train;
@@ -300,15 +305,16 @@
minLeafSize);
// Prune with optimal alpha.
- while ((old_alpha < optimal_alpha) && (dtree_opt->SubtreeLeaves() > 1))
+ while ((old_alpha > optimal_alpha) && (dtree_opt->SubtreeLeaves() > 1))
{
old_alpha = alpha;
- alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
+ alpha = dtree_opt->PruneAndUpdate(old_alpha, new_dataset->n_cols,
+ useVolumeReg);
// Some sanity checks.
assert((alpha < numeric_limits<double>::max()) ||
(dtree_opt->SubtreeLeaves() == 1));
- assert(alpha > old_alpha);
+ assert(alpha < old_alpha);
}
Log::Info << dtree_opt->SubtreeLeaves()
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-08-01 20:19:43 UTC (rev 13307)
@@ -58,91 +58,248 @@
typename cT = long double>
class DTree
{
+ public:
+ /**
+ * Create an empty density estimation tree.
+ */
+ DTree();
+
+ /**
+ * Create a density estimation tree with the given bounds and the given number
+ * of total points. Children will not be created.
+ *
+ * @param maxVals Maximum values of the bounding box.
+ * @param minVals Minimum values of the bounding box.
+ * @param totalPoints Total number of points in the dataset.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints);
+
+ /**
+ * Create a density estimation tree on the given data. Children will be
+ * created following the procedure outlined in the paper. The data will be
+ * modified; it will be reordered similar to the way BinarySpaceTree modifies
+ * datasets.
+ *
+ * @param data Dataset to build tree on.
+ */
+ DTree(arma::mat& data);
+
+ /**
+ * Create a child node of a density estimation tree given the bounding box
+ * specified by maxVals and minVals, using the size given in start and end and
+ * the specified error. Children of this node will not be created
+ * recursively.
+ *
+ * @param maxVals Upper bound of bounding box.
+ * @param minVals Lower bound of bounding box.
+ * @param start Start of points represented by this node in the data matrix.
+ * @param end End of points represented by this node in the data matrix.
+ * @param error log-negative error of this node.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError);
+
+ /**
+ * Create a child node of a density estimation tree given the bounding box
+ * specified by maxVals and minVals, using the size given in start and end,
+ * and calculating the error with the total number of points given. Children
+ * of this node will not be created recursively.
+ *
+ * @param maxVals Upper bound of bounding box.
+ * @param minVals Lower bound of bounding box.
+ * @param start Start of points represented by this node in the data matrix.
+ * @param end End of points represented by this node in the data matrix.
+ */
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end);
+
+ //! Clean up memory allocated by the tree.
+ ~DTree();
+
+ /**
+ * Greedily expand the tree. The points in the dataset will be reordered
+ * during tree growth.
+ *
+ * @param data Dataset to build tree on.
+ * @param oldFromNew Mappings from old points to new points.
+ * @param useVolReg If true, volume regularization is used.
+ * @param maxLeafSize Maximum size of a leaf.
+ * @param minLeafSize Minimum size of a leaf.
+ */
+ double Grow(arma::mat& data,
+ arma::Col<size_t>& oldFromNew,
+ const bool useVolReg = false,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5);
+
+ /**
+ * Perform alpha pruning on a tree. Returns the new value of alpha.
+ *
+ * @param oldAlpha Old value of alpha.
+ * @param points Total number of points in dataset.
+ * @param useVolReg If true, volume regularization is used.
+ * @return New value of alpha.
+ */
+ double PruneAndUpdate(const double oldAlpha,
+ const size_t points,
+ const bool useVolReg = false);
+
+ /**
+ * Compute the logarithm of the density estimate of a given query point.
+ *
+ * @param query Point to estimate density of.
+ */
+ double ComputeValue(const arma::vec& query) const;
+
+ /**
+ * Print the tree in a depth-first manner (this function is called
+ * recursively).
+ *
+ * @param fp File to write the tree to.
+ * @param level Level of the tree (should start at 0).
+ */
+ void WriteTree(FILE *fp, const size_t level = 0) const;
+
+ /**
+ * Index the buckets for possible usage later; this results in every leaf in
+ * the tree having a specific tag (accessible with BucketTag()). This
+ * function calls itself recursively.
+ *
+ * @param tag Tag for the next leaf; leave at 0 for the initial call.
+ */
+ int TagTree(const int tag = 0);
+
+ /**
+ * Return the tag of the leaf containing the query. This is useful for
+ * generating class memberships.
+ *
+ * @param query Query to search for.
+ */
+ int FindBucket(const arma::vec& query) const;
+
+ /**
+ * Compute the variable importance of each dimension in the learned tree.
+ *
+ * @param importances Vector to store the calculated importances in.
+ */
+ void ComputeVariableImportance(arma::vec& importances) const;
+
+ /**
+ * Compute the log-negative-error for this point, given the total number of
+ * points in the dataset.
+ *
+ * @param totalPoints Total number of points in the dataset.
+ */
+ inline double LogNegativeError(const size_t totalPoints) const;
+
private:
// The indices in the complete set of points
// (after all forms of swapping in the original data
// matrix to align all the points in a node
// consecutively in the matrix. The 'old_from_new' array
// maps the points back to their original indices.
+
+ //! The index of the first point in the dataset contained in this node (and
+ //! its children).
size_t start;
+ //! The index of the last point in the dataset contained in this node (and its
+ //! children).
size_t end;
- // since we are using uniform density, we need
- // the max and min of every dimension for every node
+ //! Upper half of bounding box for this node.
arma::vec maxVals;
+ //! Lower half of bounding box for this node.
arma::vec minVals;
- // The split dim for this node
+ //! The splitting dimension for this node.
size_t splitDim;
- // The split val on that dim
+ //! The split value on the splitting dimension for this node.
double splitValue;
- // L2-error of the node
- double error;
+ //! log-negative-L2-error of the node.
+ double logNegError;
- // sum of the error of the leaves of the subtree
- double subtreeLeavesError;
+ //! Sum of the error of the leaves of the subtree.
+ double subtreeLeavesLogNegError;
- // number of leaves of the subtree
+ //! Number of leaves of the subtree.
size_t subtreeLeaves;
- // flag to indicate if this is the root node
- // used to check whether the query point is
- // within the range
+ //! If true, this node is the root of the tree.
bool root;
- // ratio of number of points in the node to the
- // total number of points (|t| / N)
+ //! Ratio of the number of points in the node to the total number of points.
double ratio;
- // the inverse of volume of the node
- double vTInv;
+ //! The logarithm of the volume of the node.
+ double logVolume;
- // sum of the reciprocal of the inverse v_ts
- // the leaves of this subtree
- double subtreeLeavesVTInv;
-
- // the tag for the leaf used for hashing points
+ //! The tag for the leaf, used for hashing points.
int bucketTag;
- // The children
+ //! Upper part of alpha sum; used for pruning.
+ double alphaUpper;
+
+ //! The left child.
DTree<eT, cT> *left;
+ //! The right child.
DTree<eT, cT> *right;
-public:
-
- ////////////////////// Getters and Setters //////////////////////////////////
+ public:
+ //! Return the starting index of points contained in this node.
size_t Start() const { return start; }
-
+ //! Return the first index of a point not contained in this node.
size_t End() const { return end; }
-
+ //! Return the split dimension of this node.
size_t SplitDim() const { return splitDim; }
-
+ //! Return the split value of this node.
double SplitValue() const { return splitValue; }
-
- double Error() const { return error; }
-
- double SubtreeLeavesError() const { return subtreeLeavesError; }
-
+ //! Return the log negative error of this node.
+ double LogNegError() const { return logNegError; }
+ //! Return the log negative error of all descendants of this node.
+ double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
+ //! Return the number of leaves which are descendants of this node.
size_t SubtreeLeaves() const { return subtreeLeaves; }
-
+ //! Return the ratio of points in this node to the points in the whole
+ //! dataset.
double Ratio() const { return ratio; }
-
- double VTInv() const { return vTInv; }
-
- double SubtreeLeavesVTInv() const { return subtreeLeavesVTInv; }
-
+ //! Return the inverse of the volume of this node.
+ double LogVolume() const { return logVolume; }
+ //! Return the left child.
DTree<eT, cT>* Left() const { return left; }
+ //! Return the right child.
DTree<eT, cT>* Right() const { return right; }
-
+ //! Return whether or not this is the root of the tree.
bool Root() const { return root; }
+ //! Return the upper part of the alpha sum.
+ double AlphaUpper() const { return alphaUpper; }
- ////////////////////// Private Functions ////////////////////////////////////
+ //! Return the maximum values.
+ const arma::vec& MaxVals() const { return maxVals; }
+ //! Modify the maximum values.
+ arma::vec& MaxVals() { return maxVals; }
+
+ //! Return the minimum values.
+ const arma::vec& MinVals() const { return minVals; }
+ //! Modify the minimum values.
+ arma::vec& MinVals() { return minVals; }
+
private:
- inline double LogNegativeError(const size_t total_points) const;
+ // Utility methods.
+ /**
+ * Find the dimension to split on.
+ */
bool FindSplit(const arma::mat& data,
size_t& splitDim,
double& splitValue,
@@ -151,71 +308,18 @@
const size_t maxLeafSize = 10,
const size_t minLeafSize = 5) const;
+ /**
+ * Split the data, returning the number of points left of the split.
+ */
size_t SplitData(arma::mat& data,
const size_t splitDim,
const double splitValue,
arma::Col<size_t>& oldFromNew) const;
+ /**
+ * Return whether a query point is within the range of this node.
+ */
inline bool WithinRange(const arma::vec& query) const;
-
- ///////////////////// Public Functions //////////////////////////////////////
- public:
-
- DTree();
-
- // Root node initializer
- // with the bounding box of the data
- // it contains instead of just the data.
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints);
-
- // Root node initializer
- // with the data, no bounding box.
- DTree(arma::mat& data);
-
- // Non-root node initializers
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t start,
- const size_t end,
- const double error);
-
- DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end);
-
- ~DTree();
-
- // Greedily expand the tree
- double Grow(arma::mat& data,
- arma::Col<size_t>& oldFromNew,
- const bool useVolReg = false,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5);
-
- // perform alpha pruning on the tree
- double PruneAndUpdate(const double old_alpha, const bool useVolReg = false);
-
- // compute the density at a given point
- double ComputeValue(const arma::vec& query) const;
-
- // print the tree (in a DFS manner)
- void WriteTree(size_t level, FILE *fp);
-
- // indexing the buckets for possible usage later
- int TagTree(int tag);
-
- // This is used to generate the class membership
- // of a learned tree.
- int FindBucket(const arma::vec& query) const;
-
- // This computes the variable importance list
- // for the learned tree.
- void ComputeVariableImportance(arma::vec& importances) const;
-
}; // Class DTree
}; // namespace det
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-08-01 20:19:43 UTC (rev 13307)
@@ -15,6 +15,112 @@
namespace mlpack {
namespace det {
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree() :
+ start(0),
+ end(0),
+ logNegError(-DBL_MAX),
+ root(true),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+
+// Root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints) :
+ start(0),
+ end(totalPoints),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(LogNegativeError(totalPoints)),
+ root(true),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(arma::mat& data) :
+ start(0),
+ end(data.n_cols),
+ left(NULL),
+ right(NULL)
+{
+ maxVals.set_size(data.n_rows);
+ minVals.set_size(data.n_rows);
+
+ // Initialize to first column; values will be overwritten if necessary.
+ maxVals = data.col(0);
+ minVals = data.col(0);
+
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ {
+ for (size_t j = 0; j < data.n_rows; ++j)
+ {
+ if (data(j, i) > maxVals[j])
+ maxVals[j] = data(j, i);
+ if (data(j, i) < minVals[j])
+ minVals[j] = data(j, i);
+ }
+ }
+
+ logNegError = LogNegativeError(data.n_cols);
+
+ bucketTag = -1;
+ root = true;
+}
+
+
+// Non-root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t start,
+ const size_t end,
+ const double logNegError) :
+ start(start),
+ end(end),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(logNegError),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
+ const size_t totalPoints,
+ const size_t start,
+ const size_t end) :
+ start(start),
+ end(end),
+ maxVals(maxVals),
+ minVals(minVals),
+ logNegError(LogNegativeError(totalPoints)),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::~DTree()
+{
+ if (left != NULL)
+ delete left;
+
+ if (right != NULL)
+ delete right;
+}
+
// This function computes the log-l2-negative-error of a given node from the
// formula R(t) = log(|t|^2 / (N^2 V_t)).
template<typename eT, typename cT>
@@ -45,7 +151,7 @@
const size_t points = end - start;
- double minError = std::log(-error);
+ double minError = logNegError;
bool splitFound = false;
// Loop through each dimension.
@@ -63,20 +169,13 @@
// Initializing all the stuff for this dimension.
bool dimSplitFound = false;
// Take an error estimate for this dimension.
- double minDimError = points / (max - min);
+ double minDimError = std::pow(points, 2.0) / (max - min);
double dimLeftError;
double dimRightError;
double dimSplitValue;
// Find the log volume of all the other dimensions.
- double volumeWithoutDim = 0;
- for (size_t i = 0; i < maxVals.n_elem; ++i)
- {
- if ((maxVals[i] - minVals[i] > 0.0) && (i != dim))
- {
- volumeWithoutDim += std::log(maxVals[i] - minVals[i]);
- }
- }
+ double volumeWithoutDim = logVolume - std::log(max - min);
// Get the values for the dimension.
arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
@@ -145,11 +244,6 @@
} // end if better split found in this dimension.
}
- // Map out of logspace.
- minError = -std::exp(minError);
- leftError = -std::exp(leftError);
- rightError = -std::exp(rightError);
-
return splitFound;
}
@@ -187,111 +281,6 @@
return left;
}
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree() :
- start(0),
- end(0),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-
-// Root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints) :
- start(0),
- end(totalPoints),
- maxVals(maxVals),
- minVals(minVals),
- error(-std::exp(LogNegativeError(totalPoints))),
- root(true),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(arma::mat& data) :
- start(0),
- end(data.n_cols),
- left(NULL),
- right(NULL)
-{
- maxVals.set_size(data.n_rows);
- minVals.set_size(data.n_rows);
-
- // Initialize to first column; values will be overwritten if necessary.
- maxVals = data.col(0);
- minVals = data.col(0);
-
- // Loop over data to extract maximum and minimum values in each dimension.
- for (size_t i = 1; i < data.n_cols; ++i)
- {
- for (size_t j = 0; j < data.n_rows; ++j)
- {
- if (data(j, i) > maxVals[j])
- maxVals[j] = data(j, i);
- if (data(j, i) < minVals[j])
- minVals[j] = data(j, i);
- }
- }
-
- error = -std::exp(LogNegativeError(data.n_cols));
-
- bucketTag = -1;
- root = true;
-}
-
-
-// Non-root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t start,
- const size_t end,
- const double error) :
- start(start),
- end(end),
- maxVals(maxVals),
- minVals(minVals),
- error(error),
- root(false),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
- const arma::vec& minVals,
- const size_t totalPoints,
- const size_t start,
- const size_t end) :
- start(start),
- end(end),
- maxVals(maxVals),
- minVals(minVals),
- error(-std::exp(LogNegativeError(totalPoints))),
- root(false),
- bucketTag(-1),
- left(NULL),
- right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::~DTree()
-{
- if (left != NULL)
- delete left;
-
- if (right != NULL)
- delete right;
-}
-
-
// Greedily expand the tree
template<typename eT, typename cT>
double DTree<eT, cT>::Grow(arma::mat& data,
@@ -308,17 +297,12 @@
// Compute points ratio.
ratio = (double) (end - start) / (double) oldFromNew.n_elem;
- // Compute the v_t_inv: the inverse of the volume of the node. We use log to
- // prevent overflow.
- double logVol = 0;
+ // Compute the log of the volume of the node.
+ logVolume = 0;
for (size_t i = 0; i < maxVals.n_elem; ++i)
if (maxVals[i] - minVals[i] > 0.0)
- logVol += std::log(maxVals[i] - minVals[i]);
+ logVolume += std::log(maxVals[i] - minVals[i]);
- // Check for overflow.
- assert(std::exp(logVol) > 0.0);
- vTInv = 1.0 / std::exp(logVol);
-
// Check if node is large enough to split.
if ((size_t) (end - start) > maxLeafSize) {
@@ -357,19 +341,24 @@
// Store values of R(T~) and |T~|.
subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
- subtreeLeavesError = left->SubtreeLeavesError() +
- right->SubtreeLeavesError();
- // Store subtreeLeavesVTInv.
- subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
- right->SubtreeLeavesVTInv();
+ // Find the log negative error of the subtree leaves. This is kind of an
+ // odd one because we don't want to represent the error in non-log-space,
+ // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
+ // V_t (remember E_l has an inverse relationship to the volume of the
+ // nodes) and then subtract log(V_t) at the end of the whole expression.
+ // As a result we do leave log-space, but the largest quantity we
+ // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+ // node below this node, which depends heavily on the depth of the tree.
+ subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+ left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+ - logVolume;
}
else
{
// No split found so make a leaf out of it.
subtreeLeaves = 1;
- subtreeLeavesError = error;
- subtreeLeavesVTInv = vTInv;
+ subtreeLeavesLogNegError = logNegError;
}
}
else
@@ -377,8 +366,7 @@
// We can make this a leaf node.
assert((size_t) (end - start) >= minLeafSize);
subtreeLeaves = 1;
- subtreeLeavesError = error;
- subtreeLeavesVTInv = vTInv;
+ subtreeLeavesLogNegError = logNegError;
}
// If this is a leaf, do not compute g_k(t); otherwise compute, store, and
@@ -390,13 +378,47 @@
}
else
{
+ const double range = maxVals[splitDim] - minVals[splitDim];
+ const double leftRatio = (splitValue - minVals[splitDim]) / range;
+ const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+ const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+ const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+ const size_t thisPow = std::pow(end - start, 2);
+
+ double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
+
+ if (left->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log(data.n_cols) + logVolume +
+ left->AlphaUpper();
+
+ // Whether or not this will overflow is highly dependent on the depth of
+ // the tree.
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ if (right->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log(data.n_cols) + logVolume +
+ right->AlphaUpper();
+
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(data.n_cols) - logVolume;
+
double gT;
if (useVolReg)
- gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+ {
+ // This is wrong for now!
+ gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
+ }
else
- gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+ {
+ gT = alphaUpper - std::log(subtreeLeaves - 1);
+ }
- assert(gT > 0.0);
return min(gT, min(leftG, rightG));
}
@@ -408,23 +430,25 @@
template<typename eT, typename cT>
double DTree<eT, cT>::PruneAndUpdate(const double oldAlpha,
+ const size_t points,
const bool useVolReg)
+
{
// Compute gT.
if (subtreeLeaves == 1) // If we are a leaf...
{
- return std::numeric_limits<double>::max();
+ return 0;
}
else
{
// Compute gT value for node t.
double gT;
if (useVolReg)
- gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+ gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
else
- gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+ gT = alphaUpper - std::log(subtreeLeaves - 1);
- if (gT > oldAlpha)
+ if (gT < oldAlpha)
{
// Go down the tree and update accordingly. Traverse the children.
double leftG = left->PruneAndUpdate(oldAlpha, useVolReg);
@@ -432,35 +456,72 @@
// Update values.
subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
- subtreeLeavesError = left->SubtreeLeavesError() +
- right->SubtreeLeavesError();
- subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
- right->SubtreeLeavesVTInv();
+ // Find the log negative error of the subtree leaves. This is kind of an
+ // odd one because we don't want to represent the error in non-log-space,
+ // but we have to calculate log(E_l + E_r). So we multiply E_l and E_r by
+ // V_t (remember E_l has an inverse relationship to the volume of the
+ // nodes) and then subtract log(V_t) at the end of the whole expression.
+ // As a result we do leave log-space, but the largest quantity we
+ // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+ // node below this node, which depends heavily on the depth of the tree.
+ subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+ left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+ - logVolume;
+
+ // Recalculate upper alpha.
+ const double range = maxVals[splitDim] - minVals[splitDim];
+ const double leftRatio = (splitValue - minVals[splitDim]) / range;
+ const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+ const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+ const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+ const size_t thisPow = std::pow(end - start, 2);
+
+ double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
+ thisPow;
+
+ if (left->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log(points) + logVolume +
+ left->AlphaUpper();
+
+ // Whether or not this will overflow is highly dependent on the depth of
+ // the tree.
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ if (right->SubtreeLeaves() > 1)
+ {
+ const double exponent = 2 * std::log(points) + logVolume +
+ right->AlphaUpper();
+
+ tmpAlphaSum += std::exp(exponent);
+ }
+
+ alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(points) - logVolume;
+
// Update gT value.
if (useVolReg)
- gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+ {
+ // This is incorrect.
+ gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
+ }
else
- gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+ {
+ gT = alphaUpper - std::log(subtreeLeaves - 1);
+ }
assert(gT < std::numeric_limits<double>::max());
- if (left->SubtreeLeaves() == 1 && right->SubtreeLeaves() == 1)
- return gT;
- else if (left->SubtreeLeaves() == 1)
- return min(gT, rightG);
- else if (right->SubtreeLeaves() == 1)
- return min(gT, leftG);
- else
- return min(gT, min(leftG, rightG));
+ return min(gT, min(leftG, rightG));
}
else
{
// Prune this subtree.
// First, make this node a leaf node.
subtreeLeaves = 1;
- subtreeLeavesError = error;
- subtreeLeavesVTInv = vTInv;
+ subtreeLeavesLogNegError = logNegError;
delete left;
left = NULL;
@@ -503,7 +564,7 @@
if (subtreeLeaves == 1) // If we are a leaf...
{
- return ratio * vTInv;
+ return std::exp(std::log(ratio) - logVolume);
}
else
{
@@ -517,11 +578,13 @@
return right->ComputeValue(query);
}
}
+
+ return 0.0;
}
template<typename eT, typename cT>
-void DTree<eT, cT>::WriteTree(size_t level, FILE *fp)
+void DTree<eT, cT>::WriteTree(FILE *fp, const size_t level) const
{
if (subtreeLeaves > 1)
{
@@ -530,18 +593,18 @@
fprintf(fp, "|\t");
fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
- right->WriteTree(level + 1, fp);
+ right->WriteTree(fp, level + 1);
fprintf(fp, "\n");
for (size_t i = 0; i < level; ++i)
fprintf(fp, "|\t");
fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
- left->WriteTree(level + 1, fp);
+ left->WriteTree(fp, level);
}
else // If we are a leaf...
{
- fprintf(fp, ": f(x)=%lg", ratio * vTInv);
+ fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
if (bucketTag != -1)
fprintf(fp, " BT:%d", bucketTag);
}
@@ -550,10 +613,11 @@
// Index the buckets for possible usage later.
template<typename eT, typename cT>
-int DTree<eT, cT>::TagTree(int tag)
+int DTree<eT, cT>::TagTree(const int tag)
{
if (subtreeLeaves == 1)
{
+ // Only label leaves.
bucketTag = tag;
return (tag + 1);
}
@@ -561,7 +625,7 @@
{
return right->TagTree(left->TagTree(tag));
}
-} // TagTree
+}
template<typename eT, typename cT>
@@ -578,8 +642,9 @@
// If left subtree, go to left child.
return left->FindBucket(query);
}
- else // If right subtree, go to right child.
+ else
{
+ // If right subtree, go to right child.
return right->FindBucket(query);
}
}
@@ -603,8 +668,11 @@
if (curNode.subtreeLeaves == 1)
continue; // Do nothing for leaves.
- importances[curNode.SplitDim()] += (double) (curNode.Error() -
- (curNode.Left()->Error() + curNode.Right()->Error()));
+ // The way to do this entirely in log-space is (at this time) somewhat
+ // unclear. So this risks overflow.
+ importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
+ (-std::exp(curNode.Left()->LogNegError()) +
+ -std::exp(curNode.Right()->LogNegError())));
nodes.push(curNode.Left());
nodes.push(curNode.Right());
More information about the mlpack-svn
mailing list