[mlpack-svn] r12189 - in mlpack/trunk/src/mlpack/methods: . det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 4 11:54:33 EDT 2012
Author: pram
Date: 2012-04-04 11:54:33 -0400 (Wed, 04 Apr 2012)
New Revision: 12189
Added:
mlpack/trunk/src/mlpack/methods/det/
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Density Estimation Trees (DET) added to mlpack
Added: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,239 @@
+/**
+ * @file dt_main.cpp
+ * @ Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file provides an example use of the DET
+ */
+
+#include <mlpack/core.hpp>
+#include "dt_utils.hpp"
+
+using namespace mlpack;
+using namespace std;
+
+PROGRAM_INFO("Density estimation with DET", "This program "
+ "provides an example use of the Density Estimation "
+ "Tree for density estimation. For more details, "
+ "please look at the paper titled "
+ "'Density Estimation Trees'.");
+
+// input data files
+PARAM_STRING_REQ("input/training_set", "The data set on which to "
+ "perform density estimation.", "S");
+PARAM_STRING("input/test_set", "An extra set of test points on "
+ "which to estimate the density given the estimator.",
+ "T", "");
+PARAM_STRING("input/labels", "The labels for the given training data to "
+ "generate the class membership of each leaf (as an "
+ "extra statistic)", "L", "");
+
+// output data files
+PARAM_STRING("output/unpruned_tree_estimates", "The file "
+ "in which to output the estimates on the "
+ "training set from the large unpruned tree.", "u", "");
+PARAM_STRING("output/training_set_estimates", "The file "
+ "in which to output the estimates on the "
+ "training set from the final optimally pruned"
+ " tree.", "s", "");
+PARAM_STRING("output/test_set_estimates", "The file "
+ "in which to output the estimates on the "
+ "test set from the final optimally pruned"
+ " tree.", "t", "");
+PARAM_STRING("output/leaf_class_table", "The file "
+ "in which to output the leaf class membership "
+ "table.", "l", "leaf_class_membership.txt");
+PARAM_STRING("output/tree", "The file in which to print "
+ "the final optimally pruned tree.", "p", "");
+PARAM_STRING("output/vi", "The file to output the "
+ "variable importance values for each feature.",
+ "i", "");
+
+// parameters for the algorithm
+PARAM_INT("param/number_of_classes", "The number of classes present "
+ "in the 'labels' set provided", "C", 0);
+PARAM_INT("param/folds", "The number of folds of cross-validation"
+ " to performed for the estimation (enter 0 for LOOCV)",
+ "F", 10);
+PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf"
+ " in the unpruned fully grown DET.", "N", 5);
+PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf"
+ " in the unpruned fully grown DET.", "M", 10);
+PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the "
+ "option to use a form of regularization similar to "
+ "the usual alpha-pruning in decision tree. But "
+ "instead of regularizing on the number of leaves, "
+ "you regularize on the sum of the inverse of the volume "
+ "of the leaves (meaning you penalize "
+ "low volume leaves.", "R");
+
+// some flags for output of some information about the tree
+PARAM_FLAG("flag/print_tree", "If you just wish to print the tree "
+ "out on the command line.", "P");
+PARAM_FLAG("flag/print_vi", "If you just wish to print the "
+ "variable importance of each feature "
+ "out on the command line.", "I");
+
+int main(int argc, char *argv[]) {
+
+
+ CLI::ParseCommandLine(argc, argv);
+
+ DTree<>* test_pvt = new DTree<>();
+ bool test_success = test_pvt->TestPrivateFunctions();
+
+ if (test_success) {
+ Log::Warn << "Private functions tests successful." << endl;
+ } else {
+ Log::Warn << "Private functions tests failed." << endl;
+ }
+
+ exit(0);
+
+ string train_set_file = CLI::GetParam<string>("S");
+ arma::Mat<float> training_data;
+
+ Log::Info << "Loading training set..." << endl;
+ if (!data::Load(train_set_file, training_data))
+ Log::Fatal << "Training set file "<< train_set_file
+ << " can't be loaded." << endl;
+
+ Log::Info << "Training set (" << training_data.n_rows
+ << ", " << training_data.n_cols
+ << ")" << endl;
+
+ // cross-validation here
+ size_t folds = CLI::GetParam<int>("F");
+ if (folds == 0) {
+ folds = training_data.n_cols;
+ Log::Info << "Starting Leave-One-Out Cross validation" << endl;
+ } else
+ Log::Info << "Starting " << folds
+ << "-fold Cross validation" << endl;
+
+
+
+ // obtaining the optimal tree
+ string unpruned_tree_estimate_file
+ = CLI::GetParam<string>("u");
+
+ Timer::Start("DET/Training");
+ DTree<float> *dtree_opt = dt_utils::Trainer<float>
+ (&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
+ CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
+ Timer::Stop("DET/Training");
+
+ // computing densities for the train points in the
+ // optimal tree
+ FILE *fp = NULL;
+
+ if (CLI::GetParam<string>("s") != "") {
+ string optimal_estimates_file = CLI::GetParam<string>("s");
+ fp = fopen(optimal_estimates_file.c_str(), "w");
+ }
+
+ // Computation timing is more accurate when you do not
+ // perform the printing.
+ Timer::Start("DET/EstimationTime");
+ for (size_t i = 0; i < training_data.n_cols; i++) {
+ arma::Col<float> test_p = training_data.unsafe_col(i);
+ long double f = dtree_opt->ComputeValue(&test_p);
+ if (fp != NULL)
+ fprintf(fp, "%Lg\n", f);
+ } // end for
+ Timer::Stop("DET/EstimationTime");
+
+ if (fp != NULL)
+ fclose(fp);
+
+
+ // computing the density at the provided test points
+ // and outputting the density in the given file.
+ if (CLI::GetParam<string>("T") != "") {
+ string test_file = CLI::GetParam<string>("T");
+ arma::Mat<float> test_data;
+ Log::Info << "Loading test set..." << endl;
+ if (!data::Load(test_file, test_data))
+ Log::Fatal << "Test set file "<< test_file
+ << " can't be loaded." << endl;
+
+ Log::Info << "Test set (" << test_data.n_rows
+ << ", " << test_data.n_cols
+ << ")" << endl;
+
+ fp = NULL;
+
+ if (CLI::GetParam<string>("t") != "") {
+ string test_density_file
+ = CLI::GetParam<string>("t");
+ fp = fopen(test_density_file.c_str(), "w");
+ }
+
+ Timer::Start("DET/TestSetEstimation");
+ for (size_t i = 0; i < test_data.n_cols; i++) {
+ arma::Col<float> test_p = test_data.unsafe_col(i);
+ long double f = dtree_opt->ComputeValue(&test_p);
+ if (fp != NULL)
+ fprintf(fp, "%Lg\n", f);
+ } // end for
+ Timer::Stop("DET/TestSetEstimation");
+
+ if (fp != NULL)
+ fclose(fp);
+ } // Test set estimation
+
+ // printing the final tree
+ if (CLI::HasParam("P")) {
+
+ fp = NULL;
+ if (CLI::GetParam<string>("p") != "") {
+ string print_tree_file = CLI::GetParam<string>("p");
+ fp = fopen(print_tree_file.c_str(), "w");
+
+ if (fp != NULL) {
+ dtree_opt->WriteTree(0, fp);
+ fclose(fp);
+ }
+ } else {
+ dtree_opt->WriteTree(0, stdout);
+ printf("\n");
+ }
+ } // Printing the tree
+
+ // print the leaf memberships for the optimal tree
+ if (CLI::GetParam<string>("L") != "") {
+ std::string labels_file = CLI::GetParam<string>("L");
+ arma::Mat<int> labels;
+
+ Log::Info << "Loading label file..." << endl;
+ if (!data::Load(labels_file, labels))
+ Log::Fatal << "Label file "<< labels_file
+ << " can't be loaded." << endl;
+
+ Log::Info << "Labels (" << labels.n_rows
+ << ", " << labels.n_cols
+ << ")" << endl;
+
+ size_t num_classes = CLI::GetParam<int>("C");
+ if (num_classes == 0)
+ Log::Fatal << "Please provide the number of classes"
+ << " present in the label file" << endl;
+
+ assert(training_data.n_cols == labels.n_cols);
+ assert(labels.n_rows == 1);
+
+ dt_utils::PrintLeafMembership<float>
+ (dtree_opt, training_data, labels, num_classes,
+ (string) CLI::GetParam<string>("l"));
+ } // leaf class membership
+
+
+ if(CLI::HasParam("I")) {
+ dt_utils::PrintVariableImportance<float>
+ (dtree_opt, training_data.n_rows,
+ (string) CLI::GetParam<string>("i"));
+ } // print variable importance
+
+
+ delete dtree_opt;
+ return 0;
+} // end main
Added: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,333 @@
+/**
+ * @file dt_utils.hpp
+ * @ Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file implements functions to perform
+ * different tasks with the Density Tree class.
+ */
+
+#ifndef DT_UTILS_HPP
+#define DT_UTILS_HPP
+
+#include <string>
+
+#include <mlpack/core.hpp>
+#include "dtree.hpp"
+
+using namespace mlpack;
+using namespace std;
+
+
+namespace dt_utils {
+
+ template<typename eT>
+ void PrintLeafMembership(DTree<eT> *dtree,
+ const arma::Mat<eT>& data,
+ const arma::Mat<int>& labels,
+ size_t num_classes,
+ string leaf_class_membership_file = "")
+ {
+ // tag the leaves with numbers
+ int num_leaves = dtree->TagTree(0);
+
+ arma::Mat<size_t> table(num_leaves, num_classes);
+ table.zeros();
+
+ for (size_t i = 0; i < data.n_cols; i++) {
+ arma::Col<eT> test_p = data.unsafe_col(i);
+ int leaf_tag = dtree->FindBucket(&test_p);
+ int label = labels[i];
+ table(leaf_tag, label) += 1;
+ } // end for
+
+ if (leaf_class_membership_file == "") {
+ Log::Warn << "Leaf Membership: Classes in each leaf" << endl
+ << table << endl;
+ } else {
+ // create a stream for the file
+ ofstream outfile(leaf_class_membership_file.c_str());
+ if (outfile.good()) {
+ Log::Warn << "Leaf Membership: Classes in each leaf"
+ << " printed in '" << leaf_class_membership_file
+ << "'" << endl;
+ outfile << table;
+ } else {
+ Log::Warn << "Can't open '" << leaf_class_membership_file
+ << "'" << endl;
+ }
+ outfile.close();
+ }
+
+ return;
+ // maybe print some more statistics if these work out well
+ } // PrintLeafMembership
+
+
+ template<typename eT>
+ void PrintVariableImportance(DTree<eT> *dtree,
+ size_t num_dims,
+ string vi_file = "")
+ {
+ arma::Col<double> *imps
+ = new arma::Col<double>(num_dims);
+
+ for (size_t i = 0; i < imps->n_elem; i++)
+ (*imps)[i] = 0.0;
+
+ dtree->ComputeVariableImportance(imps);
+ double max = 0.0;
+ for (size_t i = 0; i < imps->n_elem; i++)
+ if ((*imps)[i] > max)
+ max = (*imps)[i];
+ Log::Warn << "Max. variable importance: " << max << endl;
+
+
+ if (vi_file == "") {
+ Log::Warn << "Variable importance: " << endl
+ << imps->t();
+ } else {
+ ofstream outfile(vi_file.c_str());
+ if (outfile.good()) {
+ Log::Warn << "Variable importance printed in '"
+ << vi_file << "'" << endl;
+ outfile << *imps;
+ } else {
+ Log::Warn << "Can't open '" << vi_file
+ << "'" << endl;
+ }
+ outfile.close();
+ }
+
+ return;
+ } // PrintVariableImportance
+
+
+ // This function trains the optimal decision tree
+ // using the given number of folds
+ template<typename eT>
+ DTree<eT> *Trainer(arma::Mat<eT>* dataset,
+ size_t folds,
+ bool useVolumeReg = false,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5,
+ string unprunedTreeOutput = "")
+ {
+ // Initializing the tree
+ DTree<eT> *dtree = new DTree<eT>(dataset);
+
+ // Getting ready to grow the tree
+ arma::Col<size_t> old_from_new(dataset->n_cols);
+ for (size_t i = 0; i < old_from_new.n_elem; i++) {
+ old_from_new[i] = i;
+ }
+
+ // Saving the dataset since it would be modified
+ // while growing the tree
+ arma::Mat<eT>* new_dataset = new arma::Mat<eT>(*dataset);
+
+ // Growing the tree
+ long double old_alpha = 0.0;
+ long double alpha = dtree->Grow(new_dataset, &old_from_new,
+ useVolumeReg, maxLeafSize,
+ minLeafSize);
+ // clear the data set
+ delete new_dataset;
+
+ Log::Info << dtree->subtree_leaves()
+ << " leaf nodes in the tree with full data, min_alpha: "
+ << alpha << endl;
+
+ // computing densities for the train points in the
+ // full tree if asked for.
+ if (unprunedTreeOutput != "") {
+
+ ofstream outfile(unprunedTreeOutput.c_str());
+ if (outfile.good()) {
+ for (size_t i = 0; i < dataset->n_cols; i++) {
+ arma::Col<eT> test_p = dataset->unsafe_col(i);
+ outfile << dtree->ComputeValue(&test_p) << endl;
+ } // end for
+ } else {
+ Log::Warn << "Can't open '" << unprunedTreeOutput
+ << "'" << endl;
+ }
+
+ outfile.close();
+
+ } // if unprunedTreeOutput
+
+ // sequential pruning and saving the alpha vals and the
+ // values of c_t^2*r_t
+ std::vector<std::pair<long double, long double> > pruned_sequence;
+ while (dtree->subtree_leaves() > 1) {
+
+ std::pair<long double, long double> tree_seq
+ (old_alpha, -1.0 * dtree->subtree_leaves_error());
+ pruned_sequence.push_back(tree_seq);
+ old_alpha = alpha;
+ alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
+
+ // some checks
+ assert((alpha < std::numeric_limits<long double>::max())
+ ||(dtree->subtree_leaves() == 1));
+ assert(alpha > old_alpha);
+ assert(dtree->subtree_leaves_error() >= -1.0 * tree_seq.second);
+
+ } // end while
+
+ std::pair<long double, long double> tree_seq
+ (old_alpha, -1.0 * dtree->subtree_leaves_error());
+ pruned_sequence.push_back(tree_seq);
+
+ Log::Info << pruned_sequence.size()
+ << " trees in the sequence, max_alpha: "
+ << old_alpha << endl;
+
+ delete dtree;
+
+ arma::Mat<eT>* cvdata = new arma::Mat<eT>(*dataset);
+
+ size_t test_size = dataset->n_cols / folds;
+
+ // Go through each fold
+ for (size_t fold = 0; fold < folds; fold++) {
+
+ // break up data into train and test set
+ size_t start = fold * test_size,
+ end = std::min((fold + 1) * test_size, (size_t) cvdata->n_cols);
+ arma::Mat<eT> test = cvdata->cols(start, end - 1);
+ arma::Mat<eT>* train
+ = new arma::Mat<eT>(cvdata->n_rows,
+ cvdata->n_cols - test.n_cols);
+
+ if (start == 0 && end < cvdata->n_cols) {
+ assert(train->n_cols == cvdata->n_cols - end);
+ train->cols(0, train->n_cols - 1)
+ = cvdata->cols(end, cvdata->n_cols - 1);
+
+
+ } else if (start > 0 && end == cvdata->n_cols) {
+ assert(train->n_cols == start);
+ train->cols(0, train->n_cols - 1) = cvdata->cols(0, start - 1);
+
+ } else {
+ assert(train->n_cols == start + cvdata->n_cols - end);
+
+ train->cols(0, start - 1) = cvdata->cols(0, start - 1);
+ train->cols(start, train->n_cols - 1)
+ = cvdata->cols(end, cvdata->n_cols - 1);
+ }
+
+ assert(train->n_cols + test.n_cols == cvdata->n_cols);
+
+ // Initializing the tree
+ DTree<eT> *dtree_cv = new DTree<eT>(train);
+
+ // Getting ready to grow the tree
+ arma::Col<size_t> old_from_new_cv(train->n_cols);
+ for (size_t i = 0; i < old_from_new_cv.n_elem; i++) {
+ old_from_new_cv[i] = i;
+ }
+
+ // Growing the tree
+ old_alpha = 0.0;
+ alpha = dtree_cv->Grow(train, &old_from_new_cv,
+ useVolumeReg, maxLeafSize,
+ minLeafSize);
+
+ // sequential pruning with all the values of available
+ // alphas and adding values for test values
+ std::vector<std::pair<long double, long double> >::iterator it;
+ for (it = pruned_sequence.begin();
+ it < pruned_sequence.end() -2; ++it) {
+
+ // compute test values for this state of the tree
+ long double val_cv = 0.0;
+ for (size_t i = 0; i < test.n_cols; i++) {
+ arma::Col<eT> test_point = test.unsafe_col(i);
+ val_cv += dtree_cv->ComputeValue(&test_point);
+ }
+
+ // update the cv error value
+ it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+
+ // getting the new alpha value and pruning accordingly
+ old_alpha = sqrt(((it+1)->first) * ((it+2)->first));
+ alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
+ } // end for
+
+ // compute test values for this state of the tree
+ long double val_cv = 0.0;
+ for (size_t i = 0; i < test.n_cols; i++) {
+ arma::Col<eT> test_point = test.unsafe_col(i);
+ val_cv += dtree_cv->ComputeValue(&test_point);
+ }
+ // update the cv error value
+ it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+
+ test.reset();
+ delete train;
+
+ delete dtree_cv;
+
+ } // end for loop for number of cv-folds
+
+ delete cvdata;
+
+ long double optimal_alpha = -1.0,
+ best_cv_error = numeric_limits<long double>::max();
+ std::vector<std::pair<long double, long double> >::iterator it;
+
+ for (it = pruned_sequence.begin();
+ it < pruned_sequence.end() -1; ++it) {
+
+ if (it->second < best_cv_error) {
+ best_cv_error = it->second;
+ optimal_alpha = it->first;
+ } // end if
+ } // end for
+
+ Log::Info << "Optimal alpha: " << optimal_alpha << endl;
+
+ // Initializing the tree
+ DTree<eT> *dtree_opt = new DTree<eT>(dataset);
+ // Getting ready to grow the tree
+ for (size_t i = 0; i < old_from_new.n_elem; i++) {
+ old_from_new[i] = i;
+ }
+
+ // Saving the dataset since it would be modified
+ // while growing the tree
+ new_dataset = new arma::Mat<eT>(*dataset);
+
+ // Growing the tree
+ old_alpha = 0.0;
+ alpha = dtree_opt->Grow(new_dataset, &old_from_new,
+ useVolumeReg, maxLeafSize,
+ minLeafSize);
+
+ // Pruning with optimal alpha
+ while (old_alpha < optimal_alpha
+ && dtree_opt->subtree_leaves() > 1) {
+ old_alpha = alpha;
+ alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
+
+ // some checks
+ assert((alpha < numeric_limits<long double>::max())
+ ||(dtree_opt->subtree_leaves() == 1));
+ assert(alpha > old_alpha);
+ } // end while
+
+ Log::Info << dtree_opt->subtree_leaves()
+ << " leaf nodes in the optimally pruned tree,"
+ << " optimal alpha: "
+ << old_alpha << endl;
+
+ delete new_dataset;
+
+ return dtree_opt;
+ } // Trainer
+
+}; // namespace dt_utils
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,412 @@
+/**
+ * @file dtree.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Density Tree class
+ *
+ */
+
+#ifndef DTREE_HPP
+#define DTREE_HPP
+
+#include <assert.h>
+#include <vector>
+
+#include <mlpack/core.hpp>
+
+using namespace mlpack;
+using namespace std;
+
+
+// This two types in the template are used
+// for two purposes:
+// eT - the type to store the data in (for most practical
+// purposes, storing the data as a float suffices).
+// cT - the type to perform computations in (computations
+// like computing the error, the volume of the node etc.).
+// For high dimensional data, it might be possible that the
+// computation might overflow, so you should use either
+// normalize your data in the (-1, 1) hypercube or use
+// long double or modify this code to perform computations
+// using logarithms.
+template<typename eT = float,
+ typename cT = long double>
+class DTree{
+
+ ////////////////////// Member Variables /////////////////////////////////////
+
+ private:
+
+ typedef arma::Mat<eT> MatType;
+ typedef arma::Col<eT> VecType;
+ typedef arma::Row<eT> RowVecType;
+
+
+ // The indices in the complete set of points
+ // (after all forms of swapping in the original data
+ // matrix to align all the points in a node
+ // consecutively in the matrix. The 'old_from_new' array
+ // maps the points back to their original indices.
+ size_t start_, end_;
+
+ // The split dim for this node
+ size_t split_dim_;
+
+ // The split val on that dim
+ eT split_value_;
+
+ // L2-error of the node
+ cT error_;
+
+ // sum of the error of the leaves of the subtree
+ cT subtree_leaves_error_;
+
+ // number of leaves of the subtree
+ size_t subtree_leaves_;
+
+ // flag to indicate if this is the root node
+ // used to check whether the query point is
+ // within the range
+ bool root_;
+
+ // ratio of number of points in the node to the
+ // total number of points (|t| / N)
+ cT ratio_;
+
+ // the inverse of volume of the node
+ cT v_t_inv_;
+
+ // sum of the reciprocal of the inverse v_ts
+ // the leaves of this subtree
+ cT subtree_leaves_v_t_inv_;
+
+ // since we are using uniform density, we need
+ // the max and min of every dimension for every node
+ VecType* max_vals_;
+ VecType* min_vals_;
+
+ // the tag for the leaf used for hashing points
+ int bucket_tag_;
+
+ // The children
+ DTree<eT, cT> *left_;
+ DTree<eT, cT> *right_;
+
+ ////////////////////// Constructors /////////////////////////////////////////
+
+public:
+
+ ////////////////////// Getters and Setters //////////////////////////////////
+ size_t start() { return start_; }
+
+ size_t end() { return end_; }
+
+ size_t split_dim() { return split_dim_; }
+
+ eT split_value() { return split_value_; }
+
+ cT error() { return error_; }
+
+ cT subtree_leaves_error() { return subtree_leaves_error_; }
+
+ size_t subtree_leaves() { return subtree_leaves_; }
+
+ cT ratio() { return ratio_; }
+
+ cT v_t_inv() { return v_t_inv_; }
+
+ cT subtree_leaves_v_t_inv() { return subtree_leaves_v_t_inv_; }
+
+ DTree<eT, cT>* left() { return left_; }
+ DTree<eT, cT>* right() { return right_; }
+
+ bool root() { return root_; }
+
+ ////////////////////// Private Functions ////////////////////////////////////
+ private:
+
+ cT ComputeNodeError_(size_t total_points);
+
+ bool FindSplit_(MatType* data,
+ size_t* split_dim,
+ size_t* split_ind,
+ cT* left_error,
+ cT* right_error,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5);
+
+ void SplitData_(MatType* data,
+ size_t split_dim,
+ size_t split_ind,
+ arma::Col<size_t>* old_from_new,
+ eT* split_val,
+ eT* lsplit_val,
+ eT* rsplit_val);
+
+ void GetMaxMinVals_(MatType* data,
+ VecType* max_vals,
+ VecType* min_vals);
+
+ bool WithinRange_(VecType* query);
+
+ ///////////////////// Public Functions //////////////////////////////////////
+ public:
+
+ DTree();
+
+ // Root node initializer
+ // with the bounding box of the data
+ // it contains instead of just the data.
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points);
+
+ // Root node initializer
+ // with the data, no bounding box.
+ DTree(MatType* data);
+
+ // Non-root node initializers
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t start,
+ size_t end,
+ cT error);
+
+ DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points,
+ size_t start,
+ size_t end);
+
+ ~DTree();
+
+ // Greedily expand the tree
+ cT Grow(MatType* data,
+ arma::Col<size_t> *old_from_new,
+ bool useVolReg = false,
+ size_t maxLeafSize = 10,
+ size_t minLeafSize = 5);
+
+ // perform alpha pruning on the tree
+ cT PruneAndUpdate(cT old_alpha,
+ bool useVolReg = false);
+
+ // compute the density at a given point
+ cT ComputeValue(VecType* query);
+
+ // print the tree (in a DFS manner)
+ void WriteTree(size_t level, FILE *fp);
+
+ // indexing the buckets for possible usage later
+ int TagTree(int tag);
+
+ // This is used to generate the class membership
+ // of a learned tree.
+ int FindBucket(VecType* query);
+
+ // This computes the variable importance list
+ // for the learned tree.
+ void ComputeVariableImportance(arma::Col<double> *imps);
+
+ // A public function to test the private functions
+ bool TestPrivateFunctions() {
+
+
+ bool return_flag = true;
+
+ // Create data
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ // save current data
+ size_t true_start = start_, true_end = end_;
+ VecType* true_max_vals = max_vals_;
+ VecType* true_min_vals = min_vals_;
+ cT true_error = error_;
+
+
+ // Test GetMaxMinVals_
+ min_vals_ = NULL;
+ max_vals_ = NULL;
+ max_vals_ = new VecType();
+ min_vals_ = new VecType();
+
+ GetMaxMinVals_(&test_data, max_vals_, min_vals_);
+
+ if ((*max_vals_)[0] != 7 || (*min_vals_)[0] != 3) {
+ Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+ return_flag = false;
+ }
+
+ if ((*max_vals_)[1] != 7 || (*min_vals_)[1] != 0) {
+ Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+ return_flag = false;
+ }
+
+ if ((*max_vals_)[2] != 8 || (*min_vals_)[2] != 1) {
+ Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+ return_flag = false;
+ }
+
+ // Test ComputeNodeError_
+ start_ = 0;
+ end_ = 5;
+ cT node_error = ComputeNodeError_(5);
+ cT log_vol = (cT) std::log(4) + (cT) std::log(7) + (cT) std::log(7);
+ cT true_node_error = -1.0 * std::exp(-log_vol);
+
+ if (std::abs(node_error - true_node_error) > 1e-7) {
+ Log::Warn << "Test: True error : " << true_node_error
+ << ", Computed error: " << node_error
+ << ", diff: " << std::abs(node_error - true_node_error)
+ << endl;
+ return_flag = false;
+ }
+
+ start_ = 3;
+ end_ = 5;
+ node_error = ComputeNodeError_(5);
+ true_node_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5) - log_vol);
+
+ if (std::abs(node_error - true_node_error) > 1e-7) {
+ Log::Warn << "Test: True error : " << true_node_error
+ << ", Computed error: " << node_error
+ << ", diff: " << std::abs(node_error - true_node_error)
+ << endl;
+ return_flag = false;
+ }
+
+ // Test WithinRange_
+
+ VecType test_query(3);
+ test_query << 4.5 << 2.5 << 2;
+
+ if (!WithinRange_(&test_query)) {
+ Log::Warn << "Test: WithinRange_ failed" << endl;
+ return_flag = false;
+ }
+
+ test_query << 8.5 << 2.5 << 2;
+
+ if (WithinRange_(&test_query)) {
+ Log::Warn << "Test: WithinRange_ failed" << endl;
+ return_flag = false;
+ }
+
+ // Test FindSplit_
+ start_ = 0;
+ end_ = 5;
+ error_ = ComputeNodeError_(5);
+
+ size_t ob_dim, true_dim, ob_ind, true_ind;
+ cT true_left_error, ob_left_error, true_right_error, ob_right_error;
+
+ true_dim = 2;
+ true_ind = 1;
+ true_left_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5)
+ - (std::log((cT) 7) + std::log((cT) 4)
+ + std::log((cT) 4.5)));
+ true_right_error = -1.0 * std::exp(2 * std::log((cT) 3 / (cT) 5)
+ - (std::log((cT) 7) + std::log((cT) 4)
+ + std::log((cT) 2.5)));
+
+ if(!FindSplit_(&test_data, &ob_dim, &ob_ind,
+ &ob_left_error, &ob_right_error, 2, 1)) {
+ Log::Warn << "Test: FindSplit_ returns false." << endl;
+ return_flag = false;
+ }
+
+ if (true_dim != ob_dim) {
+ Log::Warn << "Test: FindSplit_ - True dim: " << true_dim
+ << ", Obtained dim: " << ob_dim << endl;
+ return_flag = false;
+ }
+
+ if (true_ind != ob_ind) {
+ Log::Warn << "Test: FindSplit_ - True ind: " << true_ind
+ << ", Obtained ind: " << ob_ind << endl;
+ return_flag = false;
+ }
+
+ if (std::abs(true_left_error - ob_left_error) > 1e-7) {
+ Log::Warn << "Test: FindSplit_ - True left_error: " << true_left_error
+ << ", Obtained left_error: " << ob_left_error
+ << ", diff: " << std::abs(true_left_error - ob_left_error)
+ << endl;
+ return_flag = false;
+ }
+
+ if (std::abs(true_right_error - ob_right_error) > 1e-7) {
+ Log::Warn << "Test: FindSplit_ - True right_error: " << true_right_error
+ << ", Obtained right_error: " << ob_right_error
+ << ", diff: " << std::abs(true_right_error - ob_right_error)
+ << endl;
+ return_flag = false;
+ }
+
+ // Test SplitData_
+ MatType split_test_data(test_data);
+ arma::Col<size_t> o_test(5);
+ o_test << 1 << 2 << 3 << 4 << 5;
+
+ start_ = 0;
+ end_ = 5;
+ size_t split_dim = 2, split_ind = 1;
+ eT true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
+ true_rsplit_val, ob_rsplit_val;
+
+ true_lsplit_val = 5;
+ true_rsplit_val = 6;
+ true_split_val = (true_lsplit_val + true_rsplit_val) / 2;
+
+ SplitData_(&split_test_data, split_dim, split_ind,
+ &o_test, &ob_split_val,
+ &ob_lsplit_val, &ob_rsplit_val);
+
+ if (o_test[0] != 1 || o_test[1] != 4 || o_test[2] != 3
+ || o_test[3] != 2 || o_test[4] != 5) {
+ Log::Warn << "Test: SplitData_ - OFW should be 1,4,3,2,5"
+ << ", is " << o_test.t();
+ return_flag = false;
+ }
+
+ if (true_split_val != ob_split_val) {
+ Log::Warn << "Test: SplitData_ - True split val: " << true_split_val
+ << ", Ob split val: " << ob_split_val << endl;
+ return_flag = false;
+ }
+
+ if (true_lsplit_val != ob_lsplit_val) {
+ Log::Warn << "Test: SplitData_ - True lsplit val: " << true_lsplit_val
+ << ", Ob lsplit val: " << ob_lsplit_val << endl;
+ return_flag = false;
+ }
+
+ if (true_rsplit_val != ob_rsplit_val) {
+ Log::Warn << "Test: SplitData_ - True rsplit val: " << true_rsplit_val
+ << ", Ob rsplit val: " << ob_rsplit_val << endl;
+ return_flag = false;
+ }
+
+
+ // restore original values
+ delete max_vals_;
+ delete min_vals_;
+ max_vals_ = true_max_vals;
+ min_vals_ = true_min_vals;
+ start_ = true_start;
+ end_ = true_end;
+ error_ = true_error;
+
+ return return_flag;
+
+ } // TestPrivateFunctions
+
+}; // Class DTree
+
+#include "dtree_impl.hpp"
+
+#endif
Added: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,770 @@
+ /**
+ * @file dtree_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Implementations of some declared functions in
+ * the Density Tree class.
+ *
+ */
+
+#ifndef DTREE_IMPL_HPP
+#define DTREE_IMPL_HPP
+
+#include "dtree.hpp"
+
+
+
+// This function computes the l2-error of a given node
+// from the formula - R(t) = -|t|^2 / (N^2 V_t)
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+ComputeNodeError_(size_t total_points)
+{
+ size_t node_size = end_ - start_;
+
+ cT log_vol_t = 0;
+ for (size_t i = 0; i < max_vals_->n_elem; i++)
+ if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+ // using log to prevent overflow
+ log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+
+ // check for overflow -- if it doesn't work, try higher precision
+ // by default cT = long double, so if you can't work with that
+ // there is nothing else you can do - except computing error using
+ // log and dealing with everything in log form.
+ assert(std::exp(log_vol_t) > 0.0);
+
+ cT log_neg_error = 2 * std::log((cT) node_size / (cT) total_points)
+ - log_vol_t;
+
+ assert(std::exp(log_neg_error) > 0.0);
+
+ cT error = -1.0 * std::exp(log_neg_error);
+
+ return error;
+} // ComputeNodeError
+
+
+// This function find the best split with respect to the L2-error
+// but trying all possible splits.
+// The dataset is the full data set but the start_ and end_
+// are used to obtain the point in this node.
+template<typename eT, typename cT>
+bool DTree<eT, cT>::
+FindSplit_(MatType* data,
+ size_t *split_dim,
+ size_t *split_ind,
+ cT *left_error,
+ cT *right_error,
+ size_t maxLeafSize,
+ size_t minLeafSize)
+{
+ assert(data->n_rows == max_vals_->n_elem);
+ assert(data->n_rows == min_vals_->n_elem);
+
+ size_t total_n = data->n_cols, n_t = end_ - start_;
+
+ cT min_error = error_;
+ bool some_split_found = false;
+ size_t point_mass_in_dim = 0;
+
+ // printf("In FindSplit %Lg\n", error_);fflush(NULL);
+
+ // loop through each dimension
+ for (size_t dim = 0; dim < max_vals_->n_elem; dim++) {
+ // have to deal with REAL, INTEGER, NOMINAL data
+ // differently so have to think of how to do that.
+ eT min = (*min_vals_)[dim], max = (*max_vals_)[dim];
+
+ // checking if there is any scope of splitting in this dim
+ if (max - min > 0.0) {
+ // initializing all the stuff for this dimension
+ bool dim_split_found = false;
+ cT min_dim_error = min_error,
+ temp_lval = 0.0, temp_rval = 0.0;
+ size_t dim_split_ind = -1;
+
+ cT log_range_all_not_dim = 0;
+ for (size_t i = 0; i < max_vals_->n_elem; i++) {
+
+ if ((*max_vals_)[i] -(*min_vals_)[i] > 0.0 && i != dim) {
+ log_range_all_not_dim
+ += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+ }
+ }
+
+ assert(std::exp(log_range_all_not_dim) > 0);
+
+ // get the values for the dimension
+ // NEED TO CHECK: if this works correctly
+ RowVecType dim_val_vec = data->row(dim).subvec(start_, end_ - 1);
+
+ // sort the values in ascending order
+ dim_val_vec = arma::sort(dim_val_vec);
+
+ // get ready to go through the sorted list and compute error
+ assert(dim_val_vec.n_elem > maxLeafSize);
+ // enforcing the leaves to have a minimum of MIN_LEAF_SIZE
+ // number of points to avoid spikes
+
+ // one way of doing it is only considering splits resulting
+ // in sizes > MIN_LEAF_SIZE
+ size_t left_child_size = minLeafSize - 1, right_child_size;
+
+ // finding the best split for this dimension
+ // need to figure out why there are spikes if
+ // this min_leaf_size is enforced here
+ for (size_t i = minLeafSize -1;
+ i < dim_val_vec.n_elem - minLeafSize;
+ i++, left_child_size++) {
+
+ eT split, lsplit = dim_val_vec[i],
+ rsplit = dim_val_vec[i + 1];
+
+ if (lsplit < rsplit) {
+
+ // this makes sense for real continuous data
+ // This kinda corrupts the data and estimation
+ // if the data is ordinal
+ split = (lsplit + rsplit) / 2;
+
+ // Another way of picking split is using
+ // split = left_split;
+
+ if (split - min > 0.0 && max - split > 0.0) {
+
+ assert(std::exp(log_range_all_not_dim
+ + (cT) std::log(split - min)) > 0);
+ assert(std::exp(log_range_all_not_dim
+ + (cT) std::log(max - split)) > 0);
+
+ cT temp_log_neg_l_error = 2 * std::log((cT) (i + 1) / (cT) total_n)
+ - (log_range_all_not_dim + (cT) std::log(split - min));
+
+ assert(std::exp(temp_log_neg_l_error) > 0.0);
+
+ cT temp_l_error = -1.0 * std::exp(temp_log_neg_l_error);
+
+// = -1.0 * ((cT)(i + 1) / (cT)total_n)
+// * ((cT)(i + 1) / (cT)total_n)
+// / (std::exp(log_range_all_not_dim
+// + (cT) std::log(split - min)));
+
+ assert(std::abs(temp_l_error)
+ < std::numeric_limits<cT>::max());
+
+ cT temp_log_neg_r_error
+ = 2 * std::log((cT) (n_t - i - 1) / (cT) total_n)
+ - (log_range_all_not_dim + (cT) std::log(max - split));
+
+ assert(std::exp(temp_log_neg_r_error) > 0.0);
+
+ right_child_size = n_t - i - 1;
+ assert(right_child_size >= minLeafSize);
+
+ cT temp_r_error = -1.0 * std::exp(temp_log_neg_r_error);
+
+// = -1.0 * ((cT) (n_t - i - 1) / (cT)total_n)
+// * ((cT) (n_t - i - 1) / (cT)total_n)
+// / (std::exp(log_range_all_not_dim
+// + (cT) std::log(max - split)));
+
+ assert(std::abs(temp_r_error)
+ < std::numeric_limits<cT>::max());
+
+ //if (temp_l + temp_r <= min_dim_error) {
+ // why not just less than
+ if (temp_l_error + temp_r_error < min_dim_error) {
+ min_dim_error = temp_l_error + temp_r_error;
+ temp_lval = temp_l_error;
+ temp_rval = temp_r_error;
+ dim_split_ind = i;
+ dim_split_found = true;
+ } // end if improvement
+ } // end if split - min > 0 & max - split > 0
+ } // end if lsplit < rsplit instead of being equal
+ } // end for loop over all splits in this dim
+
+ dim_val_vec.clear();
+
+ if ((min_dim_error < min_error) && dim_split_found) {
+ min_error = min_dim_error;
+ *split_dim = dim;
+ *split_ind = dim_split_ind;
+ *left_error = temp_lval;
+ *right_error = temp_rval;
+ some_split_found = true;
+ } // end if better split found in this dim
+ } else {
+ point_mass_in_dim++;
+ } // end if
+ } // end for each dimension
+
+ return some_split_found;
+} // end FindSplit_
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+SplitData_(MatType* data,
+ size_t split_dim,
+ size_t split_ind,
+ arma::Col<size_t> *old_from_new,
+ eT *split_val,
+ eT *lsplit_val,
+ eT *rsplit_val)
+{
+ // get the values for the split dim
+ RowVecType dim_val_vec
+ = data->row(split_dim).subvec(start_, end_ - 1);
+
+ // sort the values
+ dim_val_vec = arma::sort(dim_val_vec);
+
+ *lsplit_val = dim_val_vec[split_ind];
+ *rsplit_val = dim_val_vec[split_ind + 1];
+ *split_val = (*lsplit_val + *rsplit_val) / 2 ;
+
+ std::vector<bool> left_membership;
+ left_membership.reserve(end_ - start_);
+
+ for (size_t i = start_; i < end_; i++) {
+ if ((*data)(split_dim, i) > *split_val)
+ left_membership[i - start_] = false;
+ else
+ left_membership[i - start_] = true;
+ }
+
+ size_t left_ind = start_, right_ind = end_ - 1;
+ for(;;) {
+ while (left_membership[left_ind - start_] && (left_ind <= right_ind))
+ left_ind++;
+
+ while (!left_membership[right_ind - start_] && (left_ind <= right_ind))
+ right_ind--;
+
+ if (left_ind > right_ind)
+ break;
+
+
+ data->swap_cols(left_ind, right_ind);
+ bool tmp = left_membership[left_ind - start_];
+ left_membership[left_ind - start_] = left_membership[right_ind - start_];
+ left_membership[right_ind - start_] = tmp;
+
+ size_t t = (*old_from_new)[left_ind];
+ (*old_from_new)[left_ind] = (*old_from_new)[right_ind];
+ (*old_from_new)[right_ind] = t;
+
+ } // swap for loop
+
+ assert(left_ind == right_ind + 1);
+} // end SplitData_
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+GetMaxMinVals_(MatType* data,
+ VecType *max_vals,
+ VecType *min_vals) {
+
+ max_vals->set_size(data->n_rows);
+ min_vals->set_size(data->n_rows);
+
+ MatType temp_d = arma::trans(*data);
+
+ for (size_t i = 0; i < temp_d.n_cols; i++) {
+
+ VecType dim_vals = arma::sort(temp_d.col(i));
+ (*min_vals)[i] = dim_vals[0];
+ (*max_vals)[i] = dim_vals[dim_vals.n_elem - 1];
+ }
+} // end GetMaxMinVals_
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree() :
+ start_(0),
+ end_(0),
+ max_vals_(NULL),
+ min_vals_(NULL),
+ left_(NULL),
+ right_(NULL)
+{}
+
+
+// Root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points) :
+ start_(0),
+ end_(total_points),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
+{
+ error_ = ComputeNodeError_(total_points);
+ // if this assert fails, this implies that you need
+ // a higher precision (or higher range) 'eT'
+ assert(std::abs(error_) < std::numeric_limits<cT>::max());
+
+ bucket_tag_ = -1;
+ root_ = true;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(MatType* data) :
+ start_(0),
+ end_(data->n_cols),
+ left_(NULL),
+ right_(NULL)
+{
+ max_vals_ = new VecType();
+ min_vals_ = new VecType();
+
+ GetMaxMinVals_(data, max_vals_, min_vals_);
+
+ error_ = ComputeNodeError_(data->n_cols);
+
+ bucket_tag_ = -1;
+ root_ = true;
+}
+
+
+// Non-root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t start,
+ size_t end,
+ cT error) :
+ start_(start),
+ end_(end),
+ error_(error),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
+{
+ bucket_tag_ = -1;
+ root_ = false;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals,
+ VecType* min_vals,
+ size_t total_points,
+ size_t start,
+ size_t end) :
+ start_(start),
+ end_(end),
+ max_vals_(max_vals),
+ min_vals_(min_vals),
+ left_(NULL),
+ right_(NULL)
+{
+
+ error_ = ComputeNodeError_(total_points);
+
+ bucket_tag_ = -1;
+ root_ = false;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+~DTree()
+{
+ if (left_ != NULL)
+ delete left_;
+
+ if (right_ != NULL)
+ delete right_;
+
+ if (min_vals_ != NULL)
+ delete min_vals_;
+
+ if (max_vals_ != NULL)
+ delete max_vals_;
+}
+
+
+// Greedily expand the tree
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+Grow(MatType* data,
+ arma::Col<size_t> *old_from_new,
+ bool useVolReg,
+ size_t maxLeafSize,
+ size_t minLeafSize)
+{
+ assert(data->n_rows == max_vals_->n_elem);
+ assert(data->n_rows == min_vals_->n_elem);
+
+ cT left_g, right_g;
+
+ // computing points ratio
+ ratio_ = (cT) (end_ - start_)
+ / (cT) old_from_new->n_elem;
+
+ // computing the v_t_inv:
+ // the inverse of the volume of the node
+ cT log_vol_t = 0;
+ for (size_t i = 0; i < max_vals_->n_elem; i++)
+ if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+ // using log to prevent overflow
+ log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+
+ // check for overflow
+ assert(std::exp(log_vol_t) > 0.0);
+ v_t_inv_ = 1.0 / std::exp(log_vol_t);
+
+ // Checking if node is large enough
+ if ((size_t) (end_ - start_) > maxLeafSize) {
+
+ // find the split
+ size_t dim, split_ind;
+ cT left_error, right_error;
+ if (FindSplit_(data, &dim, &split_ind,
+ &left_error, &right_error,
+ maxLeafSize, minLeafSize)) {
+
+ // printf("Split found\n");fflush(NULL);
+ // Split the data for the children
+ // MatType data_l, data_r;
+ eT split_val, lsplit_val, rsplit_val;
+ SplitData_(data, dim, split_ind,
+ old_from_new, &split_val,
+ &lsplit_val, &rsplit_val);
+
+ // make max and min vals for the children
+ VecType* max_vals_l = new VecType(*max_vals_);
+ VecType* max_vals_r = new VecType(*max_vals_);
+ VecType* min_vals_l = new VecType(*min_vals_);
+ VecType* min_vals_r = new VecType(*min_vals_);
+
+ (*max_vals_l)[dim] = split_val; // changed from just lsplit_val
+ (*min_vals_r)[dim] = split_val; // changed from just rsplit_val
+
+ // store split dim and split val in the node
+ split_value_ = split_val;
+ split_dim_ = dim;
+
+
+ // Recursively growing the children
+ left_ = new DTree(max_vals_l, min_vals_l,
+ start_, start_ + split_ind + 1,
+ left_error);
+ right_ = new DTree(max_vals_r, min_vals_r,
+ start_ + split_ind + 1, end_,
+ right_error);
+
+ left_g = left_->Grow(data, old_from_new, useVolReg,
+ maxLeafSize, minLeafSize);
+ right_g = right_->Grow(data, old_from_new, useVolReg,
+ maxLeafSize, minLeafSize);
+
+ // storing values of R(T~) and |T~|
+ subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
+ subtree_leaves_error_ = left_->subtree_leaves_error()
+ + right_->subtree_leaves_error();
+
+ // storing the subtree_leaves_v_t_inv
+ subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
+ + right_->subtree_leaves_v_t_inv();
+
+ // // storing the sum of the estimates (OF WHAT)
+ // st_estimate_ = left_->st_estimate() + right_->st_estimate();
+
+ // // storing del_f / del r(split_dim)
+ // cT del_f = (ratio_ * v_t_inv_)
+ // - (left_->ratio() * left_->v_t_inv());
+ // cT del_r = max_vals_[split_dim_] - split_value_;
+ // del_f_del_r_ = fabs(del_f / del_r);
+
+ // Forming T1 by removing leaves for which
+ // R(t) = R(t_L) + R(t_R)
+ if ((left_->subtree_leaves() == 1)
+ && (right_->subtree_leaves() == 1)) {
+ if (left_->error() + right_->error() == error_) {
+ delete left_;
+ left_ = NULL;
+ delete right_;
+ right_ = NULL;
+ subtree_leaves_ = 1;
+ subtree_leaves_error_ = error_;
+ subtree_leaves_v_t_inv_ = v_t_inv_;
+ } // end if
+ } // end if
+ } else {
+ // no split found so make a leaf out of it
+ subtree_leaves_ = 1;
+ subtree_leaves_error_ = error_;
+ subtree_leaves_v_t_inv_ = v_t_inv_;
+// st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+// del_f_del_r_ = 0.0;
+ } // end if-else
+ } else {
+ // We can make this a leaf node
+ assert((size_t) (end_ - start_) >= minLeafSize);
+ subtree_leaves_ = 1;
+ subtree_leaves_error_ = error_;
+ subtree_leaves_v_t_inv_ = v_t_inv_;
+
+ // // TO CHECK:
+ // // if these are the density estimate
+ // // it should be ratio_ * v_t_inv_
+ // st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+ // del_f_del_r_ = 0.0;
+ } // end if-else
+
+ // if leaf do not compute g_k(t), else compute, store,
+ // and propagate min(g_k(t_L),g_k(t_R),g_k(t)),
+ // unless t_L and/or t_R are leaves
+ if (subtree_leaves_ == 1) {
+ return std::numeric_limits<cT>::max();
+ } else {
+ cT g_t;
+ if (useVolReg) {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_v_t_inv_ - v_t_inv_);
+ } else {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_ - 1);
+ }
+
+ assert(g_t > 0.0);
+ return min(g_t, min(left_g, right_g));
+ } // end if-else
+
+ // need to compute (c_t^2)*r_t for all subtree leaves
+ // this is equal to n_t^2/r_t*n^2 = -error_ !!
+ // therefore the value we need is actually
+ // -1.0*subtree_leaves_error_
+} // Grow
+
+
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+PruneAndUpdate(cT old_alpha,
+ bool useVolReg)
+{
+ // compute g_t
+ if (subtree_leaves_ == 1) { // if leaf
+ return std::numeric_limits<cT>::max();
+ } else {
+
+ // compute g_t value for node t
+ cT g_t;
+ if (useVolReg) {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_v_t_inv_ - v_t_inv_);
+ } else {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_ - 1);
+ }
+
+ if (g_t > old_alpha) { // go down the tree and update accordingly
+ // traverse the children
+ cT left_g = left_->PruneAndUpdate(old_alpha, useVolReg);
+ cT right_g = right_->PruneAndUpdate(old_alpha, useVolReg);
+
+ // update values
+ subtree_leaves_ = left_->subtree_leaves()
+ + right_->subtree_leaves();
+ subtree_leaves_error_ = left_->subtree_leaves_error()
+ + right_->subtree_leaves_error();
+ subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
+ + right_->subtree_leaves_v_t_inv();
+
+ // // updating values for the sum of density estimates
+ // st_estimate_
+ // = left_->st_estimate() + right_->st_estimate();
+
+ // update g_t value
+ if (useVolReg) {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_v_t_inv_ - v_t_inv_);
+ } else {
+ g_t = (error_ - subtree_leaves_error_)
+ / (subtree_leaves_ - 1);
+ }
+
+ assert(g_t < std::numeric_limits<cT>::max());
+
+ if (left_->subtree_leaves() == 1
+ && right_->subtree_leaves() == 1) {
+ return g_t;
+ } else if (left_->subtree_leaves() == 1) {
+ return min(g_t, right_g);
+ } else if (right_->subtree_leaves() == 1) {
+ return min(g_t, left_g);
+ } else {
+ return min(g_t, min(left_g, right_g));
+ }
+ } else { // prune this subtree
+
+ // otherwise this should be equal to the alpha
+ // for this node. So we check that:
+ // assert(g_t == old_alpha, "Alpha != g(t) but less than!!");
+
+ // // compute \del f_hat(x) / \del r(split_dim)
+ // cT st_change_in_estimate
+ // = st_estimate_ - (ratio_ * ratio_ * v_t_inv_);
+
+ // printf("%lg:%lg Pruned %lg\n",
+ // old_alpha, del_f_del_r_, st_change_in_estimate);
+
+
+ // making this node a leaf node
+ subtree_leaves_ = 1;
+ subtree_leaves_error_ = error_;
+ subtree_leaves_v_t_inv_ = v_t_inv_;
+// st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+// del_f_del_r_ = 0.0;
+ delete left_;
+ left_ = NULL;
+ delete right_;
+ right_ = NULL;
+ // passing information upward
+ return std::numeric_limits<cT>::max();
+ } // end if-else for pruning subtree
+ } /// if-else for leaf or non-leaf
+} // PruneAndUpdate
+
+
+// Checking whether a given point is within the
+// bounding box of this node (check generally done
+// at the root, so its the bounding box of the data)
+//
+// Improvement: To open up the range with epsilons on
+// both sides where epsilon on the density near the boundary.
+template<typename eT, typename cT>
+bool DTree<eT, cT>::
+WithinRange_(VecType* query)
+{
+ for (size_t i = 0; i < query->n_elem; i++)
+ if (((*query)[i] < (*min_vals_)[i])
+ || ((*query)[i] > (*max_vals_)[i]))
+ return false;
+
+ return true;
+}
+
+
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+ComputeValue(VecType* query)
+{
+ assert(query->n_elem == max_vals_->n_elem);
+
+ if (root_ == 1) // if root
+ // check if query is within range
+ if (!WithinRange_(query))
+ return 0.0;
+ // end WithinRange_ if-else
+ // end root if
+
+ if (subtree_leaves_ == 1) // if leaf
+ return ratio_ * v_t_inv_;
+ else
+ if ((*query)[split_dim_] <= split_value_) // if left subtree
+ // go to left child
+ return left_->ComputeValue(query); //, printer);
+ else // if right subtree
+ // go to right child
+ return right_->ComputeValue(query);
+ // end if-else
+} // ComputeValue
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+WriteTree(size_t level, FILE *fp)
+{
+ if (subtree_leaves_ > 1){
+ fprintf(fp, "\n");
+ for (size_t i = 0; i < level; i++){
+ fprintf(fp, "|\t");
+ }
+ fprintf(fp, "Var. %zu > %lg",
+ split_dim_, split_value_);
+ right_->WriteTree(level+1, fp);
+ fprintf(fp, "\n");
+ for (size_t i = 0; i < level; i++){
+ fprintf(fp, "|\t");
+ }
+ fprintf(fp, "Var. %zu <= %lg ",
+ split_dim_, split_value_);
+ left_->WriteTree(level+1, fp);
+
+ } else { // if leaf
+ fprintf(fp, ": f(x)=%Lg", (cT) ratio_ * v_t_inv_);
+ if (bucket_tag_ != -1)
+ fprintf(fp, " BT:%d", bucket_tag_);
+ }
+} // WriteTree
+
+
+// indexing the buckets for possible usage later
+template<typename eT, typename cT>
+int DTree<eT, cT>::
+TagTree(int tag)
+{
+ if (subtree_leaves_ == 1) {
+ bucket_tag_ = tag;
+ return (tag+1);
+ } else {
+ return right_->TagTree(left_->TagTree(tag));
+ }
+} // TagTree
+
+
+template<typename eT, typename cT>
+int DTree<eT, cT>::
+FindBucket(VecType* query)
+{
+ assert(query->n_elem == max_vals_->n_elem);
+
+ if (subtree_leaves_ == 1) { // if leaf
+ return bucket_tag_;
+ } else if ((*query)[split_dim_] <= split_value_) { // if left subtree
+ // go to left child
+ return left_->FindBucket(query);
+ } else { // if right subtree
+ // go to right child
+ return right_->FindBucket(query);
+ } // end if-else
+} // FindBucket
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+ComputeVariableImportance(arma::Col<double> *imps)
+{
+ if (subtree_leaves_ == 1) {
+ // if leaf, do nothing
+ return;
+ } else {
+ // compute the improvement in error because of the
+ // split
+ double error_improv
+ = (double) (error_ - (left_->error() + right_->error()));
+ (*imps)[split_dim_] += error_improv;
+ left_->ComputeVariableImportance(imps);
+ right_->ComputeVariableImportance(imps);
+ return;
+ }
+} // ComputeVariableImportance
+
+
+#endif
More information about the mlpack-svn
mailing list