[mlpack-svn] r13267 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 14:53:05 EDT 2012
Author: rcurtin
Date: 2012-07-20 14:53:05 -0400 (Fri, 20 Jul 2012)
New Revision: 13267
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Stop using eT and cT for good. Change all variable names to adhere to naming
guidelines.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-20 18:25:20 UTC (rev 13266)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-20 18:53:05 UTC (rev 13267)
@@ -216,10 +216,9 @@
} // leaf class membership
- if(CLI::HasParam("I")) {
- PrintVariableImportance<double>
- (dtree_opt, training_data.n_rows,
- (string) CLI::GetParam<string>("i"));
+ if (CLI::HasParam("I"))
+ {
+ PrintVariableImportance<double>(dtree_opt, CLI::GetParam<string>("i"));
} // print variable importance
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-20 18:25:20 UTC (rev 13266)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-20 18:53:05 UTC (rev 13267)
@@ -67,8 +67,7 @@
template<typename eT>
void PrintVariableImportance(DTree<eT> *dtree,
- size_t num_dims,
- string vi_file = "")
+ const string vi_file = "")
{
arma::vec imps;
dtree->ComputeVariableImportance(imps);
@@ -131,7 +130,7 @@
delete new_dataset;
- Log::Info << dtree->subtree_leaves()
+ Log::Info << dtree->SubtreeLeaves()
<< " leaf nodes in the tree with full data; min_alpha: " << alpha << "."
<< std::endl;
@@ -157,24 +156,24 @@
}
// Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
- std::vector<std::pair<double, long double> > pruned_sequence;
- while (dtree->subtree_leaves() > 1)
+ std::vector<std::pair<double, double> > pruned_sequence;
+ while (dtree->SubtreeLeaves() > 1)
{
- std::pair<double, long double> tree_seq(old_alpha,
- -1.0 * dtree->subtree_leaves_error());
+ std::pair<double, double> tree_seq(old_alpha,
+ -1.0 * dtree->SubtreeLeavesError());
pruned_sequence.push_back(tree_seq);
old_alpha = alpha;
alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
// Some sanity checks.
- assert((alpha < std::numeric_limits<long double>::max()) ||
- (dtree->subtree_leaves() == 1));
+ assert((alpha < std::numeric_limits<double>::max()) ||
+ (dtree->SubtreeLeaves() == 1));
assert(alpha > old_alpha);
- assert(dtree->subtree_leaves_error() >= -1.0 * tree_seq.second);
+ assert(dtree->SubtreeLeavesError() >= -1.0 * tree_seq.second);
}
- std::pair<long double, long double> tree_seq(old_alpha,
- -1.0 * dtree->subtree_leaves_error());
+ std::pair<double, double> tree_seq(old_alpha,
+ -1.0 * dtree->SubtreeLeavesError());
pruned_sequence.push_back(tree_seq);
Log::Info << pruned_sequence.size() << " trees in the sequence; max_alpha: "
@@ -232,11 +231,11 @@
// Sequentially prune with all the values of available alphas and adding
// values for test values.
- std::vector<std::pair<double, long double> >::iterator it;
+ std::vector<std::pair<double, double> >::iterator it;
for (it = pruned_sequence.begin(); it < pruned_sequence.end() -2; ++it)
{
// Compute test values for this state of the tree.
- long double val_cv = 0.0;
+ double val_cv = 0.0;
for (size_t i = 0; i < test.n_cols; i++)
{
arma::Col<eT> test_point = test.unsafe_col(i);
@@ -244,7 +243,7 @@
}
// Update the cv error value.
- it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+ it->second -= 2.0 * val_cv / (double) dataset->n_cols;
// Determine the new alpha value and prune accordingly.
old_alpha = sqrt(((it + 1)->first) * ((it + 2)->first));
@@ -252,7 +251,7 @@
}
// Compute test values for this state of the tree.
- long double val_cv = 0.0;
+ double val_cv = 0.0;
for (size_t i = 0; i < test.n_cols; ++i)
{
arma::Col<eT> test_point = test.unsafe_col(i);
@@ -260,7 +259,7 @@
}
// Update the cv error value.
- it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+ it->second -= 2.0 * val_cv / (double) dataset->n_cols;
test.reset();
delete train;
@@ -270,9 +269,9 @@
delete cvdata;
- long double optimal_alpha = -1.0;
- long double best_cv_error = numeric_limits<long double>::max();
- std::vector<std::pair<double, long double> >::iterator it;
+ double optimal_alpha = -1.0;
+ double best_cv_error = numeric_limits<double>::max();
+ std::vector<std::pair<double, double> >::iterator it;
for (it = pruned_sequence.begin(); it < pruned_sequence.end() -1; ++it)
{
@@ -301,18 +300,18 @@
minLeafSize);
// Prune with optimal alpha.
- while ((old_alpha < optimal_alpha) && (dtree_opt->subtree_leaves() > 1))
+ while ((old_alpha < optimal_alpha) && (dtree_opt->SubtreeLeaves() > 1))
{
old_alpha = alpha;
alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
// Some sanity checks.
- assert((alpha < numeric_limits<long double>::max()) ||
- (dtree_opt->subtree_leaves() == 1));
+ assert((alpha < numeric_limits<double>::max()) ||
+ (dtree_opt->SubtreeLeaves() == 1));
assert(alpha > old_alpha);
}
- Log::Info << dtree_opt->subtree_leaves()
+ Log::Info << dtree_opt->SubtreeLeaves()
<< " leaf nodes in the optimally pruned tree; optimal alpha: "
<< old_alpha << "." << std::endl;
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-20 18:25:20 UTC (rev 13266)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-20 18:53:05 UTC (rev 13267)
@@ -59,17 +59,13 @@
class DTree
{
private:
-
- typedef arma::Mat<eT> MatType;
- typedef arma::Col<eT> VecType;
- typedef arma::Row<eT> RowVecType;
-
// The indices in the complete set of points
// (after all forms of swapping in the original data
// matrix to align all the points in a node
// consecutively in the matrix. The 'old_from_new' array
// maps the points back to their original indices.
- size_t start_, end_;
+ size_t start;
+ size_t end;
// since we are using uniform density, we need
// the max and min of every dimension for every node
@@ -77,70 +73,70 @@
arma::vec minVals;
// The split dim for this node
- size_t split_dim_;
+ size_t splitDim;
// The split val on that dim
- eT split_value_;
+ double splitValue;
// L2-error of the node
- cT error_;
+ double error;
// sum of the error of the leaves of the subtree
- cT subtree_leaves_error_;
+ double subtreeLeavesError;
// number of leaves of the subtree
- size_t subtree_leaves_;
+ size_t subtreeLeaves;
// flag to indicate if this is the root node
// used to check whether the query point is
// within the range
- bool root_;
+ bool root;
// ratio of number of points in the node to the
// total number of points (|t| / N)
- cT ratio_;
+ double ratio;
// the inverse of volume of the node
- cT v_t_inv_;
+ double vTInv;
// sum of the reciprocal of the inverse v_ts
// the leaves of this subtree
- cT subtree_leaves_v_t_inv_;
+ double subtreeLeavesVTInv;
// the tag for the leaf used for hashing points
- int bucket_tag_;
+ int bucketTag;
// The children
- DTree<eT, cT> *left_;
- DTree<eT, cT> *right_;
+ DTree<eT, cT> *left;
+ DTree<eT, cT> *right;
public:
////////////////////// Getters and Setters //////////////////////////////////
- size_t start() const { return start_; }
+ size_t Start() const { return start; }
- size_t end() const { return end_; }
+ size_t End() const { return end; }
- size_t split_dim() const { return split_dim_; }
+ size_t SplitDim() const { return splitDim; }
- eT split_value() const { return split_value_; }
+ double SplitValue() const { return splitValue; }
- cT error() const { return error_; }
+ double Error() const { return error; }
- cT subtree_leaves_error() const { return subtree_leaves_error_; }
+ double SubtreeLeavesError() const { return subtreeLeavesError; }
- size_t subtree_leaves() const { return subtree_leaves_; }
+ size_t SubtreeLeaves() const { return subtreeLeaves; }
- cT ratio() const { return ratio_; }
+ double Ratio() const { return ratio; }
- cT v_t_inv() const { return v_t_inv_; }
+ double VTInv() const { return vTInv; }
- cT subtree_leaves_v_t_inv() const { return subtree_leaves_v_t_inv_; }
+ double SubtreeLeavesVTInv() const { return subtreeLeavesVTInv; }
- DTree<eT, cT>* left() const { return left_; }
- DTree<eT, cT>* right() const { return right_; }
+ DTree<eT, cT>* Left() const { return left; }
+ DTree<eT, cT>* Right() const { return right; }
- bool root() const { return root_; }
+ bool Root() const { return root; }
////////////////////// Private Functions ////////////////////////////////////
private:
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-20 18:25:20 UTC (rev 13266)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-20 18:53:05 UTC (rev 13267)
@@ -21,14 +21,14 @@
inline double DTree<eT, cT>::LogNegativeError(const size_t totalPoints) const
{
// log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
- return 2 * std::log((double) (end_ - start_)) -
+ return 2 * std::log((double) (end - start)) -
2 * std::log((double) totalPoints) -
arma::accu(arma::log(maxVals - minVals));
}
// This function finds the best split with respect to the L2-error, by trying
-// all possible splits. The dataset is the full data set but the start_ and
-// end_ are used to obtain the point in this node.
+// all possible splits. The dataset is the full data set but the start and
+// end are used to obtain the point in this node.
template<typename eT, typename cT>
bool DTree<eT, cT>::FindSplit(const arma::mat& data,
size_t& splitDim,
@@ -43,9 +43,9 @@
assert(data.n_rows == maxVals.n_elem);
assert(data.n_rows == minVals.n_elem);
- const size_t points = end_ - start_;
+ const size_t points = end - start;
- double minError = std::log(-error_);
+ double minError = std::log(-error);
bool splitFound = false;
// Loop through each dimension.
@@ -79,7 +79,7 @@
}
// Get the values for the dimension.
- arma::rowvec dimVec = data.row(dim).subvec(start_, end_ - 1);
+ arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
// Sort the values in ascending order.
dimVec = arma::sort(dimVec);
@@ -88,7 +88,7 @@
assert(dimVec.n_elem > maxLeafSize);
// Find the best split for this dimension. We need to figure out why
- // there are spikes if this min_leaf_size is enforced here...
+ // there are spikes if this minLeafSize is enforced here...
for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
{
// This makes sense for real continuous data. This kinda corrupts the
@@ -99,7 +99,7 @@
continue; // We can't split here (two points are the same).
// Another way of picking split is using this:
- // split = left_split;
+ // split = leftsplit;
if ((split - min > 0.0) && (max - split > 0.0))
{
// Ensure that the right node will have at least the minimum number of
@@ -163,8 +163,8 @@
// less than or equal to splitValue are on the left side, and all others are
// on the right side. A similar sort to this is also performed in
// BinarySpaceTree construction (its comments are more detailed).
- size_t left = start_;
- size_t right = end_ - 1;
+ size_t left = start;
+ size_t right = end - 1;
for (;;)
{
while (data(splitDim, left) <= splitValue)
@@ -190,10 +190,10 @@
template<typename eT, typename cT>
DTree<eT, cT>::DTree() :
- start_(0),
- end_(0),
- left_(NULL),
- right_(NULL)
+ start(0),
+ end(0),
+ left(NULL),
+ right(NULL)
{ /* Nothing to do. */ }
@@ -202,23 +202,23 @@
DTree<eT, cT>::DTree(const arma::vec& maxVals,
const arma::vec& minVals,
const size_t totalPoints) :
- start_(0),
- end_(totalPoints),
+ start(0),
+ end(totalPoints),
maxVals(maxVals),
minVals(minVals),
- error_(-std::exp(LogNegativeError(totalPoints))),
- root_(true),
- bucket_tag_(-1),
- left_(NULL),
- right_(NULL)
+ error(-std::exp(LogNegativeError(totalPoints))),
+ root(true),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
{ /* Nothing to do. */ }
template<typename eT, typename cT>
DTree<eT, cT>::DTree(arma::mat& data) :
- start_(0),
- end_(data.n_cols),
- left_(NULL),
- right_(NULL)
+ start(0),
+ end(data.n_cols),
+ left(NULL),
+ right(NULL)
{
maxVals.set_size(data.n_rows);
minVals.set_size(data.n_rows);
@@ -239,10 +239,10 @@
}
}
- error_ = -std::exp(LogNegativeError(data.n_cols));
+ error = -std::exp(LogNegativeError(data.n_cols));
- bucket_tag_ = -1;
- root_ = true;
+ bucketTag = -1;
+ root = true;
}
@@ -253,15 +253,15 @@
const size_t start,
const size_t end,
const double error) :
- start_(start),
- end_(end),
+ start(start),
+ end(end),
maxVals(maxVals),
minVals(minVals),
- error_(error),
- root_(false),
- bucket_tag_(-1),
- left_(NULL),
- right_(NULL)
+ error(error),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
{ /* Nothing to do. */ }
template<typename eT, typename cT>
@@ -270,25 +270,25 @@
const size_t totalPoints,
const size_t start,
const size_t end) :
- start_(start),
- end_(end),
+ start(start),
+ end(end),
maxVals(maxVals),
minVals(minVals),
- error_(-std::exp(LogNegativeError(totalPoints))),
- root_(false),
- bucket_tag_(-1),
- left_(NULL),
- right_(NULL)
+ error(-std::exp(LogNegativeError(totalPoints))),
+ root(false),
+ bucketTag(-1),
+ left(NULL),
+ right(NULL)
{ /* Nothing to do. */ }
template<typename eT, typename cT>
DTree<eT, cT>::~DTree()
{
- if (left_ != NULL)
- delete left_;
+ if (left != NULL)
+ delete left;
- if (right_ != NULL)
- delete right_;
+ if (right != NULL)
+ delete right;
}
@@ -306,7 +306,7 @@
double leftG, rightG;
// Compute points ratio.
- ratio_ = (double) (end_ - start_) / (double) oldFromNew.n_elem;
+ ratio = (double) (end - start) / (double) oldFromNew.n_elem;
// Compute the v_t_inv: the inverse of the volume of the node. We use log to
// prevent overflow.
@@ -317,93 +317,92 @@
// Check for overflow.
assert(std::exp(logVol) > 0.0);
- v_t_inv_ = 1.0 / std::exp(logVol);
+ vTInv = 1.0 / std::exp(logVol);
// Check if node is large enough to split.
- if ((size_t) (end_ - start_) > maxLeafSize) {
+ if ((size_t) (end - start) > maxLeafSize) {
// Find the split.
size_t dim;
- double splitValue;
+ double splitValueTmp;
double leftError, rightError;
- if (FindSplit(data, dim, splitValue, leftError, rightError, maxLeafSize,
+ if (FindSplit(data, dim, splitValueTmp, leftError, rightError, maxLeafSize,
minLeafSize))
{
// Move the data around for the children to have points in a node lie
// contiguously (to increase efficiency during the training).
- const size_t splitIndex = SplitData(data, dim, splitValue, oldFromNew);
+ const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
// Make max and min vals for the children.
- arma::vec max_vals_l(maxVals);
- arma::vec max_vals_r(maxVals);
- arma::vec min_vals_l(minVals);
- arma::vec min_vals_r(minVals);
+ arma::vec maxValsL(maxVals);
+ arma::vec maxValsR(maxVals);
+ arma::vec minValsL(minVals);
+ arma::vec minValsR(minVals);
- max_vals_l[dim] = splitValue;
- min_vals_r[dim] = splitValue;
+ maxValsL[dim] = splitValueTmp;
+ minValsR[dim] = splitValueTmp;
// Store split dim and split val in the node.
- split_value_ = splitValue;
- split_dim_ = dim;
+ splitValue = splitValueTmp;
+ splitDim = dim;
// Recursively grow the children.
- left_ = new DTree(max_vals_l, min_vals_l, start_, splitIndex, leftError);
- right_ = new DTree(max_vals_r, min_vals_r, splitIndex, end_, rightError);
+ left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
+ right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
- leftG = left_->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+ leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
minLeafSize);
- rightG = right_->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+ rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
minLeafSize);
// Store values of R(T~) and |T~|.
- subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
- subtree_leaves_error_ = left_->subtree_leaves_error() +
- right_->subtree_leaves_error();
+ subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+ subtreeLeavesError = left->SubtreeLeavesError() +
+ right->SubtreeLeavesError();
- // Store the subtree_leaves_v_t_inv.
- subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv() +
- right_->subtree_leaves_v_t_inv();
+ // Store subtreeLeavesVTInv.
+ subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
+ right->SubtreeLeavesVTInv();
}
else
{
// No split found so make a leaf out of it.
- subtree_leaves_ = 1;
- subtree_leaves_error_ = error_;
- subtree_leaves_v_t_inv_ = v_t_inv_;
+ subtreeLeaves = 1;
+ subtreeLeavesError = error;
+ subtreeLeavesVTInv = vTInv;
}
}
else
{
// We can make this a leaf node.
- assert((size_t) (end_ - start_) >= minLeafSize);
- subtree_leaves_ = 1;
- subtree_leaves_error_ = error_;
- subtree_leaves_v_t_inv_ = v_t_inv_;
+ assert((size_t) (end - start) >= minLeafSize);
+ subtreeLeaves = 1;
+ subtreeLeavesError = error;
+ subtreeLeavesVTInv = vTInv;
}
// If this is a leaf, do not compute g_k(t); otherwise compute, store, and
// propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
// leaves.
- if (subtree_leaves_ == 1)
+ if (subtreeLeaves == 1)
{
return std::numeric_limits<double>::max();
}
else
{
- double g_t;
+ double gT;
if (useVolReg)
- g_t = (error_ - subtree_leaves_error_) /
- (subtree_leaves_v_t_inv_ - v_t_inv_);
+ gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
else
- g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
+ gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
- assert(g_t > 0.0);
- return min(g_t, min(leftG, rightG));
+ assert(gT > 0.0);
+ return min(gT, min(leftG, rightG));
}
// We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
- // n_t ^ 2 / r_t * n ^ 2 = -error_. Therefore the value we need is actually
- // -1.0 * subtree_leaves_error_.
+ // n_t ^ 2 / r_t * n ^ 2 = -error. Therefore the value we need is actually
+ // -1.0 * subtreeLeavesError.
}
@@ -411,64 +410,62 @@
double DTree<eT, cT>::PruneAndUpdate(const double oldAlpha,
const bool useVolReg)
{
- // Compute g_t.
- if (subtree_leaves_ == 1) // If we are a leaf...
+ // Compute gT.
+ if (subtreeLeaves == 1) // If we are a leaf...
{
return std::numeric_limits<double>::max();
}
else
{
- // Compute g_t value for node t.
- double g_t;
+ // Compute gT value for node t.
+ double gT;
if (useVolReg)
- g_t = (error_ - subtree_leaves_error_) /
- (subtree_leaves_v_t_inv_ - v_t_inv_);
+ gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
else
- g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
+ gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
- if (g_t > oldAlpha)
+ if (gT > oldAlpha)
{
// Go down the tree and update accordingly. Traverse the children.
- double left_g = left_->PruneAndUpdate(oldAlpha, useVolReg);
- double right_g = right_->PruneAndUpdate(oldAlpha, useVolReg);
+ double leftG = left->PruneAndUpdate(oldAlpha, useVolReg);
+ double rightG = right->PruneAndUpdate(oldAlpha, useVolReg);
// Update values.
- subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
- subtree_leaves_error_ = left_->subtree_leaves_error() +
- right_->subtree_leaves_error();
- subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv() +
- right_->subtree_leaves_v_t_inv();
+ subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+ subtreeLeavesError = left->SubtreeLeavesError() +
+ right->SubtreeLeavesError();
+ subtreeLeavesVTInv = left->SubtreeLeavesVTInv() +
+ right->SubtreeLeavesVTInv();
- // Update g_t value.
+ // Update gT value.
if (useVolReg)
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_v_t_inv_ - v_t_inv_);
+ gT = (error - subtreeLeavesError) / (subtreeLeavesVTInv - vTInv);
else
- g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
+ gT = (error - subtreeLeavesError) / (subtreeLeaves - 1);
- assert(g_t < std::numeric_limits<double>::max());
+ assert(gT < std::numeric_limits<double>::max());
- if (left_->subtree_leaves() == 1 && right_->subtree_leaves() == 1)
- return g_t;
- else if (left_->subtree_leaves() == 1)
- return min(g_t, right_g);
- else if (right_->subtree_leaves() == 1)
- return min(g_t, left_g);
+ if (left->SubtreeLeaves() == 1 && right->SubtreeLeaves() == 1)
+ return gT;
+ else if (left->SubtreeLeaves() == 1)
+ return min(gT, rightG);
+ else if (right->SubtreeLeaves() == 1)
+ return min(gT, leftG);
else
- return min(g_t, min(left_g, right_g));
+ return min(gT, min(leftG, rightG));
}
else
{
// Prune this subtree.
// First, make this node a leaf node.
- subtree_leaves_ = 1;
- subtree_leaves_error_ = error_;
- subtree_leaves_v_t_inv_ = v_t_inv_;
+ subtreeLeaves = 1;
+ subtreeLeavesError = error;
+ subtreeLeavesVTInv = vTInv;
- delete left_;
- left_ = NULL;
- delete right_;
- right_ = NULL;
+ delete left;
+ left = NULL;
+ delete right;
+ right = NULL;
// Pass information upward.
return std::numeric_limits<double>::max();
@@ -497,27 +494,27 @@
{
Log::Assert(query.n_elem == maxVals.n_elem);
- if (root_ == 1) // If we are the root...
+ if (root == 1) // If we are the root...
{
// Check if the query is within range.
if (!WithinRange(query))
return 0.0;
}
- if (subtree_leaves_ == 1) // If we are a leaf...
+ if (subtreeLeaves == 1) // If we are a leaf...
{
- return ratio_ * v_t_inv_;
+ return ratio * vTInv;
}
else
{
- if (query[split_dim_] <= split_value_)
+ if (query[splitDim] <= splitValue)
{
// If left subtree, go to left child.
- return left_->ComputeValue(query);
+ return left->ComputeValue(query);
}
else // If right subtree, go to right child
{
- return right_->ComputeValue(query);
+ return right->ComputeValue(query);
}
}
}
@@ -526,27 +523,27 @@
template<typename eT, typename cT>
void DTree<eT, cT>::WriteTree(size_t level, FILE *fp)
{
- if (subtree_leaves_ > 1)
+ if (subtreeLeaves > 1)
{
fprintf(fp, "\n");
for (size_t i = 0; i < level; ++i)
fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu > %lg", split_dim_, split_value_);
+ fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
- right_->WriteTree(level + 1, fp);
+ right->WriteTree(level + 1, fp);
fprintf(fp, "\n");
for (size_t i = 0; i < level; ++i)
fprintf(fp, "|\t");
- fprintf(fp, "Var. %zu <= %lg ", split_dim_, split_value_);
+ fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
- left_->WriteTree(level + 1, fp);
+ left->WriteTree(level + 1, fp);
}
else // If we are a leaf...
{
- fprintf(fp, ": f(x)=%Lg", (cT) ratio_ * v_t_inv_);
- if (bucket_tag_ != -1)
- fprintf(fp, " BT:%d", bucket_tag_);
+ fprintf(fp, ": f(x)=%lg", ratio * vTInv);
+ if (bucketTag != -1)
+ fprintf(fp, " BT:%d", bucketTag);
}
}
@@ -555,14 +552,14 @@
template<typename eT, typename cT>
int DTree<eT, cT>::TagTree(int tag)
{
- if (subtree_leaves_ == 1)
+ if (subtreeLeaves == 1)
{
- bucket_tag_ = tag;
+ bucketTag = tag;
return (tag + 1);
}
else
{
- return right_->TagTree(left_->TagTree(tag));
+ return right->TagTree(left->TagTree(tag));
}
} // TagTree
@@ -572,18 +569,18 @@
{
Log::Assert(query.n_elem == maxVals.n_elem);
- if (subtree_leaves_ == 1) // If we are a leaf...
+ if (subtreeLeaves == 1) // If we are a leaf...
{
- return bucket_tag_;
+ return bucketTag;
}
- else if (query[split_dim_] <= split_value_)
+ else if (query[splitDim] <= splitValue)
{
// If left subtree, go to left child.
- return left_->FindBucket(query);
+ return left->FindBucket(query);
}
else // If right subtree, go to right child.
{
- return right_->FindBucket(query);
+ return right->FindBucket(query);
}
}
@@ -603,14 +600,14 @@
const DTree& curNode = *nodes.top();
nodes.pop();
- if (curNode.subtree_leaves_ == 1)
+ if (curNode.subtreeLeaves == 1)
continue; // Do nothing for leaves.
- importances[curNode.split_dim()] += (double) (curNode.error() -
- (curNode.left()->error() + curNode.right()->error()));
+ importances[curNode.SplitDim()] += (double) (curNode.Error() -
+ (curNode.Left()->Error() + curNode.Right()->Error()));
- nodes.push(curNode.left());
- nodes.push(curNode.right());
+ nodes.push(curNode.Left());
+ nodes.push(curNode.Right());
}
}
More information about the mlpack-svn
mailing list