[mlpack-svn] r13307 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 1 16:19:44 EDT 2012


Author: rcurtin
Date: 2012-08-01 16:19:43 -0400 (Wed, 01 Aug 2012)
New Revision: 13307

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Convert all of DTree (with the exception of ComputeVariableImportance()) to work
in the log domain.  This was fairly difficult...


Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-08-01 20:19:43 UTC (rev 13307)
@@ -158,13 +158,13 @@
 
       if (fp != NULL)
       {
-        dtreeOpt->WriteTree(0, fp);
+        dtreeOpt->WriteTree(fp);
         fclose(fp);
       }
     }
     else
     {
-      dtreeOpt->WriteTree(0, stdout);
+      dtreeOpt->WriteTree(stdout);
       printf("\n");
     }
   }

Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-08-01 20:19:43 UTC (rev 13307)
@@ -24,7 +24,7 @@
                          string leaf_class_membership_file = "")
 {
   // Tag the leaves with numbers.
-  int num_leaves = dtree->TagTree(0);
+  int num_leaves = dtree->TagTree();
 
   arma::Mat<size_t> table(num_leaves, num_classes);
   table.zeros();
@@ -160,20 +160,20 @@
   while (dtree->SubtreeLeaves() > 1)
   {
     std::pair<double, double> tree_seq(old_alpha,
-        -1.0 * dtree->SubtreeLeavesError());
+        dtree->SubtreeLeavesLogNegError());
     pruned_sequence.push_back(tree_seq);
     old_alpha = alpha;
-    alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
+    alpha = dtree->PruneAndUpdate(old_alpha, dataset->n_cols, useVolumeReg);
 
     // Some sanity checks.
     assert((alpha < std::numeric_limits<double>::max()) ||
         (dtree->SubtreeLeaves() == 1));
     assert(alpha > old_alpha);
-    assert(dtree->SubtreeLeavesError() >= -1.0 * tree_seq.second);
+    assert(dtree->SubtreeLeavesLogNegError() < tree_seq.second);
   }
 
   std::pair<double, double> tree_seq(old_alpha,
-      -1.0 * dtree->SubtreeLeavesError());
+      dtree->SubtreeLeavesLogNegError());
   pruned_sequence.push_back(tree_seq);
 
   Log::Info << pruned_sequence.size() << " trees in the sequence; max_alpha: "
@@ -242,12 +242,15 @@
         val_cv += dtree_cv->ComputeValue(test_point);
       }
 
-      // Update the cv error value.
-      it->second -= 2.0 * val_cv / (double) dataset->n_cols;
+      // Update the cv error value by mapping out of log-space then back into
+      // it, using long doubles.
+      long double notLogVal = -std::exp((long double) it->second) -
+          2.0 * val_cv / (double) dataset->n_cols;
+      it->second = (double) std::log(-notLogVal);
 
       // Determine the new alpha value and prune accordingly.
       old_alpha = sqrt(((it + 1)->first) * ((it + 2)->first));
-      alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
+      alpha = dtree_cv->PruneAndUpdate(old_alpha, train->n_cols, useVolumeReg);
     }
 
     // Compute test values for this state of the tree.
@@ -259,7 +262,9 @@
     }
 
     // Update the cv error value.
-    it->second -= 2.0 * val_cv / (double) dataset->n_cols;
+    long double notLogVal = -std::exp((long double) it->second) -
+        2.0 * val_cv / (double) dataset->n_cols;
+    it->second -= (double) std::log(-notLogVal);
 
     test.reset();
     delete train;
@@ -300,15 +305,16 @@
       minLeafSize);
 
   // Prune with optimal alpha.
-  while ((old_alpha < optimal_alpha) && (dtree_opt->SubtreeLeaves() > 1))
+  while ((old_alpha > optimal_alpha) && (dtree_opt->SubtreeLeaves() > 1))
   {
     old_alpha = alpha;
-    alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
+    alpha = dtree_opt->PruneAndUpdate(old_alpha, new_dataset->n_cols,
+        useVolumeReg);
 
     // Some sanity checks.
     assert((alpha < numeric_limits<double>::max()) ||
         (dtree_opt->SubtreeLeaves() == 1));
-    assert(alpha > old_alpha);
+    assert(alpha < old_alpha);
   }
 
   Log::Info << dtree_opt->SubtreeLeaves()

Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-08-01 20:19:43 UTC (rev 13307)
@@ -58,91 +58,248 @@
          typename cT = long double>
 class DTree
 {
+ public:
+  /**
+   * Create an empty density estimation tree.
+   */
+  DTree();
+
+  /**
+   * Create a density estimation tree with the given bounds and the given number
+   * of total points.  Children will not be created.
+   *
+   * @param maxVals Maximum values of the bounding box.
+   * @param minVals Minimum values of the bounding box.
+   * @param totalPoints Total number of points in the dataset.
+   */
+  DTree(const arma::vec& maxVals,
+        const arma::vec& minVals,
+        const size_t totalPoints);
+
+  /**
+   * Create a density estimation tree on the given data.  Children will be
+   * created following the procedure outlined in the paper.  The data will be
+   * modified; it will be reordered similar to the way BinarySpaceTree modifies
+   * datasets.
+   *
+   * @param data Dataset to build tree on.
+   */
+  DTree(arma::mat& data);
+
+  /**
+   * Create a child node of a density estimation tree given the bounding box
+   * specified by maxVals and minVals, using the size given in start and end and
+   * the specified error.  Children of this node will not be created
+   * recursively.
+   *
+   * @param maxVals Upper bound of bounding box.
+   * @param minVals Lower bound of bounding box.
+   * @param start Start of points represented by this node in the data matrix.
+   * @param end End of points represented by this node in the data matrix.
+   * @param error log-negative error of this node.
+   */
+  DTree(const arma::vec& maxVals,
+        const arma::vec& minVals,
+        const size_t start,
+        const size_t end,
+        const double logNegError);
+
+  /**
+   * Create a child node of a density estimation tree given the bounding box
+   * specified by maxVals and minVals, using the size given in start and end,
+   * and calculating the error with the total number of points given.  Children
+   * of this node will not be created recursively.
+   *
+   * @param maxVals Upper bound of bounding box.
+   * @param minVals Lower bound of bounding box.
+   * @param start Start of points represented by this node in the data matrix.
+   * @param end End of points represented by this node in the data matrix.
+   */
+  DTree(const arma::vec& maxVals,
+        const arma::vec& minVals,
+        const size_t totalPoints,
+        const size_t start,
+        const size_t end);
+
+  //! Clean up memory allocated by the tree.
+  ~DTree();
+
+  /**
+   * Greedily expand the tree.  The points in the dataset will be reordered
+   * during tree growth.
+   *
+   * @param data Dataset to build tree on.
+   * @param oldFromNew Mappings from old points to new points.
+   * @param useVolReg If true, volume regularization is used.
+   * @param maxLeafSize Maximum size of a leaf.
+   * @param minLeafSize Minimum size of a leaf.
+   */
+  double Grow(arma::mat& data,
+              arma::Col<size_t>& oldFromNew,
+              const bool useVolReg = false,
+              const size_t maxLeafSize = 10,
+              const size_t minLeafSize = 5);
+
+  /**
+   * Perform alpha pruning on a tree.  Returns the new value of alpha.
+   *
+   * @param oldAlpha Old value of alpha.
+   * @param points Total number of points in dataset.
+   * @param useVolReg If true, volume regularization is used.
+   * @return New value of alpha.
+   */
+  double PruneAndUpdate(const double oldAlpha,
+                        const size_t points,
+                        const bool useVolReg = false);
+
+  /**
+   * Compute the logarithm of the density estimate of a given query point.
+   *
+   * @param query Point to estimate density of.
+   */
+  double ComputeValue(const arma::vec& query) const;
+
+  /**
+   * Print the tree in a depth-first manner (this function is called
+   * recursively).
+   *
+   * @param fp File to write the tree to.
+   * @param level Level of the tree (should start at 0).
+   */
+  void WriteTree(FILE *fp, const size_t level = 0) const;
+
+  /**
+   * Index the buckets for possible usage later; this results in every leaf in
+   * the tree having a specific tag (accessible with BucketTag()).  This
+   * function calls itself recursively.
+   *
+   * @param tag Tag for the next leaf; leave at 0 for the initial call.
+   */
+  int TagTree(const int tag = 0);
+
+  /**
+   * Return the tag of the leaf containing the query.  This is useful for
+   * generating class memberships.
+   *
+   * @param query Query to search for.
+   */
+  int FindBucket(const arma::vec& query) const;
+
+  /**
+   * Compute the variable importance of each dimension in the learned tree.
+   *
+   * @param importances Vector to store the calculated importances in.
+   */
+  void ComputeVariableImportance(arma::vec& importances) const;
+
+  /**
+   * Compute the log-negative-error for this point, given the total number of
+   * points in the dataset.
+   *
+   * @param totalPoints Total number of points in the dataset.
+   */
+  inline double LogNegativeError(const size_t totalPoints) const;
+
  private:
   // The indices in the complete set of points
   // (after all forms of swapping in the original data
   // matrix to align all the points in a node
   // consecutively in the matrix. The 'old_from_new' array
   // maps the points back to their original indices.
+
+  //! The index of the first point in the dataset contained in this node (and
+  //! its children).
   size_t start;
+  //! The index of the last point in the dataset contained in this node (and its
+  //! children).
   size_t end;
 
-  // since we are using uniform density, we need
-  // the max and min of every dimension for every node
+  //! Upper half of bounding box for this node.
   arma::vec maxVals;
+  //! Lower half of bounding box for this node.
   arma::vec minVals;
 
-  // The split dim for this node
+  //! The splitting dimension for this node.
   size_t splitDim;
 
-  // The split val on that dim
+  //! The split value on the splitting dimension for this node.
   double splitValue;
 
-  // L2-error of the node
-  double error;
+  //! log-negative-L2-error of the node.
+  double logNegError;
 
-  // sum of the error of the leaves of the subtree
-  double subtreeLeavesError;
+  //! Sum of the error of the leaves of the subtree.
+  double subtreeLeavesLogNegError;
 
-  // number of leaves of the subtree
+  //! Number of leaves of the subtree.
   size_t subtreeLeaves;
 
-  // flag to indicate if this is the root node
-  // used to check whether the query point is
-  // within the range
+  //! If true, this node is the root of the tree.
   bool root;
 
-  // ratio of number of points in the node to the
-  // total number of points (|t| / N)
+  //! Ratio of the number of points in the node to the total number of points.
   double ratio;
 
-  // the inverse of  volume of the node
-  double vTInv;
+  //! The logarithm of the volume of the node.
+  double logVolume;
 
-  // sum of the reciprocal of the inverse v_ts
-  // the leaves of this subtree
-  double subtreeLeavesVTInv;
-
-  // the tag for the leaf used for hashing points
+  //! The tag for the leaf, used for hashing points.
   int bucketTag;
 
-  // The children
+  //! Upper part of alpha sum; used for pruning.
+  double alphaUpper;
+
+  //! The left child.
   DTree<eT, cT> *left;
+  //! The right child.
   DTree<eT, cT> *right;
 
-public:
-
-  ////////////////////// Getters and Setters //////////////////////////////////
+ public:
+  //! Return the starting index of points contained in this node.
   size_t Start() const { return start; }
-
+  //! Return the first index of a point not contained in this node.
   size_t End() const { return end; }
-
+  //! Return the split dimension of this node.
   size_t SplitDim() const { return splitDim; }
-
+  //! Return the split value of this node.
   double SplitValue() const { return splitValue; }
-
-  double Error() const { return error; }
-
-  double SubtreeLeavesError() const { return subtreeLeavesError; }
-
+  //! Return the log negative error of this node.
+  double LogNegError() const { return logNegError; }
+  //! Return the log negative error of all descendants of this node.
+  double SubtreeLeavesLogNegError() const { return subtreeLeavesLogNegError; }
+  //! Return the number of leaves which are descendants of this node.
   size_t SubtreeLeaves() const { return subtreeLeaves; }
-
+  //! Return the ratio of points in this node to the points in the whole
+  //! dataset.
   double Ratio() const { return ratio; }
-
-  double VTInv() const { return vTInv; }
-
-  double SubtreeLeavesVTInv() const { return subtreeLeavesVTInv; }
-
+  //! Return the inverse of the volume of this node.
+  double LogVolume() const { return logVolume; }
+  //! Return the left child.
   DTree<eT, cT>* Left() const { return left; }
+  //! Return the right child.
   DTree<eT, cT>* Right() const { return right; }
-
+  //! Return whether or not this is the root of the tree.
   bool Root() const { return root; }
+  //! Return the upper part of the alpha sum.
+  double AlphaUpper() const { return alphaUpper; }
 
-  ////////////////////// Private Functions ////////////////////////////////////
+  //! Return the maximum values.
+  const arma::vec& MaxVals() const { return maxVals; }
+  //! Modify the maximum values.
+  arma::vec& MaxVals() { return maxVals; }
+
+  //! Return the minimum values.
+  const arma::vec& MinVals() const { return minVals; }
+  //! Modify the minimum values.
+  arma::vec& MinVals() { return minVals; }
+
  private:
 
-  inline double LogNegativeError(const size_t total_points) const;
+  // Utility methods.
 
+  /**
+   * Find the dimension to split on.
+   */
   bool FindSplit(const arma::mat& data,
                  size_t& splitDim,
                  double& splitValue,
@@ -151,71 +308,18 @@
                  const size_t maxLeafSize = 10,
                  const size_t minLeafSize = 5) const;
 
+  /**
+   * Split the data, returning the number of points left of the split.
+   */
   size_t SplitData(arma::mat& data,
                    const size_t splitDim,
                    const double splitValue,
                    arma::Col<size_t>& oldFromNew) const;
 
+  /**
+   * Return whether a query point is within the range of this node.
+   */
   inline bool WithinRange(const arma::vec& query) const;
-
-  ///////////////////// Public Functions //////////////////////////////////////
- public:
-
-  DTree();
-
-  // Root node initializer
-  // with the bounding box of the data
-  // it contains instead of just the data.
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
-        const size_t totalPoints);
-
-  // Root node initializer
-  // with the data, no bounding box.
-  DTree(arma::mat& data);
-
-  // Non-root node initializers
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
-        const size_t start,
-        const size_t end,
-        const double error);
-
-  DTree(const arma::vec& maxVals,
-        const arma::vec& minVals,
-        const size_t totalPoints,
-        const size_t start,
-        const size_t end);
-
-  ~DTree();
-
-  // Greedily expand the tree
-  double Grow(arma::mat& data,
-              arma::Col<size_t>& oldFromNew,
-              const bool useVolReg = false,
-              const size_t maxLeafSize = 10,
-              const size_t minLeafSize = 5);
-
-  // perform alpha pruning on the tree
-  double PruneAndUpdate(const double old_alpha, const bool useVolReg = false);
-
-  // compute the density at a given point
-  double ComputeValue(const arma::vec& query) const;
-
-  // print the tree (in a DFS manner)
-  void WriteTree(size_t level, FILE *fp);
-
-  // indexing the buckets for possible usage later
-  int TagTree(int tag);
-
-  // This is used to generate the class membership
-  // of a learned tree.
-  int FindBucket(const arma::vec& query) const;
-
-  // This computes the variable importance list
-  // for the learned tree.
-  void ComputeVariableImportance(arma::vec& importances) const;
-
 }; // Class DTree
 
 }; // namespace det

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-08-01 02:18:15 UTC (rev 13306)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-08-01 20:19:43 UTC (rev 13307)
@@ -15,6 +15,112 @@
 namespace mlpack {
 namespace det {
 
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree() :
+    start(0),
+    end(0),
+    logNegError(-DBL_MAX),
+    root(true),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+
+// Root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
+                     const size_t totalPoints) :
+    start(0),
+    end(totalPoints),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(LogNegativeError(totalPoints)),
+    root(true),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(arma::mat& data) :
+    start(0),
+    end(data.n_cols),
+    left(NULL),
+    right(NULL)
+{
+  maxVals.set_size(data.n_rows);
+  minVals.set_size(data.n_rows);
+
+  // Initialize to first column; values will be overwritten if necessary.
+  maxVals = data.col(0);
+  minVals = data.col(0);
+
+  // Loop over data to extract maximum and minimum values in each dimension.
+  for (size_t i = 1; i < data.n_cols; ++i)
+  {
+    for (size_t j = 0; j < data.n_rows; ++j)
+    {
+      if (data(j, i) > maxVals[j])
+        maxVals[j] = data(j, i);
+      if (data(j, i) < minVals[j])
+        minVals[j] = data(j, i);
+    }
+  }
+
+  logNegError = LogNegativeError(data.n_cols);
+
+  bucketTag = -1;
+  root = true;
+}
+
+
+// Non-root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
+                     const size_t start,
+                     const size_t end,
+                     const double logNegError) :
+    start(start),
+    end(end),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(logNegError),
+    root(false),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+                     const arma::vec& minVals,
+                     const size_t totalPoints,
+                     const size_t start,
+                     const size_t end) :
+    start(start),
+    end(end),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(LogNegativeError(totalPoints)),
+    root(false),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+template<typename eT, typename cT>
+DTree<eT, cT>::~DTree()
+{
+  if (left != NULL)
+    delete left;
+
+  if (right != NULL)
+    delete right;
+}
+
 // This function computes the log-l2-negative-error of a given node from the
 // formula R(t) = log(|t|^2 / (N^2 V_t)).
 template<typename eT, typename cT>
@@ -45,7 +151,7 @@
 
   const size_t points = end - start;
 
-  double minError = std::log(-error);
+  double minError = logNegError;
   bool splitFound = false;
 
   // Loop through each dimension.
@@ -63,20 +169,13 @@
     // Initializing all the stuff for this dimension.
     bool dimSplitFound = false;
     // Take an error estimate for this dimension.
-    double minDimError = points / (max - min);
+    double minDimError = std::pow(points, 2.0) / (max - min);
     double dimLeftError;
     double dimRightError;
     double dimSplitValue;
 
     // Find the log volume of all the other dimensions.
-    double volumeWithoutDim = 0;
-    for (size_t i = 0; i < maxVals.n_elem; ++i)
-    {
-      if ((maxVals[i] - minVals[i] > 0.0) && (i != dim))
-      {
-        volumeWithoutDim += std::log(maxVals[i] - minVals[i]);
-      }
-    }
+    double volumeWithoutDim = logVolume - std::log(max - min);
 
     // Get the values for the dimension.
     arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
@@ -145,11 +244,6 @@
     } // end if better split found in this dimension.
   }
 
-  // Map out of logspace.
-  minError = -std::exp(minError);
-  leftError = -std::exp(leftError);
-  rightError = -std::exp(rightError);
-
   return splitFound;
 }
 
@@ -187,111 +281,6 @@
   return left;
 }
 
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree() :
-    start(0),
-    end(0),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-
-// Root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t totalPoints) :
-    start(0),
-    end(totalPoints),
-    maxVals(maxVals),
-    minVals(minVals),
-    error(-std::exp(LogNegativeError(totalPoints))),
-    root(true),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(arma::mat& data) :
-    start(0),
-    end(data.n_cols),
-    left(NULL),
-    right(NULL)
-{
-  maxVals.set_size(data.n_rows);
-  minVals.set_size(data.n_rows);
-
-  // Initialize to first column; values will be overwritten if necessary.
-  maxVals = data.col(0);
-  minVals = data.col(0);
-
-  // Loop over data to extract maximum and minimum values in each dimension.
-  for (size_t i = 1; i < data.n_cols; ++i)
-  {
-    for (size_t j = 0; j < data.n_rows; ++j)
-    {
-      if (data(j, i) > maxVals[j])
-        maxVals[j] = data(j, i);
-      if (data(j, i) < minVals[j])
-        minVals[j] = data(j, i);
-    }
-  }
-
-  error = -std::exp(LogNegativeError(data.n_cols));
-
-  bucketTag = -1;
-  root = true;
-}
-
-
-// Non-root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t start,
-                     const size_t end,
-                     const double error) :
-    start(start),
-    end(end),
-    maxVals(maxVals),
-    minVals(minVals),
-    error(error),
-    root(false),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t totalPoints,
-                     const size_t start,
-                     const size_t end) :
-    start(start),
-    end(end),
-    maxVals(maxVals),
-    minVals(minVals),
-    error(-std::exp(LogNegativeError(totalPoints))),
-    root(false),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::~DTree()
-{
-  if (left != NULL)
-    delete left;
-
-  if (right != NULL)
-    delete right;
-}
-
-
 // Greedily expand the tree
 template<typename eT, typename cT>
 double DTree<eT, cT>::Grow(arma::mat& data,
@@ -308,17 +297,12 @@
   // Compute points ratio.
   ratio = (double) (end - start) / (double) oldFromNew.n_elem;
 
-  // Compute the v_t_inv: the inverse of the volume of the node.  We use log to
-  // prevent overflow.
-  double logVol = 0;
+  // Compute the log of the volume of the node.
+  logVolume = 0;
   for (size_t i = 0; i < maxVals.n_elem; ++i)
     if (maxVals[i] - minVals[i] > 0.0)
-      logVol += std::log(maxVals[i] - minVals[i]);
+      logVolume += std::log(maxVals[i] - minVals[i]);
 
-  // Check for overflow.
-  assert(std::exp(logVol) > 0.0);
-  vTInv = 1.0 / std::exp(logVol);
-
   // Check if node is large enough to split.
   if ((size_t) (end - start) > maxLeafSize) {
 
@@ -357,19 +341,24 @@
 
       // Store values of R(T~) and |T~|.
       subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-      subtreeLeavesError = left->SubtreeLeavesError() +
-          right->SubtreeLeavesError();
 
-      // Store subtreeLeavesVTInv.
-      subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
-          right->SubtreeLeavesVTInv();
+      // Find the log negative error of the subtree leaves.  This is kind of an
+      // odd one because we don't want to represent the error in non-log-space,
+      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
+      // V_t (remember E_l has an inverse relationship to the volume of the
+      // nodes) and then subtract log(V_t) at the end of the whole expression.
+      // As a result we do leave log-space, but the largest quantity we
+      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+      // node below this node, which depends heavily on the depth of the tree.
+      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+          - logVolume;
     }
     else
     {
       // No split found so make a leaf out of it.
       subtreeLeaves = 1;
-      subtreeLeavesError = error;
-      subtreeLeavesVTInv = vTInv;
+      subtreeLeavesLogNegError = logNegError;
     }
   }
   else
@@ -377,8 +366,7 @@
     // We can make this a leaf node.
     assert((size_t) (end - start) >= minLeafSize);
     subtreeLeaves = 1;
-    subtreeLeavesError = error;
-    subtreeLeavesVTInv = vTInv;
+    subtreeLeavesLogNegError = logNegError;
   }
 
   // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
@@ -390,13 +378,47 @@
   }
   else
   {
+    const double range = maxVals[splitDim] - minVals[splitDim];
+    const double leftRatio = (splitValue - minVals[splitDim]) / range;
+    const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+    const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+    const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+    const size_t thisPow = std::pow(end - start, 2);
+
+    double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
+
+    if (left->SubtreeLeaves() > 1)
+    {
+      const double exponent = 2 * std::log(data.n_cols) + logVolume +
+          left->AlphaUpper();
+
+      // Whether or not this will overflow is highly dependent on the depth of
+      // the tree.
+      tmpAlphaSum += std::exp(exponent);
+    }
+
+    if (right->SubtreeLeaves() > 1)
+    {
+      const double exponent = 2 * std::log(data.n_cols) + logVolume +
+          right->AlphaUpper();
+
+      tmpAlphaSum += std::exp(exponent);
+    }
+
+    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(data.n_cols) - logVolume;
+
     double gT;
     if (useVolReg)
-      gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+    {
+      // This is wrong for now!
+      gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
+    }
     else
-      gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+    {
+      gT = alphaUpper - std::log(subtreeLeaves - 1);
+    }
 
-    assert(gT > 0.0);
     return min(gT, min(leftG, rightG));
   }
 
@@ -408,23 +430,25 @@
 
 template<typename eT, typename cT>
 double DTree<eT, cT>::PruneAndUpdate(const double oldAlpha,
+                                     const size_t points,
                                      const bool useVolReg)
+
 {
   // Compute gT.
   if (subtreeLeaves == 1) // If we are a leaf...
   {
-    return std::numeric_limits<double>::max();
+    return 0;
   }
   else
   {
     // Compute gT value for node t.
     double gT;
     if (useVolReg)
-      gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+      gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
     else
-      gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+      gT = alphaUpper - std::log(subtreeLeaves - 1);
 
-    if (gT > oldAlpha)
+    if (gT < oldAlpha)
     {
       // Go down the tree and update accordingly.  Traverse the children.
       double leftG = left->PruneAndUpdate(oldAlpha, useVolReg);
@@ -432,35 +456,72 @@
 
       // Update values.
       subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-      subtreeLeavesError = left->SubtreeLeavesError() +
-          right->SubtreeLeavesError();
-      subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
-          right->SubtreeLeavesVTInv();
 
+      // Find the log negative error of the subtree leaves.  This is kind of an
+      // odd one because we don't want to represent the error in non-log-space,
+      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
+      // V_t (remember E_l has an inverse relationship to the volume of the
+      // nodes) and then subtract log(V_t) at the end of the whole expression.
+      // As a result we do leave log-space, but the largest quantity we
+      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+      // node below this node, which depends heavily on the depth of the tree.
+      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+          - logVolume;
+
+      // Recalculate upper alpha.
+      const double range = maxVals[splitDim] - minVals[splitDim];
+      const double leftRatio = (splitValue - minVals[splitDim]) / range;
+      const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+      const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+      const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+      const size_t thisPow = std::pow(end - start, 2);
+
+      double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
+          thisPow;
+
+      if (left->SubtreeLeaves() > 1)
+      {
+        const double exponent = 2 * std::log(points) + logVolume +
+            left->AlphaUpper();
+
+        // Whether or not this will overflow is highly dependent on the depth of
+        // the tree.
+        tmpAlphaSum += std::exp(exponent);
+      }
+
+      if (right->SubtreeLeaves() > 1)
+      {
+        const double exponent = 2 * std::log(points) + logVolume +
+            right->AlphaUpper();
+
+        tmpAlphaSum += std::exp(exponent);
+      }
+
+      alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(points) - logVolume;
+
       // Update gT value.
       if (useVolReg)
-        gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
+      {
+        // This is incorrect.
+        gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
+      }
       else
-        gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
+      {
+        gT = alphaUpper - std::log(subtreeLeaves - 1);
+      }
 
       assert(gT < std::numeric_limits<double>::max());
 
-      if (left->SubtreeLeaves() == 1 && right->SubtreeLeaves() == 1)
-        return gT;
-      else if (left->SubtreeLeaves() == 1)
-        return min(gT, rightG);
-      else if (right->SubtreeLeaves() == 1)
-        return min(gT, leftG);
-      else
-        return min(gT, min(leftG, rightG));
+      return min(gT, min(leftG, rightG));
     }
     else
     {
       // Prune this subtree.
       // First, make this node a leaf node.
       subtreeLeaves = 1;
-      subtreeLeavesError = error;
-      subtreeLeavesVTInv = vTInv;
+      subtreeLeavesLogNegError = logNegError;
 
       delete left;
       left = NULL;
@@ -503,7 +564,7 @@
 
   if (subtreeLeaves == 1)  // If we are a leaf...
   {
-    return ratio * vTInv;
+    return std::exp(std::log(ratio) - logVolume);
   }
   else
   {
@@ -517,11 +578,13 @@
       return right->ComputeValue(query);
     }
   }
+
+  return 0.0;
 }
 
 
 template<typename eT, typename cT>
-void DTree<eT, cT>::WriteTree(size_t level, FILE *fp)
+void DTree<eT, cT>::WriteTree(FILE *fp, const size_t level) const
 {
   if (subtreeLeaves > 1)
   {
@@ -530,18 +593,18 @@
       fprintf(fp, "|\t");
     fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
 
-    right->WriteTree(level + 1, fp);
+    right->WriteTree(fp, level + 1);
 
     fprintf(fp, "\n");
     for (size_t i = 0; i < level; ++i)
       fprintf(fp, "|\t");
     fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
 
-    left->WriteTree(level + 1, fp);
+    left->WriteTree(fp, level);
   }
   else // If we are a leaf...
   {
-    fprintf(fp, ": f(x)=%lg", ratio * vTInv);
+    fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
     if (bucketTag != -1)
       fprintf(fp, " BT:%d", bucketTag);
   }
@@ -550,10 +613,11 @@
 
 // Index the buckets for possible usage later.
 template<typename eT, typename cT>
-int DTree<eT, cT>::TagTree(int tag)
+int DTree<eT, cT>::TagTree(const int tag)
 {
   if (subtreeLeaves == 1)
   {
+    // Only label leaves.
     bucketTag = tag;
     return (tag + 1);
   }
@@ -561,7 +625,7 @@
   {
     return right->TagTree(left->TagTree(tag));
   }
-} // TagTree
+}
 
 
 template<typename eT, typename cT>
@@ -578,8 +642,9 @@
     // If left subtree, go to left child.
     return left->FindBucket(query);
   }
-  else // If right subtree, go to right child.
+  else
   {
+    // If right subtree, go to right child.
     return right->FindBucket(query);
   }
 }
@@ -603,8 +668,11 @@
     if (curNode.subtreeLeaves == 1)
       continue; // Do nothing for leaves.
 
-    importances[curNode.SplitDim()] += (double) (curNode.Error() -
-        (curNode.Left()->Error() + curNode.Right()->Error()));
+    // The way to do this entirely in log-space is (at this time) somewhat
+    // unclear.  So this risks overflow.
+    importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
+        (-std::exp(curNode.Left()->LogNegError()) +
+         -std::exp(curNode.Right()->LogNegError())));
 
     nodes.push(curNode.Left());
     nodes.push(curNode.Right());




More information about the mlpack-svn mailing list