[mlpack-svn] r12421 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Apr 16 18:12:11 EDT 2012
Author: rcurtin
Date: 2012-04-16 18:12:10 -0400 (Mon, 16 Apr 2012)
New Revision: 12421
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
First pass: style. Fix a memory leak or two too.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-04-16 21:44:50 UTC (rev 12420)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-04-16 22:12:10 UTC (rev 12421)
@@ -1,11 +1,10 @@
/**
* @file dt_utils.hpp
- * @ Parikshit Ram (pram at cc.gatech.edu)
+ * @author Parikshit Ram (pram at cc.gatech.edu)
*
- * This file implements functions to perform
- * different tasks with the Density Tree class.
+ * This file implements functions to perform different tasks with the Density
+ * Tree class.
*/
-
#ifndef __MLPACK_METHODS_DET_DT_UTILS_HPP
#define __MLPACK_METHODS_DET_DT_UTILS_HPP
@@ -14,319 +13,317 @@
#include <mlpack/core.hpp>
#include "dtree.hpp"
-using namespace std;
-
namespace mlpack {
namespace det {
- template<typename eT>
- void PrintLeafMembership(DTree<eT> *dtree,
- const arma::Mat<eT>& data,
- const arma::Mat<int>& labels,
- size_t num_classes,
- string leaf_class_membership_file = "")
- {
- // tag the leaves with numbers
- int num_leaves = dtree->TagTree(0);
-
- arma::Mat<size_t> table(num_leaves, num_classes);
- table.zeros();
+template<typename eT>
+void PrintLeafMembership(DTree<eT> *dtree,
+ const arma::Mat<eT>& data,
+ const arma::Mat<int>& labels,
+ size_t num_classes,
+ string leaf_class_membership_file = "")
+{
+ // Tag the leaves with numbers.
+ int num_leaves = dtree->TagTree(0);
- for (size_t i = 0; i < data.n_cols; i++) {
- arma::Col<eT> test_p = data.unsafe_col(i);
- int leaf_tag = dtree->FindBucket(&test_p);
- int label = labels[i];
- table(leaf_tag, label) += 1;
- } // end for
+ arma::Mat<size_t> table(num_leaves, num_classes);
+ table.zeros();
- if (leaf_class_membership_file == "") {
- Log::Warn << "Leaf Membership: Classes in each leaf" << endl
- << table << endl;
- } else {
- // create a stream for the file
- ofstream outfile(leaf_class_membership_file.c_str());
- if (outfile.good()) {
- Log::Warn << "Leaf Membership: Classes in each leaf"
- << " printed in '" << leaf_class_membership_file
- << "'" << endl;
- outfile << table;
- } else {
- Log::Warn << "Can't open '" << leaf_class_membership_file
- << "'" << endl;
- }
- outfile.close();
- }
+ for (size_t i = 0; i < data.n_cols; i++)
+ {
+ arma::Col<eT> test_p = data.unsafe_col(i);
+ int leaf_tag = dtree->FindBucket(&test_p);
+ int label = labels[i];
+ table(leaf_tag, label) += 1;
+ }
- return;
- // maybe print some more statistics if these work out well
- } // PrintLeafMembership
-
-
- template<typename eT>
- void PrintVariableImportance(DTree<eT> *dtree,
- size_t num_dims,
- string vi_file = "")
+ if (leaf_class_membership_file == "")
{
- arma::Col<double> *imps
- = new arma::Col<double>(num_dims);
+ Log::Warn << "Leaf Membership: Classes in each leaf" << std::endl
+ << table << std::endl;
+ }
+ else
+ {
+ // Create a stream for the file.
+ ofstream outfile(leaf_class_membership_file.c_str());
+ if (outfile.good())
+ {
+ outfile << table;
+ Log::Warn << "Leaf Membership: Classes in each leaf"
+ << " printed in '" << leaf_class_membership_file << "'."
+ << std::endl;
+ }
+ else
+ {
+ Log::Warn << "Can't open '" << leaf_class_membership_file << "'."
+ << std::endl;
+ }
+ outfile.close();
+ }
- for (size_t i = 0; i < imps->n_elem; i++)
- (*imps)[i] = 0.0;
-
- dtree->ComputeVariableImportance(imps);
- double max = 0.0;
- for (size_t i = 0; i < imps->n_elem; i++)
- if ((*imps)[i] > max)
- max = (*imps)[i];
- Log::Warn << "Max. variable importance: " << max << endl;
+ return;
+}
- if (vi_file == "") {
- Log::Warn << "Variable importance: " << endl
- << imps->t();
- } else {
- ofstream outfile(vi_file.c_str());
- if (outfile.good()) {
- Log::Warn << "Variable importance printed in '"
- << vi_file << "'" << endl;
- outfile << *imps;
- } else {
- Log::Warn << "Can't open '" << vi_file
- << "'" << endl;
- }
- outfile.close();
- }
+template<typename eT>
+void PrintVariableImportance(DTree<eT> *dtree,
+ size_t num_dims,
+ string vi_file = "")
+{
+ arma::Col<double> *imps = new arma::Col<double>(num_dims);
+ imps->zeros();
- return;
- } // PrintVariableImportance
+ dtree->ComputeVariableImportance(imps);
+ double max = 0.0;
+ for (size_t i = 0; i < imps->n_elem; ++i)
+ if ((*imps)[i] > max)
+ max = (*imps)[i];
+ Log::Warn << "Max. variable importance: " << max << "." << std::endl;
- // This function trains the optimal decision tree
- // using the given number of folds
- template<typename eT>
- DTree<eT> *Trainer(arma::Mat<eT>* dataset,
- size_t folds,
- bool useVolumeReg = false,
- size_t maxLeafSize = 10,
- size_t minLeafSize = 5,
- string unprunedTreeOutput = "")
+ if (vi_file == "")
{
- // Initializing the tree
- DTree<eT> *dtree = new DTree<eT>(dataset);
-
- // Getting ready to grow the tree
- arma::Col<size_t> old_from_new(dataset->n_cols);
- for (size_t i = 0; i < old_from_new.n_elem; i++) {
- old_from_new[i] = i;
+ Log::Warn << "Variable importance: " << std::endl << imps->t() << std::endl;
+ }
+ else
+ {
+ ofstream outfile(vi_file.c_str());
+ if (outfile.good())
+ {
+ Log::Warn << "Variable importance printed in '" << vi_file << "'."
+ << endl;
+ outfile << *imps;
+ } else {
+ Log::Warn << "Can't open '" << vi_file
+ << "'" << endl;
}
+ outfile.close();
+ }
- // Saving the dataset since it would be modified
- // while growing the tree
- arma::Mat<eT>* new_dataset = new arma::Mat<eT>(*dataset);
+ delete[] imps;
- // Growing the tree
- long double old_alpha = 0.0;
- long double alpha = dtree->Grow(new_dataset, &old_from_new,
- useVolumeReg, maxLeafSize,
- minLeafSize);
- // clear the data set
- delete new_dataset;
+ return;
+} // PrintVariableImportance
- Log::Info << dtree->subtree_leaves()
- << " leaf nodes in the tree with full data, min_alpha: "
- << alpha << endl;
- // computing densities for the train points in the
- // full tree if asked for.
- if (unprunedTreeOutput != "") {
+// This function trains the optimal decision tree using the given number of
+// folds.
+template<typename eT>
+DTree<eT> *Trainer(arma::Mat<eT>* dataset,
+ size_t folds,
+ bool useVolumeReg = false,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5,
+ string unprunedTreeOutput = "")
+{
+ // Initialize the tree.
+ DTree<eT>* dtree = new DTree<eT>(dataset);
- ofstream outfile(unprunedTreeOutput.c_str());
- if (outfile.good()) {
- for (size_t i = 0; i < dataset->n_cols; i++) {
- arma::Col<eT> test_p = dataset->unsafe_col(i);
- outfile << dtree->ComputeValue(&test_p) << endl;
- } // end for
- } else {
- Log::Warn << "Can't open '" << unprunedTreeOutput
- << "'" << endl;
- }
+ // Getting ready to grow the tree...
+ arma::Col<size_t> old_from_new(dataset->n_cols);
+ for (size_t i = 0; i < old_from_new.n_elem; i++)
+ old_from_new[i] = i;
- outfile.close();
+ // Saving the dataset since it would be modified while growing the tree
+ arma::Mat<eT>* new_dataset = new arma::Mat<eT>(*dataset);
- } // if unprunedTreeOutput
+ // Growing the tree
+ long double old_alpha = 0.0;
+ long double alpha = dtree->Grow(new_dataset, &old_from_new, useVolumeReg,
+ maxLeafSize, minLeafSize);
- // sequential pruning and saving the alpha vals and the
- // values of c_t^2*r_t
- std::vector<std::pair<long double, long double> > pruned_sequence;
- while (dtree->subtree_leaves() > 1) {
+ delete new_dataset;
- std::pair<long double, long double> tree_seq
- (old_alpha, -1.0 * dtree->subtree_leaves_error());
- pruned_sequence.push_back(tree_seq);
- old_alpha = alpha;
- alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
+ Log::Info << dtree->subtree_leaves()
+ << " leaf nodes in the tree with full data; min_alpha: " << alpha << "."
+ << std::endl;
- // some checks
- assert((alpha < std::numeric_limits<long double>::max())
- ||(dtree->subtree_leaves() == 1));
- assert(alpha > old_alpha);
- assert(dtree->subtree_leaves_error() >= -1.0 * tree_seq.second);
+ // Compute densities for the training points in the full tree, if we were
+ // asked for this.
+ if (unprunedTreeOutput != "")
+ {
+ ofstream outfile(unprunedTreeOutput.c_str());
+ if (outfile.good())
+ {
+ for (size_t i = 0; i < dataset->n_cols; ++i)
+ {
+ arma::Col<eT> test_p = dataset->unsafe_col(i);
+ outfile << dtree->ComputeValue(&test_p) << endl;
+ }
+ }
+ else
+ {
+ Log::Warn << "Can't open '" << unprunedTreeOutput << "'." << std::endl;
+ }
- } // end while
-
- std::pair<long double, long double> tree_seq
- (old_alpha, -1.0 * dtree->subtree_leaves_error());
+ outfile.close();
+ }
+
+ // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
+ std::vector<std::pair<long double, long double> > pruned_sequence;
+ while (dtree->subtree_leaves() > 1)
+ {
+ std::pair<long double, long double> tree_seq(old_alpha,
+ -1.0 * dtree->subtree_leaves_error());
pruned_sequence.push_back(tree_seq);
+ old_alpha = alpha;
+ alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
- Log::Info << pruned_sequence.size()
- << " trees in the sequence, max_alpha: "
- << old_alpha << endl;
+ // Some sanity checks.
+ assert((alpha < std::numeric_limits<long double>::max()) ||
+ (dtree->subtree_leaves() == 1));
+ assert(alpha > old_alpha);
+ assert(dtree->subtree_leaves_error() >= -1.0 * tree_seq.second);
+ }
- delete dtree;
+ std::pair<long double, long double> tree_seq(old_alpha,
+ -1.0 * dtree->subtree_leaves_error());
+ pruned_sequence.push_back(tree_seq);
- arma::Mat<eT>* cvdata = new arma::Mat<eT>(*dataset);
+ Log::Info << pruned_sequence.size() << " trees in the sequence; max_alpha: "
+ << old_alpha << "." << std::endl;
- size_t test_size = dataset->n_cols / folds;
+ delete dtree;
- // Go through each fold
- for (size_t fold = 0; fold < folds; fold++) {
-
- // break up data into train and test set
- size_t start = fold * test_size,
- end = std::min((fold + 1) * test_size, (size_t) cvdata->n_cols);
- arma::Mat<eT> test = cvdata->cols(start, end - 1);
- arma::Mat<eT>* train
- = new arma::Mat<eT>(cvdata->n_rows,
- cvdata->n_cols - test.n_cols);
+ arma::Mat<eT>* cvdata = new arma::Mat<eT>(*dataset);
+ size_t test_size = dataset->n_cols / folds;
- if (start == 0 && end < cvdata->n_cols) {
- assert(train->n_cols == cvdata->n_cols - end);
- train->cols(0, train->n_cols - 1)
- = cvdata->cols(end, cvdata->n_cols - 1);
+ // Go through each fold.
+ for (size_t fold = 0; fold < folds; fold++)
+ {
+ // Break up data into train and test sets.
+ size_t start = fold * test_size;
+ size_t end = std::min((fold + 1) * test_size, (size_t) cvdata->n_cols);
+ arma::Mat<eT> test = cvdata->cols(start, end - 1);
+ arma::Mat<eT>* train = new arma::Mat<eT>(cvdata->n_rows,
+ cvdata->n_cols - test.n_cols);
- } else if (start > 0 && end == cvdata->n_cols) {
- assert(train->n_cols == start);
- train->cols(0, train->n_cols - 1) = cvdata->cols(0, start - 1);
+ if (start == 0 && end < cvdata->n_cols)
+ {
+ assert(train->n_cols == cvdata->n_cols - end);
+ train->cols(0, train->n_cols - 1) = cvdata->cols(end, cvdata->n_cols - 1);
+ }
+ else if (start > 0 && end == cvdata->n_cols)
+ {
+ assert(train->n_cols == start);
+ train->cols(0, train->n_cols - 1) = cvdata->cols(0, start - 1);
+ }
+ else
+ {
+ assert(train->n_cols == start + cvdata->n_cols - end);
- } else {
- assert(train->n_cols == start + cvdata->n_cols - end);
+ train->cols(0, start - 1) = cvdata->cols(0, start - 1);
+ train->cols(start, train->n_cols - 1) =
+ cvdata->cols(end, cvdata->n_cols - 1);
+ }
- train->cols(0, start - 1) = cvdata->cols(0, start - 1);
- train->cols(start, train->n_cols - 1)
- = cvdata->cols(end, cvdata->n_cols - 1);
- }
+ assert(train->n_cols + test.n_cols == cvdata->n_cols);
- assert(train->n_cols + test.n_cols == cvdata->n_cols);
+ // Initialize the tree.
+ DTree<eT>* dtree_cv = new DTree<eT>(train);
- // Initializing the tree
- DTree<eT> *dtree_cv = new DTree<eT>(train);
+ // Getting ready to grow the tree...
+ arma::Col<size_t> old_from_new_cv(train->n_cols);
+ for (size_t i = 0; i < old_from_new_cv.n_elem; i++)
+ old_from_new_cv[i] = i;
- // Getting ready to grow the tree
- arma::Col<size_t> old_from_new_cv(train->n_cols);
- for (size_t i = 0; i < old_from_new_cv.n_elem; i++) {
- old_from_new_cv[i] = i;
- }
+ // Grow the tree.
+ old_alpha = 0.0;
+ alpha = dtree_cv->Grow(train, &old_from_new_cv, useVolumeReg, maxLeafSize,
+ minLeafSize);
- // Growing the tree
- old_alpha = 0.0;
- alpha = dtree_cv->Grow(train, &old_from_new_cv,
- useVolumeReg, maxLeafSize,
- minLeafSize);
-
- // sequential pruning with all the values of available
- // alphas and adding values for test values
- std::vector<std::pair<long double, long double> >::iterator it;
- for (it = pruned_sequence.begin();
- it < pruned_sequence.end() -2; ++it) {
-
- // compute test values for this state of the tree
- long double val_cv = 0.0;
- for (size_t i = 0; i < test.n_cols; i++) {
- arma::Col<eT> test_point = test.unsafe_col(i);
- val_cv += dtree_cv->ComputeValue(&test_point);
- }
-
- // update the cv error value
- it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
-
- // getting the new alpha value and pruning accordingly
- old_alpha = sqrt(((it+1)->first) * ((it+2)->first));
- alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
- } // end for
-
- // compute test values for this state of the tree
+ // Sequentially prune with all the values of available alphas and adding
+ // values for test values.
+ std::vector<std::pair<long double, long double> >::iterator it;
+ for (it = pruned_sequence.begin(); it < pruned_sequence.end() -2; ++it)
+ {
+ // Compute test values for this state of the tree.
long double val_cv = 0.0;
- for (size_t i = 0; i < test.n_cols; i++) {
- arma::Col<eT> test_point = test.unsafe_col(i);
- val_cv += dtree_cv->ComputeValue(&test_point);
+ for (size_t i = 0; i < test.n_cols; i++)
+ {
+ arma::Col<eT> test_point = test.unsafe_col(i);
+ val_cv += dtree_cv->ComputeValue(&test_point);
}
- // update the cv error value
+
+ // Update the cv error value.
it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
- test.reset();
- delete train;
+ // Determine the new alpha value and prune accordingly.
+ old_alpha = sqrt(((it + 1)->first) * ((it + 2)->first));
+ alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
+ }
- delete dtree_cv;
+ // Compute test values for this state of the tree.
+ long double val_cv = 0.0;
+ for (size_t i = 0; i < test.n_cols; ++i)
+ {
+ arma::Col<eT> test_point = test.unsafe_col(i);
+ val_cv += dtree_cv->ComputeValue(&test_point);
+ }
- } // end for loop for number of cv-folds
+ // Update the cv error value.
+ it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
- delete cvdata;
+ test.reset();
+ delete train;
- long double optimal_alpha = -1.0,
- best_cv_error = numeric_limits<long double>::max();
- std::vector<std::pair<long double, long double> >::iterator it;
+ delete dtree_cv;
+ }
- for (it = pruned_sequence.begin();
- it < pruned_sequence.end() -1; ++it) {
+ delete cvdata;
- if (it->second < best_cv_error) {
- best_cv_error = it->second;
- optimal_alpha = it->first;
- } // end if
- } // end for
+ long double optimal_alpha = -1.0;
+ long double best_cv_error = numeric_limits<long double>::max();
+ std::vector<std::pair<long double, long double> >::iterator it;
- Log::Info << "Optimal alpha: " << optimal_alpha << endl;
-
- // Initializing the tree
- DTree<eT> *dtree_opt = new DTree<eT>(dataset);
- // Getting ready to grow the tree
- for (size_t i = 0; i < old_from_new.n_elem; i++) {
- old_from_new[i] = i;
+ for (it = pruned_sequence.begin(); it < pruned_sequence.end() -1; ++it)
+ {
+ if (it->second < best_cv_error)
+ {
+ best_cv_error = it->second;
+ optimal_alpha = it->first;
}
+ }
- // Saving the dataset since it would be modified
- // while growing the tree
- new_dataset = new arma::Mat<eT>(*dataset);
+ Log::Info << "Optimal alpha: " << optimal_alpha << "." << std::endl;
- // Growing the tree
- old_alpha = 0.0;
- alpha = dtree_opt->Grow(new_dataset, &old_from_new,
- useVolumeReg, maxLeafSize,
- minLeafSize);
+ // Initialize the tree.
+ DTree<eT>* dtree_opt = new DTree<eT>(dataset);
- // Pruning with optimal alpha
- while (old_alpha < optimal_alpha
- && dtree_opt->subtree_leaves() > 1) {
- old_alpha = alpha;
- alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
+ // Getting ready to grow the tree...
+ for (size_t i = 0; i < old_from_new.n_elem; i++)
+ old_from_new[i] = i;
- // some checks
- assert((alpha < numeric_limits<long double>::max())
- ||(dtree_opt->subtree_leaves() == 1));
- assert(alpha > old_alpha);
- } // end while
+ // Save the dataset since it would be modified while growing the tree.
+ new_dataset = new arma::Mat<eT>(*dataset);
- Log::Info << dtree_opt->subtree_leaves()
- << " leaf nodes in the optimally pruned tree,"
- << " optimal alpha: "
- << old_alpha << endl;
+ // Grow the tree.
+ old_alpha = 0.0;
+ alpha = dtree_opt->Grow(new_dataset, &old_from_new, useVolumeReg, maxLeafSize,
+ minLeafSize);
- delete new_dataset;
+ // Prune with optimal alpha.
+ while ((old_alpha < optimal_alpha) && (dtree_opt->subtree_leaves() > 1))
+ {
+ old_alpha = alpha;
+ alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
- return dtree_opt;
- } // Trainer
+ // Some sanity checks.
+ assert((alpha < numeric_limits<long double>::max()) ||
+ (dtree_opt->subtree_leaves() == 1));
+ assert(alpha > old_alpha);
+ }
+ Log::Info << dtree_opt->subtree_leaves()
+ << " leaf nodes in the optimally pruned tree; optimal alpha: "
+ << old_alpha << "." << std::endl;
+
+ delete new_dataset;
+
+ return dtree_opt;
+}
+
}; // namespace det
}; // namespace mlpack
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-04-16 21:44:50 UTC (rev 12420)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-04-16 22:12:10 UTC (rev 12421)
@@ -19,37 +19,36 @@
namespace mlpack {
namespace det /** Density Estimation Trees */ {
-// This two types in the template are used
-// for two purposes:
-// eT - the type to store the data in (for most practical
-// purposes, storing the data as a float suffices).
-// cT - the type to perform computations in (computations
-// like computing the error, the volume of the node etc.).
-// For high dimensional data, it might be possible that the
-// computation might overflow, so you should use either
-// normalize your data in the (-1, 1) hypercube or use
-// long double or modify this code to perform computations
-// using logarithms.
+/**
+ * This two types in the template are used for two purposes:
+ *
+ * eT - the type to store the data in (for most practical purposes, storing
+ * the data as a float suffices).
+ * cT - the type to perform computations in (computations like computing the
+ * error, the volume of the node etc.).
+ *
+ * For high dimensional data, it might be possible that the computation might
+ * overflow, so you should use either normalize your data in the (-1, 1)
+ * hypercube or use long double or modify this code to perform computations
+ * using logarithms.
+ */
template<typename eT = float,
- typename cT = long double>
-class DTree{
-
- ////////////////////// Member Variables /////////////////////////////////////
-
+ typename cT = long double>
+class DTree
+{
private:
typedef arma::Mat<eT> MatType;
typedef arma::Col<eT> VecType;
typedef arma::Row<eT> RowVecType;
-
// The indices in the complete set of points
// (after all forms of swapping in the original data
- // matrix to align all the points in a node
- // consecutively in the matrix. The 'old_from_new' array
+ // matrix to align all the points in a node
+ // consecutively in the matrix. The 'old_from_new' array
// maps the points back to their original indices.
size_t start_, end_;
-
+
// The split dim for this node
size_t split_dim_;
@@ -66,11 +65,11 @@
size_t subtree_leaves_;
// flag to indicate if this is the root node
- // used to check whether the query point is
+ // used to check whether the query point is
// within the range
bool root_;
- // ratio of number of points in the node to the
+ // ratio of number of points in the node to the
// total number of points (|t| / N)
cT ratio_;
@@ -93,10 +92,8 @@
DTree<eT, cT> *left_;
DTree<eT, cT> *right_;
- ////////////////////// Constructors /////////////////////////////////////////
+public:
-public:
-
////////////////////// Getters and Setters //////////////////////////////////
size_t start() { return start_; }
@@ -127,70 +124,69 @@
private:
cT ComputeNodeError_(size_t total_points);
-
+
bool FindSplit_(MatType* data,
- size_t* split_dim,
- size_t* split_ind,
- cT* left_error,
- cT* right_error,
- size_t maxLeafSize = 10,
- size_t minLeafSize = 5);
+ size_t* split_dim,
+ size_t* split_ind,
+ cT* left_error,
+ cT* right_error,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5);
void SplitData_(MatType* data,
- size_t split_dim,
- size_t split_ind,
- arma::Col<size_t>* old_from_new,
- eT* split_val,
- eT* lsplit_val,
- eT* rsplit_val);
+ size_t split_dim,
+ size_t split_ind,
+ arma::Col<size_t>* old_from_new,
+ eT* split_val,
+ eT* lsplit_val,
+ eT* rsplit_val);
void GetMaxMinVals_(MatType* data,
- VecType* max_vals,
- VecType* min_vals);
+ VecType* max_vals,
+ VecType* min_vals);
bool WithinRange_(VecType* query);
///////////////////// Public Functions //////////////////////////////////////
public:
-
+
DTree();
// Root node initializer
// with the bounding box of the data
// it contains instead of just the data.
- DTree(VecType* max_vals,
- VecType* min_vals,
- size_t total_points);
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points);
// Root node initializer
// with the data, no bounding box.
DTree(MatType* data);
// Non-root node initializers
- DTree(VecType* max_vals,
- VecType* min_vals,
- size_t start,
- size_t end,
- cT error);
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t start,
+ size_t end,
+ cT error);
- DTree(VecType* max_vals,
- VecType* min_vals,
- size_t total_points,
- size_t start,
- size_t end);
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points,
+ size_t start,
+ size_t end);
~DTree();
// Greedily expand the tree
- cT Grow(MatType* data,
- arma::Col<size_t> *old_from_new,
- bool useVolReg = false,
- size_t maxLeafSize = 10,
- size_t minLeafSize = 5);
+ cT Grow(MatType* data,
+ arma::Col<size_t> *old_from_new,
+ bool useVolReg = false,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5);
// perform alpha pruning on the tree
- cT PruneAndUpdate(cT old_alpha,
- bool useVolReg = false);
+ cT PruneAndUpdate(cT old_alpha, bool useVolReg = false);
// compute the density at a given point
cT ComputeValue(VecType* query);
@@ -205,7 +201,7 @@
// of a learned tree.
int FindBucket(VecType* query);
- // This computes the variable importance list
+ // This computes the variable importance list
// for the learned tree.
void ComputeVariableImportance(arma::Col<double> *imps);
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-04-16 21:44:50 UTC (rev 12420)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-04-16 22:12:10 UTC (rev 12421)
@@ -2,7 +2,7 @@
* @file dtree_impl.hpp
* @author Parikshit Ram (pram at cc.gatech.edu)
*
- * Implementations of some declared functions in
+ * Implementations of some declared functions in
* the Density Estimation Tree class.
*
*/
@@ -15,50 +15,46 @@
namespace mlpack{
namespace det {
-// This function computes the l2-error of a given node
-// from the formula - R(t) = -|t|^2 / (N^2 V_t)
+// This function computes the l2-error of a given node from the formula
+// R(t) = -|t|^2 / (N^2 V_t).
template<typename eT, typename cT>
-cT DTree<eT, cT>::
-ComputeNodeError_(size_t total_points)
+cT DTree<eT, cT>::ComputeNodeError_(size_t total_points)
{
size_t node_size = end_ - start_;
cT log_vol_t = 0;
for (size_t i = 0; i < max_vals_->n_elem; i++)
if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
- // using log to prevent overflow
+ // Use log to prevent overflow.
log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
- // check for overflow -- if it doesn't work, try higher precision
- // by default cT = long double, so if you can't work with that
- // there is nothing else you can do - except computing error using
- // log and dealing with everything in log form.
+ // Check for overflow -- if it doesn't work, try higher precision by default
+ // cT = long double, so if you can't work with that there is nothing else you
+ // can do - except computing error using log and dealing with everything in
+ // log form.
assert(std::exp(log_vol_t) > 0.0);
- cT log_neg_error = 2 * std::log((cT) node_size / (cT) total_points)
+ cT log_neg_error = 2 * std::log((cT) node_size / (cT) total_points)
- log_vol_t;
assert(std::exp(log_neg_error) > 0.0);
- cT error = -1.0 * std::exp(log_neg_error);
+ cT error = -1.0 * std::exp(log_neg_error);
return error;
} // ComputeNodeError
-
-// This function find the best split with respect to the L2-error
-// but trying all possible splits.
-// The dataset is the full data set but the start_ and end_
-// are used to obtain the point in this node.
+// This function finds the best split with respect to the L2-error,
+// but trying all possible splits. The dataset is the full data set but the
+// start_ and end_ are used to obtain the point in this node.
template<typename eT, typename cT>
-bool DTree<eT, cT>::
-FindSplit_(MatType* data,
- size_t *split_dim,
- size_t *split_ind,
- cT *left_error,
- cT *right_error,
- size_t maxLeafSize,
- size_t minLeafSize)
+bool DTree<eT, cT>::FindSplit_(MatType* data,
+ size_t *split_dim,
+ size_t *split_ind,
+ cT *left_error,
+ cT *right_error,
+ size_t maxLeafSize,
+ size_t minLeafSize)
{
assert(data->n_rows == max_vals_->n_elem);
assert(data->n_rows == min_vals_->n_elem);
@@ -69,161 +65,160 @@
bool some_split_found = false;
size_t point_mass_in_dim = 0;
- // loop through each dimension
- for (size_t dim = 0; dim < max_vals_->n_elem; dim++) {
- // have to deal with REAL, INTEGER, NOMINAL data
- // differently so have to think of how to do that.
+ // Loop through each dimension.
+ for (size_t dim = 0; dim < max_vals_->n_elem; dim++)
+ {
+ // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
+ // think of how to do that...
eT min = (*min_vals_)[dim], max = (*max_vals_)[dim];
- // checking if there is any scope of splitting in this dim
+ // Check if there is any scope of splitting in this dimension.
if (max - min > 0.0) {
- // initializing all the stuff for this dimension
+ // Initializing all the stuff for this dimension.
bool dim_split_found = false;
- cT min_dim_error = min_error,
- temp_lval = 0.0, temp_rval = 0.0;
+ cT min_dim_error = min_error, temp_lval = 0.0, temp_rval = 0.0;
size_t dim_split_ind = -1;
cT log_range_all_not_dim = 0;
- for (size_t i = 0; i < max_vals_->n_elem; i++) {
-
- if ((*max_vals_)[i] -(*min_vals_)[i] > 0.0 && i != dim) {
- log_range_all_not_dim
- += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
- }
+ for (size_t i = 0; i < max_vals_->n_elem; i++)
+ {
+ if ((*max_vals_)[i] -(*min_vals_)[i] > 0.0 && i != dim)
+ {
+ log_range_all_not_dim +=
+ (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+ }
}
assert(std::exp(log_range_all_not_dim) > 0);
- // get the values for the dimension
+ // Get the values for the dimension.
RowVecType dim_val_vec = data->row(dim).subvec(start_, end_ - 1);
- // sort the values in ascending order
+ // Sort the values in ascending order.
dim_val_vec = arma::sort(dim_val_vec);
- // get ready to go through the sorted list and compute error
+ // Get ready to go through the sorted list and compute error.
assert(dim_val_vec.n_elem > maxLeafSize);
- // enforcing the leaves to have a minimum
- // number of points to avoid spikes
- // one way of doing it is only considering splits resulting
- // in sizes > some constant (minLeafSize)
+
+ // Enforce that the leaves have a minimum number of points to avoid
+ // spikes. One way of doing this is to only consider splits resulting in
+ // sizes > some constant (minLeafSize).
size_t left_child_size = minLeafSize - 1, right_child_size;
- // finding the best split for this dimension
- // need to figure out why there are spikes if
- // this min_leaf_size is enforced here
- for (size_t i = minLeafSize -1;
- i < dim_val_vec.n_elem - minLeafSize;
- i++, left_child_size++) {
+ // Find the best split for this dimension. We need to figure out why
+ // there are spikes if this min_leaf_size is enforced here...
+ for (size_t i = minLeafSize -1; i < dim_val_vec.n_elem - minLeafSize;
+ ++i, ++left_child_size)
+ {
+ eT split, lsplit = dim_val_vec[i], rsplit = dim_val_vec[i + 1];
- eT split, lsplit = dim_val_vec[i],
- rsplit = dim_val_vec[i + 1];
+ if (lsplit < rsplit)
+ {
+ // This makes sense for real continuous data. This kinda corrupts the
+ // data and estimation if the data is ordinal
+ split = (lsplit + rsplit) / 2;
- if (lsplit < rsplit) {
+ // Another way of picking split is using this:
+ // split = left_split;
- // this makes sense for real continuous data
- // This kinda corrupts the data and estimation
- // if the data is ordinal
- split = (lsplit + rsplit) / 2;
+ if (split - min > 0.0 && max - split > 0.0)
+ {
+ assert(std::exp(log_range_all_not_dim +
+ (cT) std::log(split - min)) > 0);
+ assert(std::exp(log_range_all_not_dim +
+ (cT) std::log(max - split)) > 0);
- // Another way of picking split is using
- // split = left_split;
+ cT temp_log_neg_l_error = 2 * std::log((cT) (i + 1) / (cT) total_n)
+ - (log_range_all_not_dim + (cT) std::log(split - min));
- if (split - min > 0.0 && max - split > 0.0) {
+ assert(std::exp(temp_log_neg_l_error) > 0.0);
- assert(std::exp(log_range_all_not_dim
- + (cT) std::log(split - min)) > 0);
- assert(std::exp(log_range_all_not_dim
- + (cT) std::log(max - split)) > 0);
+ cT temp_l_error = -1.0 * std::exp(temp_log_neg_l_error);
- cT temp_log_neg_l_error = 2 * std::log((cT) (i + 1) / (cT) total_n)
- - (log_range_all_not_dim + (cT) std::log(split - min));
+ assert(std::abs(temp_l_error) < std::numeric_limits<cT>::max());
- assert(std::exp(temp_log_neg_l_error) > 0.0);
+ cT temp_log_neg_r_error = 2 * std::log((cT) (n_t - i - 1) /
+ (cT) total_n) - (log_range_all_not_dim +
+ (cT) std::log(max - split));
- cT temp_l_error = -1.0 * std::exp(temp_log_neg_l_error);
+ assert(std::exp(temp_log_neg_r_error) > 0.0);
- assert(std::abs(temp_l_error)
- < std::numeric_limits<cT>::max());
+ right_child_size = n_t - i - 1;
+ assert(right_child_size >= minLeafSize);
- cT temp_log_neg_r_error
- = 2 * std::log((cT) (n_t - i - 1) / (cT) total_n)
- - (log_range_all_not_dim + (cT) std::log(max - split));
+ cT temp_r_error = -1.0 * std::exp(temp_log_neg_r_error);
- assert(std::exp(temp_log_neg_r_error) > 0.0);
+ assert(std::abs(temp_r_error) < std::numeric_limits<cT>::max());
- right_child_size = n_t - i - 1;
- assert(right_child_size >= minLeafSize);
-
- cT temp_r_error = -1.0 * std::exp(temp_log_neg_r_error);
+ //if (temp_l + temp_r <= min_dim_error) {
+ // Why not just less than?
+ if (temp_l_error + temp_r_error < min_dim_error)
+ {
+ min_dim_error = temp_l_error + temp_r_error;
+ temp_lval = temp_l_error;
+ temp_rval = temp_r_error;
+ dim_split_ind = i;
+ dim_split_found = true;
+ } // end if improvement.
+ } // end if split - min > 0 & max - split > 0.
+ } // end if lsplit < rsplit instead of being equal.
+ } // end for loop over all splits in this dimension.
- assert(std::abs(temp_r_error)
- < std::numeric_limits<cT>::max());
-
- //if (temp_l + temp_r <= min_dim_error) {
- // why not just less than
- if (temp_l_error + temp_r_error < min_dim_error) {
- min_dim_error = temp_l_error + temp_r_error;
- temp_lval = temp_l_error;
- temp_rval = temp_r_error;
- dim_split_ind = i;
- dim_split_found = true;
- } // end if improvement
- } // end if split - min > 0 & max - split > 0
- } // end if lsplit < rsplit instead of being equal
- } // end for loop over all splits in this dim
-
dim_val_vec.clear();
- if ((min_dim_error < min_error) && dim_split_found) {
- min_error = min_dim_error;
- *split_dim = dim;
- *split_ind = dim_split_ind;
- *left_error = temp_lval;
- *right_error = temp_rval;
- some_split_found = true;
- } // end if better split found in this dim
- } else {
+ if ((min_dim_error < min_error) && dim_split_found)
+ {
+ min_error = min_dim_error;
+ *split_dim = dim;
+ *split_ind = dim_split_ind;
+ *left_error = temp_lval;
+ *right_error = temp_rval;
+ some_split_found = true;
+ } // end if better split found in this dimension.
+ }
+ else
+ {
point_mass_in_dim++;
- } // end if
- } // end for each dimension
+ }
+ }
return some_split_found;
} // end FindSplit_
template<typename eT, typename cT>
-void DTree<eT, cT>::
-SplitData_(MatType* data,
- size_t split_dim,
- size_t split_ind,
- arma::Col<size_t> *old_from_new,
- eT *split_val,
- eT *lsplit_val,
- eT *rsplit_val)
+void DTree<eT, cT>::SplitData_(MatType* data,
+ size_t split_dim,
+ size_t split_ind,
+ arma::Col<size_t> *old_from_new,
+ eT *split_val,
+ eT *lsplit_val,
+ eT *rsplit_val)
{
- // get the values for the split dim
- RowVecType dim_val_vec
- = data->row(split_dim).subvec(start_, end_ - 1);
+ // Get the values for the split dimension.
+ RowVecType dim_val_vec = data->row(split_dim).subvec(start_, end_ - 1);
- // sort the values
+ // Sort the values.
dim_val_vec = arma::sort(dim_val_vec);
- *lsplit_val = dim_val_vec[split_ind];
- *rsplit_val = dim_val_vec[split_ind + 1];
+ *lsplit_val = dim_val_vec[split_ind];
+ *rsplit_val = dim_val_vec[split_ind + 1];
*split_val = (*lsplit_val + *rsplit_val) / 2 ;
std::vector<bool> left_membership;
left_membership.reserve(end_ - start_);
- for (size_t i = start_; i < end_; i++) {
+ for (size_t i = start_; i < end_; ++i)
+ {
if ((*data)(split_dim, i) > *split_val)
left_membership[i - start_] = false;
- else
+ else
left_membership[i - start_] = true;
}
size_t left_ind = start_, right_ind = end_ - 1;
- for(;;) {
+ for (;;)
+ {
while (left_membership[left_ind - start_] && (left_ind <= right_ind))
left_ind++;
@@ -242,57 +237,53 @@
size_t t = (*old_from_new)[left_ind];
(*old_from_new)[left_ind] = (*old_from_new)[right_ind];
(*old_from_new)[right_ind] = t;
+ }
- } // swap for loop
-
assert(left_ind == right_ind + 1);
-} // end SplitData_
+}
template<typename eT, typename cT>
-void DTree<eT, cT>::
-GetMaxMinVals_(MatType* data,
- VecType *max_vals,
- VecType *min_vals) {
-
+void DTree<eT, cT>::GetMaxMinVals_(MatType* data,
+ VecType *max_vals,
+ VecType *min_vals)
+{
max_vals->set_size(data->n_rows);
min_vals->set_size(data->n_rows);
MatType temp_d = arma::trans(*data);
- for (size_t i = 0; i < temp_d.n_cols; i++) {
-
+ for (size_t i = 0; i < temp_d.n_cols; ++i)
+ {
VecType dim_vals = arma::sort(temp_d.col(i));
(*min_vals)[i] = dim_vals[0];
(*max_vals)[i] = dim_vals[dim_vals.n_elem - 1];
}
-} // end GetMaxMinVals_
+}
template<typename eT, typename cT>
-DTree<eT, cT>::
-DTree() :
- start_(0),
- end_(0),
- max_vals_(NULL),
- min_vals_(NULL),
- left_(NULL),
- right_(NULL)
-{}
+DTree<eT, cT>::DTree() :
+ start_(0),
+ end_(0),
+ max_vals_(NULL),
+ min_vals_(NULL),
+ left_(NULL),
+ right_(NULL)
+{ /* Nothing to do. */ }
// Root node initializers
template<typename eT, typename cT>
-DTree<eT, cT>::
-DTree(VecType* max_vals,
- VecType* min_vals,
- size_t total_points) :
- start_(0),
- end_(total_points),
- max_vals_(max_vals),
- min_vals_(min_vals),
- left_(NULL),
- right_(NULL)
+DTree<eT, cT>::DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points) :
+ start_(0),
+ end_(total_points),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
{
error_ = ComputeNodeError_(total_points);
@@ -302,12 +293,11 @@
template<typename eT, typename cT>
-DTree<eT, cT>::
-DTree(MatType* data) :
- start_(0),
- end_(data->n_cols),
- left_(NULL),
- right_(NULL)
+DTree<eT, cT>::DTree(MatType* data) :
+ start_(0),
+ end_(data->n_cols),
+ left_(NULL),
+ right_(NULL)
{
max_vals_ = new VecType();
min_vals_ = new VecType();
@@ -323,19 +313,18 @@
// Non-root node initializers
template<typename eT, typename cT>
-DTree<eT, cT>::
-DTree(VecType* max_vals,
- VecType* min_vals,
- size_t start,
- size_t end,
- cT error) :
- start_(start),
- end_(end),
- error_(error),
- max_vals_(max_vals),
- min_vals_(min_vals),
- left_(NULL),
- right_(NULL)
+DTree<eT, cT>::DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t start,
+ size_t end,
+ cT error) :
+ start_(start),
+ end_(end),
+ error_(error),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
{
bucket_tag_ = -1;
root_ = false;
@@ -343,20 +332,18 @@
template<typename eT, typename cT>
-DTree<eT, cT>::
-DTree(VecType* max_vals,
- VecType* min_vals,
- size_t total_points,
- size_t start,
- size_t end) :
- start_(start),
- end_(end),
- max_vals_(max_vals),
- min_vals_(min_vals),
- left_(NULL),
- right_(NULL)
+DTree<eT, cT>::DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points,
+ size_t start,
+ size_t end) :
+ start_(start),
+ end_(end),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
{
-
error_ = ComputeNodeError_(total_points);
bucket_tag_ = -1;
@@ -365,15 +352,14 @@
template<typename eT, typename cT>
-DTree<eT, cT>::
-~DTree()
+DTree<eT, cT>::~DTree()
{
if (left_ != NULL)
delete left_;
-
+
if (right_ != NULL)
delete right_;
-
+
if (min_vals_ != NULL)
delete min_vals_;
@@ -384,53 +370,47 @@
// Greedily expand the tree
template<typename eT, typename cT>
-cT DTree<eT, cT>::
-Grow(MatType* data,
- arma::Col<size_t> *old_from_new,
- bool useVolReg,
- size_t maxLeafSize,
- size_t minLeafSize)
-{
+cT DTree<eT, cT>::Grow(MatType* data,
+ arma::Col<size_t> *old_from_new,
+ bool useVolReg,
+ size_t maxLeafSize,
+ size_t minLeafSize)
+{
assert(data->n_rows == max_vals_->n_elem);
assert(data->n_rows == min_vals_->n_elem);
-
+
cT left_g, right_g;
- // computing points ratio
- ratio_ = (cT) (end_ - start_)
- / (cT) old_from_new->n_elem;
+ // Compute points ratio.
+ ratio_ = (cT) (end_ - start_) / (cT) old_from_new->n_elem;
- // computing the v_t_inv:
- // the inverse of the volume of the node
+ // Compute the v_t_inv: the inverse of the volume of the node.
cT log_vol_t = 0;
- for (size_t i = 0; i < max_vals_->n_elem; i++)
+ for (size_t i = 0; i < max_vals_->n_elem; ++i)
if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
- // using log to prevent overflow
+ // Use log to prevent overflow.
log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
- // check for overflow
+ // Check for overflow.
assert(std::exp(log_vol_t) > 0.0);
v_t_inv_ = 1.0 / std::exp(log_vol_t);
- // Checking if node is large enough
+ // Check if node is large enough.
if ((size_t) (end_ - start_) > maxLeafSize) {
- // find the split
+ // Find the split.
size_t dim, split_ind;
cT left_error, right_error;
- if (FindSplit_(data, &dim, &split_ind,
- &left_error, &right_error,
- maxLeafSize, minLeafSize)) {
-
- // Move the data around for the children
- // to have points in a node lie contiguously
- // (to increase efficiency during the training).
+ if (FindSplit_(data, &dim, &split_ind, &left_error, &right_error,
+ maxLeafSize, minLeafSize))
+ {
+ // Move the data around for the children to have points in a node lie
+ // contiguously (to increase efficiency during the training).
eT split_val, lsplit_val, rsplit_val;
- SplitData_(data, dim, split_ind,
- old_from_new, &split_val,
- &lsplit_val, &rsplit_val);
+ SplitData_(data, dim, split_ind, old_from_new, &split_val, &lsplit_val,
+ &rsplit_val);
- // make max and min vals for the children
+ // Make max and min vals for the children.
VecType* max_vals_l = new VecType(*max_vals_);
VecType* max_vals_r = new VecType(*max_vals_);
VecType* min_vals_l = new VecType(*min_vals_);
@@ -439,172 +419,168 @@
(*max_vals_l)[dim] = split_val;
(*min_vals_r)[dim] = split_val;
- // store split dim and split val in the node
+ // Store split dim and split val in the node.
split_value_ = split_val;
split_dim_ = dim;
+ // Recursively grow the children.
+ left_ = new DTree(max_vals_l, min_vals_l, start_, start_ + split_ind + 1,
+ left_error);
+ right_ = new DTree(max_vals_r, min_vals_r, start_ + split_ind + 1, end_,
+ right_error);
- // Recursively growing the children
- left_ = new DTree(max_vals_l, min_vals_l,
- start_, start_ + split_ind + 1,
- left_error);
- right_ = new DTree(max_vals_r, min_vals_r,
- start_ + split_ind + 1, end_,
- right_error);
+ left_g = left_->Grow(data, old_from_new, useVolReg, maxLeafSize,
+ minLeafSize);
+ right_g = right_->Grow(data, old_from_new, useVolReg, maxLeafSize,
+ minLeafSize);
- left_g = left_->Grow(data, old_from_new, useVolReg,
- maxLeafSize, minLeafSize);
- right_g = right_->Grow(data, old_from_new, useVolReg,
- maxLeafSize, minLeafSize);
-
- // storing values of R(T~) and |T~|
+ // Store values of R(T~) and |T~|.
subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
- subtree_leaves_error_ = left_->subtree_leaves_error()
- + right_->subtree_leaves_error();
+ subtree_leaves_error_ = left_->subtree_leaves_error() +
+ right_->subtree_leaves_error();
- // storing the subtree_leaves_v_t_inv
- subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
- + right_->subtree_leaves_v_t_inv();
+ // Store the subtree_leaves_v_t_inv.
+ subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv() +
+ right_->subtree_leaves_v_t_inv();
- // Forming T1 by removing leaves for which
- // R(t) = R(t_L) + R(t_R)
- if ((left_->subtree_leaves() == 1)
- && (right_->subtree_leaves() == 1)) {
- if (left_->error() + right_->error() == error_) {
- delete left_;
- left_ = NULL;
- delete right_;
- right_ = NULL;
- subtree_leaves_ = 1;
- subtree_leaves_error_ = error_;
- subtree_leaves_v_t_inv_ = v_t_inv_;
- } // end if
- } // end if
- } else {
- // no split found so make a leaf out of it
+ // Form T1 by removing leaves for which R(t) = R(t_L) + R(t_R).
+ if ((left_->subtree_leaves() == 1) && (right_->subtree_leaves() == 1))
+ {
+ if (left_->error() + right_->error() == error_)
+ {
+ delete left_;
+ left_ = NULL;
+
+ delete right_;
+ right_ = NULL;
+
+ subtree_leaves_ = 1;
+ subtree_leaves_error_ = error_;
+ subtree_leaves_v_t_inv_ = v_t_inv_;
+ }
+ }
+ }
+ else
+ {
+ // No split found so make a leaf out of it.
subtree_leaves_ = 1;
subtree_leaves_error_ = error_;
subtree_leaves_v_t_inv_ = v_t_inv_;
- } // end if-else
- } else {
- // We can make this a leaf node
+ }
+ }
+ else
+ {
+ // We can make this a leaf node.
assert((size_t) (end_ - start_) >= minLeafSize);
subtree_leaves_ = 1;
subtree_leaves_error_ = error_;
subtree_leaves_v_t_inv_ = v_t_inv_;
+ }
- } // end if-else
-
- // if leaf do not compute g_k(t), else compute, store,
- // and propagate min(g_k(t_L),g_k(t_R),g_k(t)),
- // unless t_L and/or t_R are leaves
- if (subtree_leaves_ == 1) {
+ // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
+ // propagate min(g_k(t_L),g_k(t_R),g_k(t)), unless t_L and/or t_R are leaves.
+ if (subtree_leaves_ == 1)
+ {
return std::numeric_limits<cT>::max();
- } else {
+ }
+ else
+ {
cT g_t;
- if (useVolReg) {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_v_t_inv_ - v_t_inv_);
- } else {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_ - 1);
- }
+ if (useVolReg)
+ g_t = (error_ - subtree_leaves_error_) /
+ (subtree_leaves_v_t_inv_ - v_t_inv_);
+ else
+ g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
assert(g_t > 0.0);
return min(g_t, min(left_g, right_g));
- } // end if-else
+ }
- // need to compute (c_t^2)*r_t for all subtree leaves
- // this is equal to n_t^2/r_t*n^2 = -error_ !!
- // therefore the value we need is actually
- // -1.0*subtree_leaves_error_
-} // Grow
+ // We need to compute (c_t^2)*r_t for all subtree leaves; this is equal to
+ // n_t ^ 2 / r_t * n ^ 2 = -error_. Therefore the value we need is actually
+ // -1.0 * subtree_leaves_error_.
+}
template<typename eT, typename cT>
-cT DTree<eT, cT>::
-PruneAndUpdate(cT old_alpha,
- bool useVolReg)
+cT DTree<eT, cT>::PruneAndUpdate(cT old_alpha, bool useVolReg)
{
- // compute g_t
- if (subtree_leaves_ == 1) { // if leaf
+ // Compute g_t.
+ if (subtree_leaves_ == 1) // If we are a leaf...
+ {
return std::numeric_limits<cT>::max();
- } else {
-
- // compute g_t value for node t
+ }
+ else
+ {
+ // Compute g_t value for node t.
cT g_t;
- if (useVolReg) {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_v_t_inv_ - v_t_inv_);
- } else {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_ - 1);
- }
+ if (useVolReg)
+ g_t = (error_ - subtree_leaves_error_) /
+ (subtree_leaves_v_t_inv_ - v_t_inv_);
+ else
+ g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
- if (g_t > old_alpha) { // go down the tree and update accordingly
- // traverse the children
+ if (g_t > old_alpha)
+ {
+ // Go down the tree and update accordingly. Traverse the children.
cT left_g = left_->PruneAndUpdate(old_alpha, useVolReg);
cT right_g = right_->PruneAndUpdate(old_alpha, useVolReg);
- // update values
- subtree_leaves_ = left_->subtree_leaves()
- + right_->subtree_leaves();
- subtree_leaves_error_ = left_->subtree_leaves_error()
- + right_->subtree_leaves_error();
- subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
- + right_->subtree_leaves_v_t_inv();
+ // Update values.
+ subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
+ subtree_leaves_error_ = left_->subtree_leaves_error() +
+ right_->subtree_leaves_error();
+ subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv() + right_->subtree_leaves_v_t_inv();
- // update g_t value
- if (useVolReg) {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_v_t_inv_ - v_t_inv_);
- } else {
- g_t = (error_ - subtree_leaves_error_)
- / (subtree_leaves_ - 1);
- }
+ // Update g_t value.
+ if (useVolReg)
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_v_t_inv_ - v_t_inv_);
+ else
+ g_t = (error_ - subtree_leaves_error_) / (subtree_leaves_ - 1);
assert(g_t < std::numeric_limits<cT>::max());
- if (left_->subtree_leaves() == 1
- && right_->subtree_leaves() == 1) {
- return g_t;
- } else if (left_->subtree_leaves() == 1) {
- return min(g_t, right_g);
- } else if (right_->subtree_leaves() == 1) {
- return min(g_t, left_g);
- } else {
- return min(g_t, min(left_g, right_g));
- }
- } else { // prune this subtree
+ if (left_->subtree_leaves() == 1 && right_->subtree_leaves() == 1)
+ return g_t;
+ else if (left_->subtree_leaves() == 1)
+ return min(g_t, right_g);
+ else if (right_->subtree_leaves() == 1)
+ return min(g_t, left_g);
+ else
+ return min(g_t, min(left_g, right_g));
- // making this node a leaf node
+ }
+ else
+ {
+ // Prune this subtree.
+ // First, make this node a leaf node.
subtree_leaves_ = 1;
subtree_leaves_error_ = error_;
subtree_leaves_v_t_inv_ = v_t_inv_;
+
delete left_;
left_ = NULL;
delete right_;
right_ = NULL;
- // passing information upward
+
+ // Pass information upward.
return std::numeric_limits<cT>::max();
- } // end if-else for pruning subtree
- } /// if-else for leaf or non-leaf
-} // PruneAndUpdate
+ }
+ }
+}
-// Checking whether a given point is within the
-// bounding box of this node (check generally done
-// at the root, so its the bounding box of the data)
+// Check whether a given point is within the bounding box of this node (check
+// generally done at the root, so its the bounding box of the data).
//
-// Future improvement: To open up the range with epsilons on
-// both sides where epsilon depends on the density near the boundary.
+// Future improvement: Open up the range with epsilons on both sides where
+// epsilon depends on the density near the boundary.
template<typename eT, typename cT>
-bool DTree<eT, cT>::
-WithinRange_(VecType* query)
+bool DTree<eT, cT>::WithinRange_(VecType* query)
{
- for (size_t i = 0; i < query->n_elem; i++)
- if (((*query)[i] < (*min_vals_)[i])
- || ((*query)[i] > (*max_vals_)[i]))
+ for (size_t i = 0; i < query->n_elem; ++i)
+ if (((*query)[i] < (*min_vals_)[i]) || ((*query)[i] > (*max_vals_)[i]))
return false;
return true;
@@ -612,109 +588,112 @@
template<typename eT, typename cT>
-cT DTree<eT, cT>::
-ComputeValue(VecType* query)
+cT DTree<eT, cT>::ComputeValue(VecType* query)
{
assert(query->n_elem == max_vals_->n_elem);
- if (root_ == 1) // if root
- // check if query is within range
+ if (root_ == 1) // If we are the root...
+ // Check if the query is within range.
if (!WithinRange_(query))
return 0.0;
- // end WithinRange_ if-else
- // end root if
- if (subtree_leaves_ == 1) // if leaf
+ if (subtree_leaves_ == 1) // If we are a leaf...
return ratio_ * v_t_inv_;
else
- if ((*query)[split_dim_] <= split_value_) // if left subtree
- // go to left child
+ {
+ if ((*query)[split_dim_] <= split_value_)
+ // If left subtree, go to left child.
return left_->ComputeValue(query);
- else // if right subtree
- // go to right child
+ else // If right subtree, go to right child
return right_->ComputeValue(query);
- // end if-else
-} // ComputeValue
+ }
+}
template<typename eT, typename cT>
-void DTree<eT, cT>::
-WriteTree(size_t level, FILE *fp)
+void DTree<eT, cT>::WriteTree(size_t level, FILE *fp)
{
- if (subtree_leaves_ > 1){
+ if (subtree_leaves_ > 1)
+ {
fprintf(fp, "\n");
- for (size_t i = 0; i < level; i++){
+ for (size_t i = 0; i < level; ++i)
fprintf(fp, "|\t");
- }
- fprintf(fp, "Var. %zu > %lg",
- split_dim_, split_value_);
- right_->WriteTree(level+1, fp);
+ fprintf(fp, "Var. %zu > %lg", split_dim_, split_value_);
+
+ right_->WriteTree(level + 1, fp);
+
fprintf(fp, "\n");
- for (size_t i = 0; i < level; i++){
+ for (size_t i = 0; i < level; ++i)
fprintf(fp, "|\t");
- }
- fprintf(fp, "Var. %zu <= %lg ",
- split_dim_, split_value_);
- left_->WriteTree(level+1, fp);
+ fprintf(fp, "Var. %zu <= %lg ", split_dim_, split_value_);
- } else { // if leaf
+ left_->WriteTree(level + 1, fp);
+ }
+ else // If we are a leaf...
+ {
fprintf(fp, ": f(x)=%Lg", (cT) ratio_ * v_t_inv_);
- if (bucket_tag_ != -1)
+ if (bucket_tag_ != -1)
fprintf(fp, " BT:%d", bucket_tag_);
}
-} // WriteTree
+}
-// indexing the buckets for possible usage later
+// Index the buckets for possible usage later.
template<typename eT, typename cT>
-int DTree<eT, cT>::
-TagTree(int tag)
+int DTree<eT, cT>::TagTree(int tag)
{
- if (subtree_leaves_ == 1) {
+ if (subtree_leaves_ == 1)
+ {
bucket_tag_ = tag;
- return (tag+1);
- } else {
+ return (tag + 1);
+ }
+ else
+ {
return right_->TagTree(left_->TagTree(tag));
}
} // TagTree
template<typename eT, typename cT>
-int DTree<eT, cT>::
-FindBucket(VecType* query)
+int DTree<eT, cT>::FindBucket(VecType* query)
{
assert(query->n_elem == max_vals_->n_elem);
- if (subtree_leaves_ == 1) { // if leaf
+ if (subtree_leaves_ == 1) // If we are a leaf...
+ {
return bucket_tag_;
- } else if ((*query)[split_dim_] <= split_value_) { // if left subtree
- // go to left child
+ }
+ else if ((*query)[split_dim_] <= split_value_)
+ {
+ // If left subtree, go to left child.
return left_->FindBucket(query);
- } else { // if right subtree
- // go to right child
+ }
+ else // If right subtree, go to right child.
+ {
return right_->FindBucket(query);
- } // end if-else
-} // FindBucket
+ }
+}
template<typename eT, typename cT>
-void DTree<eT, cT>::
-ComputeVariableImportance(arma::Col<double> *imps)
+void DTree<eT, cT>::ComputeVariableImportance(arma::Col<double> *imps)
{
- if (subtree_leaves_ == 1) {
- // if leaf, do nothing
+ if (subtree_leaves_ == 1)
+ {
+ // If we are a leaf, do nothing.
return;
- } else {
- // compute the improvement in error because of the
- // split
- double error_improv
- = (double) (error_ - (left_->error() + right_->error()));
+ }
+ else
+ {
+ // Compute the improvement in error because of the split.
+ double error_improv = (double)
+ (error_ - (left_->error() + right_->error()));
(*imps)[split_dim_] += error_improv;
left_->ComputeVariableImportance(imps);
right_->ComputeVariableImportance(imps);
return;
}
-} // ComputeVariableImportance
+}
}; // namespace det
}; // namespace mlpack
More information about the mlpack-svn
mailing list