[mlpack-svn] r12189 - in mlpack/trunk/src/mlpack/methods: . det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 4 11:54:33 EDT 2012


Author: pram
Date: 2012-04-04 11:54:33 -0400 (Wed, 04 Apr 2012)
New Revision: 12189

Added:
   mlpack/trunk/src/mlpack/methods/det/
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Density Estimation Trees (DET) added to mlpack

Added: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,239 @@
+/**
+ * @file dt_main.cpp
+ * @ Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file provides an example use of the DET
+ */
+
+#include <mlpack/core.hpp>
+#include "dt_utils.hpp"
+
+using namespace mlpack;
+using namespace std;
+
+PROGRAM_INFO("Density estimation with DET", "This program "
+             "provides an example use of the Density Estimation "
+	     "Tree for density estimation. For more details, "
+	     "please look at the paper titled "
+	     "'Density Estimation Trees'.");
+
+// input data files
+PARAM_STRING_REQ("input/training_set", "The data set on which to "
+		 "perform density estimation.", "S");
+PARAM_STRING("input/test_set", "An extra set of test points on "
+	     "which to estimate the density given the estimator.",
+	     "T", "");
+PARAM_STRING("input/labels", "The labels for the given training data to "
+	     "generate the class membership of each leaf (as an "
+	     "extra statistic)", "L", "");
+
+// output data files
+PARAM_STRING("output/unpruned_tree_estimates", "The file "
+	     "in which to output the estimates on the "
+	     "training set from the large unpruned tree.", "u", "");
+PARAM_STRING("output/training_set_estimates", "The file "
+	     "in which to output the estimates on the "
+	     "training set from the final optimally pruned"
+	     " tree.", "s", "");
+PARAM_STRING("output/test_set_estimates", "The file "
+	     "in which to output the estimates on the "
+	     "test set from the final optimally pruned"
+	     " tree.", "t", "");
+PARAM_STRING("output/leaf_class_table", "The file "
+	     "in which to output the leaf class membership "
+	     "table.", "l", "leaf_class_membership.txt");
+PARAM_STRING("output/tree", "The file in which to print "
+	     "the final optimally pruned tree.", "p", "");
+PARAM_STRING("output/vi", "The file to output the "
+	     "variable importance values for each feature.",
+	     "i", "");
+
+// parameters for the algorithm
+PARAM_INT("param/number_of_classes", "The number of classes present "
+	  "in the 'labels' set provided", "C", 0);
+PARAM_INT("param/folds", "The number of folds of cross-validation"
+	  " to performed for the estimation (enter 0 for LOOCV)",
+	  "F", 10);
+PARAM_INT("DET/min_leaf_size", "The minimum size of a leaf"
+	  " in the unpruned fully grown DET.", "N", 5);
+PARAM_INT("DET/max_leaf_size", "The maximum size of a leaf"
+	  " in the unpruned fully grown DET.", "M", 10);
+PARAM_FLAG("DET/use_volume_reg", "This flag gives the used the "
+	   "option to use a form of regularization similar to "
+	   "the usual alpha-pruning in decision tree. But "
+	   "instead of regularizing on the number of leaves, "
+	   "you regularize on the sum of the inverse of the volume "
+	   "of the leaves (meaning you penalize  "
+	   "low volume leaves.", "R");
+
+// some flags for output of some information about the tree
+PARAM_FLAG("flag/print_tree", "If you just wish to print the tree "
+	   "out on the command line.", "P");
+PARAM_FLAG("flag/print_vi", "If you just wish to print the "
+	   "variable importance of each feature "
+	   "out on the command line.", "I");
+
+int main(int argc, char *argv[]) {
+
+
+  CLI::ParseCommandLine(argc, argv);
+
+  DTree<>* test_pvt = new DTree<>();
+  bool test_success = test_pvt->TestPrivateFunctions();
+
+  if (test_success) {
+    Log::Warn << "Private functions tests successful." << endl;
+  } else {
+    Log::Warn << "Private functions tests failed." << endl;
+  }
+
+  exit(0);
+
+  string train_set_file = CLI::GetParam<string>("S");
+  arma::Mat<float> training_data;
+
+  Log::Info << "Loading training set..." << endl;
+  if (!data::Load(train_set_file, training_data))
+    Log::Fatal << "Training set file "<< train_set_file 
+	       << " can't be loaded." << endl;
+
+  Log::Info << "Training set (" << training_data.n_rows 
+	    << ", " << training_data.n_cols
+	    << ")" << endl;
+
+  // cross-validation here
+  size_t folds = CLI::GetParam<int>("F");
+  if (folds == 0) {
+    folds = training_data.n_cols;
+    Log::Info << "Starting Leave-One-Out Cross validation" << endl;
+  } else 
+    Log::Info << "Starting " << folds 
+	      << "-fold Cross validation" << endl;
+
+
+ 
+  // obtaining the optimal tree
+  string unpruned_tree_estimate_file 
+    = CLI::GetParam<string>("u");
+
+  Timer::Start("DET/Training");
+  DTree<float> *dtree_opt = dt_utils::Trainer<float>
+    (&training_data, folds, CLI::HasParam("R"), CLI::GetParam<int>("M"),
+     CLI::GetParam<int>("N"), unpruned_tree_estimate_file);
+  Timer::Stop("DET/Training");
+ 
+  // computing densities for the train points in the
+  // optimal tree
+  FILE *fp = NULL;
+
+  if (CLI::GetParam<string>("s") != "") {
+    string optimal_estimates_file = CLI::GetParam<string>("s");
+    fp = fopen(optimal_estimates_file.c_str(), "w");
+  }   
+
+  // Computation timing is more accurate when you do not 
+  // perform the printing.
+  Timer::Start("DET/EstimationTime");
+  for (size_t i = 0; i < training_data.n_cols; i++) {
+    arma::Col<float> test_p = training_data.unsafe_col(i);
+    long double f = dtree_opt->ComputeValue(&test_p);
+    if (fp != NULL)
+      fprintf(fp, "%Lg\n", f);
+  } // end for
+  Timer::Stop("DET/EstimationTime");
+
+  if (fp != NULL)
+    fclose(fp);
+  
+
+  // computing the density at the provided test points
+  // and outputting the density in the given file.
+  if (CLI::GetParam<string>("T") != "") {
+    string test_file = CLI::GetParam<string>("T");
+    arma::Mat<float> test_data;
+    Log::Info << "Loading test set..." << endl;
+    if (!data::Load(test_file, test_data))
+      Log::Fatal << "Test set file "<< test_file 
+		 << " can't be loaded." << endl;
+
+    Log::Info << "Test set (" << test_data.n_rows 
+	      << ", " << test_data.n_cols
+	      << ")" << endl;
+
+    fp = NULL;
+
+    if (CLI::GetParam<string>("t") != "") {
+      string test_density_file
+	= CLI::GetParam<string>("t");
+      fp = fopen(test_density_file.c_str(), "w");
+    }
+
+    Timer::Start("DET/TestSetEstimation");
+    for (size_t i = 0; i < test_data.n_cols; i++) {
+      arma::Col<float> test_p = test_data.unsafe_col(i);
+      long double f = dtree_opt->ComputeValue(&test_p);
+      if (fp != NULL)
+	fprintf(fp, "%Lg\n", f);
+    } // end for
+    Timer::Stop("DET/TestSetEstimation");
+
+    if (fp != NULL) 
+      fclose(fp);
+  } // Test set estimation
+
+  // printing the final tree
+  if (CLI::HasParam("P")) {
+
+    fp = NULL;
+    if (CLI::GetParam<string>("p") != "") {
+      string print_tree_file = CLI::GetParam<string>("p");
+      fp = fopen(print_tree_file.c_str(), "w");
+
+      if (fp != NULL) {
+	dtree_opt->WriteTree(0, fp);
+	fclose(fp);
+      }
+    } else {
+      dtree_opt->WriteTree(0, stdout);
+      printf("\n");
+    }
+  } // Printing the tree
+
+  // print the leaf memberships for the optimal tree
+  if (CLI::GetParam<string>("L") != "") {
+    std::string labels_file = CLI::GetParam<string>("L");
+    arma::Mat<int> labels;
+
+    Log::Info << "Loading label file..." << endl;
+    if (!data::Load(labels_file, labels))
+      Log::Fatal << "Label file "<< labels_file 
+		 << " can't be loaded." << endl;
+
+    Log::Info << "Labels (" << labels.n_rows 
+	      << ", " << labels.n_cols
+	      << ")" << endl;
+    
+    size_t num_classes = CLI::GetParam<int>("C");
+    if (num_classes == 0)
+      Log::Fatal << "Please provide the number of classes"
+		 << " present in the label file" << endl;
+
+    assert(training_data.n_cols == labels.n_cols);
+    assert(labels.n_rows == 1);
+
+    dt_utils::PrintLeafMembership<float>
+      (dtree_opt, training_data, labels, num_classes,
+       (string) CLI::GetParam<string>("l"));
+  } // leaf class membership
+  
+
+  if(CLI::HasParam("I")) {
+    dt_utils::PrintVariableImportance<float>
+      (dtree_opt, training_data.n_rows,
+       (string) CLI::GetParam<string>("i"));
+  } // print variable importance
+
+
+  delete dtree_opt;
+  return 0;
+} // end main

Added: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,333 @@
+/**
+ * @file dt_utils.hpp
+ * @ Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file implements functions to perform
+ * different tasks with the Density Tree class.
+ */
+
+#ifndef DT_UTILS_HPP
+#define DT_UTILS_HPP
+
+#include <string>
+
+#include <mlpack/core.hpp>
+#include "dtree.hpp"
+
+using namespace mlpack;
+using namespace std;
+
+
+namespace dt_utils {
+
+  template<typename eT>
+  void PrintLeafMembership(DTree<eT> *dtree,
+			   const arma::Mat<eT>& data,
+			   const arma::Mat<int>& labels,
+			   size_t num_classes,
+			   string leaf_class_membership_file = "") 
+  {
+    // tag the leaves with numbers
+    int num_leaves = dtree->TagTree(0);
+    
+    arma::Mat<size_t> table(num_leaves, num_classes);
+    table.zeros();
+
+    for (size_t i = 0; i < data.n_cols; i++) {
+      arma::Col<eT> test_p = data.unsafe_col(i);
+      int leaf_tag = dtree->FindBucket(&test_p);
+      int label = labels[i];
+      table(leaf_tag, label) += 1;
+    } // end for
+
+    if (leaf_class_membership_file == "") {
+      Log::Warn << "Leaf Membership: Classes in each leaf" << endl
+		<< table << endl;
+    } else {
+      // create a stream for the file
+      ofstream outfile(leaf_class_membership_file.c_str());
+      if (outfile.good()) {
+	Log::Warn << "Leaf Membership: Classes in each leaf" 
+		  << " printed in '" << leaf_class_membership_file
+		  << "'" << endl;
+	outfile << table;
+      } else {
+	Log::Warn << "Can't open '" << leaf_class_membership_file
+		  << "'" << endl;
+      }
+      outfile.close();
+    }
+
+    return;
+    // maybe print some more statistics if these work out well
+  } // PrintLeafMembership
+
+
+  template<typename eT>
+  void PrintVariableImportance(DTree<eT> *dtree,
+			       size_t num_dims,
+			       string vi_file = "")
+  {
+    arma::Col<double> *imps 
+      = new arma::Col<double>(num_dims);
+
+    for (size_t i = 0; i < imps->n_elem; i++)
+      (*imps)[i] = 0.0;
+    
+    dtree->ComputeVariableImportance(imps);
+    double max = 0.0;
+    for (size_t i = 0; i < imps->n_elem; i++)
+      if ((*imps)[i] > max)
+	max = (*imps)[i];
+    Log::Warn << "Max. variable importance: " << max << endl;
+
+
+    if (vi_file == "") {
+      Log::Warn << "Variable importance: " << endl
+		<< imps->t();
+    } else {
+      ofstream outfile(vi_file.c_str());
+      if (outfile.good()) {
+	Log::Warn << "Variable importance printed in '"
+		  << vi_file << "'" << endl;
+	outfile << *imps;
+      } else {
+	Log::Warn << "Can't open '" << vi_file
+		  << "'" << endl;
+      }
+      outfile.close();
+    }    
+
+    return;
+  } // PrintVariableImportance
+
+
+  // This function trains the optimal decision tree
+  // using the given number of folds
+  template<typename eT>
+  DTree<eT> *Trainer(arma::Mat<eT>* dataset, 
+		     size_t folds,
+		     bool useVolumeReg = false,
+		     size_t maxLeafSize = 10,
+		     size_t minLeafSize = 5,
+		     string unprunedTreeOutput = "") 
+  {
+    // Initializing the tree
+    DTree<eT> *dtree = new DTree<eT>(dataset);
+
+    // Getting ready to grow the tree
+    arma::Col<size_t> old_from_new(dataset->n_cols);
+    for (size_t i = 0; i < old_from_new.n_elem; i++) {
+      old_from_new[i] = i;
+    }
+
+    // Saving the dataset since it would be modified
+    // while growing the tree
+    arma::Mat<eT>* new_dataset = new arma::Mat<eT>(*dataset);
+
+    // Growing the tree
+    long double old_alpha = 0.0;
+    long double alpha = dtree->Grow(new_dataset, &old_from_new,
+				    useVolumeReg, maxLeafSize, 
+				    minLeafSize);
+    // clear the data set
+    delete new_dataset;
+
+    Log::Info << dtree->subtree_leaves() 
+	      << " leaf nodes in the tree with full data, min_alpha: "
+	      << alpha << endl;
+
+    // computing densities for the train points in the
+    // full tree if asked for.
+    if (unprunedTreeOutput != "") {
+
+      ofstream outfile(unprunedTreeOutput.c_str());
+      if (outfile.good()) {
+	for (size_t i = 0; i < dataset->n_cols; i++) {
+	  arma::Col<eT> test_p = dataset->unsafe_col(i);
+	  outfile << dtree->ComputeValue(&test_p) << endl;
+	} // end for
+      } else {
+	Log::Warn << "Can't open '" << unprunedTreeOutput
+		  << "'" << endl;
+      }
+
+      outfile.close();
+
+    } // if unprunedTreeOutput
+
+    // sequential pruning and saving the alpha vals and the
+    // values of c_t^2*r_t
+    std::vector<std::pair<long double, long double> > pruned_sequence;
+    while (dtree->subtree_leaves() > 1) {
+
+      std::pair<long double, long double> tree_seq
+	(old_alpha, -1.0 * dtree->subtree_leaves_error());
+      pruned_sequence.push_back(tree_seq);
+      old_alpha = alpha;
+      alpha = dtree->PruneAndUpdate(old_alpha, useVolumeReg);
+
+      // some checks
+      assert((alpha < std::numeric_limits<long double>::max())
+	     ||(dtree->subtree_leaves() == 1));
+      assert(alpha > old_alpha);
+      assert(dtree->subtree_leaves_error() >= -1.0 * tree_seq.second);
+
+    } // end while
+    
+    std::pair<long double, long double> tree_seq 
+      (old_alpha, -1.0 * dtree->subtree_leaves_error());
+    pruned_sequence.push_back(tree_seq);
+
+    Log::Info << pruned_sequence.size()
+	      << " trees in the sequence, max_alpha: "
+	      << old_alpha << endl;
+
+    delete dtree;
+
+    arma::Mat<eT>* cvdata = new arma::Mat<eT>(*dataset);
+
+    size_t test_size = dataset->n_cols / folds;
+
+    // Go through each fold
+    for (size_t fold = 0; fold < folds; fold++) {
+      
+      // break up data into train and test set
+      size_t start = fold * test_size,
+	end = std::min((fold + 1) * test_size, (size_t) cvdata->n_cols);
+      arma::Mat<eT> test = cvdata->cols(start, end - 1);
+      arma::Mat<eT>* train 
+	= new arma::Mat<eT>(cvdata->n_rows, 
+			    cvdata->n_cols - test.n_cols);
+
+      if (start == 0 && end < cvdata->n_cols) {
+	assert(train->n_cols == cvdata->n_cols - end);
+	train->cols(0, train->n_cols - 1) 
+	  = cvdata->cols(end, cvdata->n_cols - 1);
+
+
+      } else if (start > 0 && end == cvdata->n_cols) {
+	assert(train->n_cols == start);
+	train->cols(0, train->n_cols - 1) = cvdata->cols(0, start - 1);
+
+      } else {
+	assert(train->n_cols == start + cvdata->n_cols - end);
+
+	train->cols(0, start - 1) = cvdata->cols(0, start - 1);
+	train->cols(start, train->n_cols - 1) 
+	  = cvdata->cols(end, cvdata->n_cols - 1);
+      }
+
+      assert(train->n_cols + test.n_cols == cvdata->n_cols);
+
+      // Initializing the tree
+      DTree<eT> *dtree_cv = new DTree<eT>(train);
+
+      // Getting ready to grow the tree
+      arma::Col<size_t> old_from_new_cv(train->n_cols);
+      for (size_t i = 0; i < old_from_new_cv.n_elem; i++) {
+	old_from_new_cv[i] = i;
+      }
+
+      // Growing the tree
+      old_alpha = 0.0;
+      alpha = dtree_cv->Grow(train, &old_from_new_cv,
+			     useVolumeReg, maxLeafSize, 
+			     minLeafSize);
+
+      // sequential pruning with all the values of available
+      // alphas and adding values for test values
+      std::vector<std::pair<long double, long double> >::iterator it;
+      for (it = pruned_sequence.begin();
+	   it < pruned_sequence.end() -2; ++it) {
+      
+	// compute test values for this state of the tree
+	long double val_cv = 0.0;
+	for (size_t i = 0; i < test.n_cols; i++) {
+	  arma::Col<eT> test_point = test.unsafe_col(i);
+	  val_cv += dtree_cv->ComputeValue(&test_point);
+	}
+
+	// update the cv error value
+	it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+
+	// getting the new alpha value and pruning accordingly
+	old_alpha = sqrt(((it+1)->first) * ((it+2)->first));
+	alpha = dtree_cv->PruneAndUpdate(old_alpha, useVolumeReg);
+      } // end for
+
+      // compute test values for this state of the tree
+      long double val_cv = 0.0;
+      for (size_t i = 0; i < test.n_cols; i++) {
+	arma::Col<eT> test_point = test.unsafe_col(i);
+	val_cv += dtree_cv->ComputeValue(&test_point);
+      }
+      // update the cv error value
+      it->second -= 2.0 * val_cv / (long double) dataset->n_cols;
+
+      test.reset();
+      delete train;
+
+      delete dtree_cv;
+
+    } // end for loop for number of cv-folds
+
+    delete cvdata;
+
+    long double optimal_alpha = -1.0, 
+      best_cv_error = numeric_limits<long double>::max();
+    std::vector<std::pair<long double, long double> >::iterator it;
+
+    for (it = pruned_sequence.begin();
+	 it < pruned_sequence.end() -1; ++it) {
+
+      if (it->second < best_cv_error) {
+	best_cv_error = it->second;
+	optimal_alpha = it->first;
+      } // end if
+    } // end for
+
+    Log::Info << "Optimal alpha: " << optimal_alpha << endl;
+
+    // Initializing the tree
+    DTree<eT> *dtree_opt = new DTree<eT>(dataset);
+    // Getting ready to grow the tree
+    for (size_t i = 0; i < old_from_new.n_elem; i++) {
+      old_from_new[i] = i;
+    }
+
+    // Saving the dataset since it would be modified
+    // while growing the tree
+    new_dataset = new arma::Mat<eT>(*dataset);
+
+    // Growing the tree
+    old_alpha = 0.0;
+    alpha = dtree_opt->Grow(new_dataset, &old_from_new,
+			    useVolumeReg, maxLeafSize, 
+			    minLeafSize);
+
+    // Pruning with optimal alpha
+    while (old_alpha < optimal_alpha 
+	   && dtree_opt->subtree_leaves() > 1) {
+      old_alpha = alpha;
+      alpha = dtree_opt->PruneAndUpdate(old_alpha, useVolumeReg);
+
+      // some checks
+      assert((alpha < numeric_limits<long double>::max())
+	     ||(dtree_opt->subtree_leaves() == 1));
+      assert(alpha > old_alpha);
+    } // end while
+
+    Log::Info << dtree_opt->subtree_leaves() 
+	      << " leaf nodes in the optimally pruned tree,"
+	      << " optimal alpha: "
+	      << old_alpha << endl;
+
+    delete new_dataset;
+
+    return dtree_opt;
+  } // Trainer
+
+}; // namespace dt_utils
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,412 @@
+/**
+ * @file dtree.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Density Tree class
+ *
+ */
+
+#ifndef DTREE_HPP
+#define DTREE_HPP
+
+#include <assert.h>
+#include <vector>
+
+#include <mlpack/core.hpp>
+
+using namespace mlpack;
+using namespace std;
+
+
+// This two types in the template are used 
+// for two purposes:
+// eT - the type to store the data in (for most practical 
+// purposes, storing the data as a float suffices).
+// cT - the type to perform computations in (computations 
+// like computing the error, the volume of the node etc.). 
+// For high dimensional data, it might be possible that the 
+// computation might overflow, so you should use either 
+// normalize your data in the (-1, 1) hypercube or use 
+// long double or modify this code to perform computations
+// using logarithms.
+template<typename eT = float,
+	 typename cT = long double>
+class DTree{
+ 
+  ////////////////////// Member Variables /////////////////////////////////////
+  
+ private:
+
+  typedef arma::Mat<eT> MatType;
+  typedef arma::Col<eT> VecType;
+  typedef arma::Row<eT> RowVecType;
+
+
+  // The indices in the complete set of points
+  // (after all forms of swapping in the original data
+  // matrix to align all the points in a node 
+  // consecutively in the matrix. The 'old_from_new' array 
+  // maps the points back to their original indices.
+  size_t start_, end_;
+  
+  // The split dim for this node
+  size_t split_dim_;
+
+  // The split val on that dim
+  eT split_value_;
+
+  // L2-error of the node
+  cT error_;
+
+  // sum of the error of the leaves of the subtree
+  cT subtree_leaves_error_;
+
+  // number of leaves of the subtree
+  size_t subtree_leaves_;
+
+  // flag to indicate if this is the root node
+  // used to check whether the query point is 
+  // within the range
+  bool root_;
+
+  // ratio of number of points in the node to the 
+  // total number of points (|t| / N)
+  cT ratio_;
+
+  // the inverse of  volume of the node
+  cT v_t_inv_;
+
+  // sum of the reciprocal of the inverse v_ts
+  // the leaves of this subtree
+  cT subtree_leaves_v_t_inv_;
+
+  // since we are using uniform density, we need
+  // the max and min of every dimension for every node
+  VecType* max_vals_;
+  VecType* min_vals_;
+
+  // the tag for the leaf used for hashing points
+  int bucket_tag_;
+
+  // The children
+  DTree<eT, cT> *left_;
+  DTree<eT, cT> *right_;
+
+  ////////////////////// Constructors /////////////////////////////////////////
+
+public: 
+
+  ////////////////////// Getters and Setters //////////////////////////////////
+  size_t start() { return start_; }
+
+  size_t end() { return end_; }
+
+  size_t split_dim() { return split_dim_; }
+
+  eT split_value() { return split_value_; }
+
+  cT error() { return error_; }
+
+  cT subtree_leaves_error() { return subtree_leaves_error_; }
+
+  size_t subtree_leaves() { return subtree_leaves_; }
+
+  cT ratio() { return ratio_; }
+
+  cT v_t_inv() { return v_t_inv_; }
+
+  cT subtree_leaves_v_t_inv() { return subtree_leaves_v_t_inv_; }
+
+  DTree<eT, cT>* left() { return left_; }
+  DTree<eT, cT>* right() { return right_; }
+
+  bool root() { return root_; }
+
+  ////////////////////// Private Functions ////////////////////////////////////
+ private:
+
+  cT ComputeNodeError_(size_t total_points);
+  
+  bool FindSplit_(MatType* data,
+		  size_t* split_dim,
+		  size_t* split_ind,
+		  cT* left_error, 
+		  cT* right_error,
+		  size_t maxLeafSize = 10,
+		  size_t minLeafSize = 5);
+
+  void SplitData_(MatType* data,
+		  size_t split_dim,
+		  size_t split_ind,
+		  arma::Col<size_t>* old_from_new, 
+		  eT* split_val,
+		  eT* lsplit_val,
+		  eT* rsplit_val);
+
+  void GetMaxMinVals_(MatType* data,
+		      VecType* max_vals,
+		      VecType* min_vals);
+
+  bool WithinRange_(VecType* query);
+
+  ///////////////////// Public Functions //////////////////////////////////////
+ public:
+  
+  DTree();
+
+  // Root node initializer
+  // with the bounding box of the data
+  // it contains instead of just the data.
+  DTree(VecType* max_vals, 
+	VecType* min_vals,
+	size_t total_points);
+
+  // Root node initializer
+  // with the data, no bounding box.
+  DTree(MatType* data);
+
+  // Non-root node initializers
+  DTree(VecType* max_vals, 
+	VecType* min_vals,
+	size_t start,
+	size_t end,
+	cT error);
+
+  DTree(VecType* max_vals, 
+	VecType* min_vals,
+	size_t total_points,
+	size_t start,
+	size_t end);
+
+  ~DTree();
+
+  // Greedily expand the tree
+  cT Grow(MatType* data, 
+	  arma::Col<size_t> *old_from_new,
+	  bool useVolReg = false,
+	  size_t maxLeafSize = 10,
+	  size_t minLeafSize = 5);
+
+  // perform alpha pruning on the tree
+  cT PruneAndUpdate(cT old_alpha,
+		    bool useVolReg = false);
+
+  // compute the density at a given point
+  cT ComputeValue(VecType* query);
+
+  // print the tree (in a DFS manner)
+  void WriteTree(size_t level, FILE *fp);
+
+  // indexing the buckets for possible usage later
+  int TagTree(int tag);
+
+  // This is used to generate the class membership
+  // of a learned tree.
+  int FindBucket(VecType* query);
+
+  // This computes the variable importance list 
+  // for the learned tree.
+  void ComputeVariableImportance(arma::Col<double> *imps);
+
+  // A public function to test the private functions
+  bool TestPrivateFunctions() {
+
+
+    bool return_flag = true;
+
+    // Create data
+    MatType test_data(3,5);
+
+    test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+	      << 5 << 0 << 1 << 7 << 1 << arma::endr
+	      << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+    // save current data
+    size_t true_start = start_, true_end = end_;
+    VecType* true_max_vals = max_vals_;
+    VecType* true_min_vals = min_vals_;
+    cT true_error = error_;
+
+
+    // Test GetMaxMinVals_
+    min_vals_ = NULL;
+    max_vals_ = NULL;
+    max_vals_ = new VecType();
+    min_vals_ = new VecType();
+
+    GetMaxMinVals_(&test_data, max_vals_, min_vals_);
+
+    if ((*max_vals_)[0] != 7 || (*min_vals_)[0] != 3) {
+      Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+      return_flag =  false;
+    }
+
+    if ((*max_vals_)[1] != 7 || (*min_vals_)[1] != 0) {
+      Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+      return_flag =  false;
+    }
+
+    if ((*max_vals_)[2] != 8 || (*min_vals_)[2] != 1) {
+      Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
+      return_flag =  false;
+    }
+
+    // Test ComputeNodeError_
+    start_ = 0; 
+    end_ = 5;
+    cT node_error = ComputeNodeError_(5);
+    cT log_vol = (cT) std::log(4) + (cT) std::log(7) + (cT) std::log(7);
+    cT true_node_error = -1.0 * std::exp(-log_vol);
+
+    if (std::abs(node_error - true_node_error) > 1e-7) {
+      Log::Warn << "Test: True error : " << true_node_error
+		<< ", Computed error: " << node_error
+		<< ", diff: " << std::abs(node_error - true_node_error)
+		<< endl;
+      return_flag =  false;
+    }
+
+    start_ = 3; 
+    end_ = 5;
+    node_error = ComputeNodeError_(5);
+    true_node_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5) - log_vol);
+
+    if (std::abs(node_error - true_node_error) > 1e-7) {
+      Log::Warn << "Test: True error : " << true_node_error
+		<< ", Computed error: " << node_error
+		<< ", diff: " << std::abs(node_error - true_node_error)
+		<< endl;
+      return_flag =  false;
+    }
+
+    // Test WithinRange_
+
+    VecType test_query(3);
+    test_query << 4.5 << 2.5 << 2;
+
+    if (!WithinRange_(&test_query)) {
+      Log::Warn << "Test: WithinRange_ failed" << endl;
+      return_flag =  false;
+    }
+
+    test_query << 8.5 << 2.5 << 2;
+
+    if (WithinRange_(&test_query)) {
+      Log::Warn << "Test: WithinRange_ failed" << endl;
+      return_flag =  false;
+    }
+
+    // Test FindSplit_
+    start_ = 0;
+    end_ = 5;
+    error_ = ComputeNodeError_(5);
+
+    size_t ob_dim, true_dim, ob_ind, true_ind;
+    cT true_left_error, ob_left_error, true_right_error, ob_right_error;
+
+    true_dim = 2;
+    true_ind = 1;
+    true_left_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5) 
+				      - (std::log((cT) 7) + std::log((cT) 4)
+					 + std::log((cT) 4.5)));
+    true_right_error =  -1.0 * std::exp(2 * std::log((cT) 3 / (cT) 5) 
+				      - (std::log((cT) 7) + std::log((cT) 4)
+					 + std::log((cT) 2.5)));
+
+    if(!FindSplit_(&test_data, &ob_dim, &ob_ind, 
+		   &ob_left_error, &ob_right_error, 2, 1)) {
+      Log::Warn << "Test: FindSplit_ returns false." << endl;
+      return_flag =  false;
+    }
+
+    if (true_dim != ob_dim) {
+      Log::Warn << "Test: FindSplit_ - True dim: " << true_dim
+		<< ", Obtained dim: " << ob_dim << endl;
+      return_flag =  false;
+    }
+
+    if (true_ind != ob_ind) {
+      Log::Warn << "Test: FindSplit_ - True ind: " << true_ind
+		<< ", Obtained ind: " << ob_ind << endl;
+      return_flag =  false;
+    }
+
+    if (std::abs(true_left_error - ob_left_error) > 1e-7) {
+      Log::Warn << "Test: FindSplit_ - True left_error: " << true_left_error
+		<< ", Obtained left_error: " << ob_left_error 
+		<< ", diff: " << std::abs(true_left_error - ob_left_error)
+		<< endl;
+      return_flag =  false;
+    }
+
+    if (std::abs(true_right_error - ob_right_error) > 1e-7) {
+      Log::Warn << "Test: FindSplit_ - True right_error: " << true_right_error
+		<< ", Obtained right_error: " << ob_right_error 
+		<< ", diff: " << std::abs(true_right_error - ob_right_error)
+		<< endl;
+      return_flag =  false;
+    }
+
+    // Test SplitData_
+    MatType split_test_data(test_data);
+    arma::Col<size_t> o_test(5);
+    o_test << 1 << 2 << 3 << 4 << 5;
+
+    start_ = 0;
+    end_ = 5;
+    size_t split_dim = 2, split_ind = 1;
+    eT true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
+      true_rsplit_val, ob_rsplit_val;
+
+    true_lsplit_val = 5;
+    true_rsplit_val = 6;
+    true_split_val = (true_lsplit_val + true_rsplit_val) / 2;
+
+    SplitData_(&split_test_data, split_dim, split_ind, 
+	       &o_test, &ob_split_val, 
+	       &ob_lsplit_val, &ob_rsplit_val);
+
+    if (o_test[0] != 1 || o_test[1] != 4 || o_test[2] != 3 
+	|| o_test[3] != 2 || o_test[4] != 5) {
+      Log::Warn << "Test: SplitData_ - OFW should be 1,4,3,2,5"
+		<< ", is " << o_test.t();
+      return_flag =  false;
+    }
+
+    if (true_split_val != ob_split_val) {
+      Log::Warn << "Test: SplitData_ - True split val: " << true_split_val
+		<< ", Ob split val: " << ob_split_val << endl;
+      return_flag =  false;
+    }
+
+    if (true_lsplit_val != ob_lsplit_val) {
+      Log::Warn << "Test: SplitData_ - True lsplit val: " << true_lsplit_val
+		<< ", Ob lsplit val: " << ob_lsplit_val << endl;
+      return_flag =  false;
+    }
+
+    if (true_rsplit_val != ob_rsplit_val) {
+      Log::Warn << "Test: SplitData_ - True rsplit val: " << true_rsplit_val
+		<< ", Ob rsplit val: " << ob_rsplit_val << endl;
+      return_flag =  false;
+    }
+
+
+    // restore original values
+    delete max_vals_;
+    delete min_vals_;
+    max_vals_ = true_max_vals;
+    min_vals_ = true_min_vals;
+    start_ = true_start;
+    end_ = true_end;
+    error_ = true_error;
+
+    return return_flag;
+
+  } // TestPrivateFunctions
+  
+}; // Class DTree
+
+#include "dtree_impl.hpp"
+
+#endif

Added: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-04-04 15:54:33 UTC (rev 12189)
@@ -0,0 +1,770 @@
+ /**
+ * @file dtree_impl.hpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Implementations of some declared functions in 
+ * the Density Tree class.
+ *
+ */
+
+#ifndef DTREE_IMPL_HPP
+#define DTREE_IMPL_HPP
+
+#include "dtree.hpp"
+
+
+
+// This function computes the l2-error of a given node
+// from the formula - R(t) = -|t|^2 / (N^2 V_t)
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+ComputeNodeError_(size_t total_points) 
+{
+  size_t node_size = end_ - start_;
+
+  cT log_vol_t = 0;
+  for (size_t i = 0; i < max_vals_->n_elem; i++)
+    if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+      // using log to prevent overflow
+      log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+
+  // check for overflow -- if it doesn't work, try higher precision
+  // by default cT = long double, so if you can't work with that
+  // there is nothing else you can do - except computing error using 
+  // log and dealing with everything in log form.
+  assert(std::exp(log_vol_t) > 0.0);
+
+  cT log_neg_error = 2 * std::log((cT) node_size / (cT) total_points) 
+    - log_vol_t;
+
+  assert(std::exp(log_neg_error) > 0.0);
+
+  cT error = -1.0 * std::exp(log_neg_error); 
+
+  return error;
+} // ComputeNodeError
+
+
+// This function find the best split with respect to the L2-error
+// but trying all possible splits.
+// The dataset is the full data set but the start_ and end_ 
+// are used to obtain the point in this node.
+template<typename eT, typename cT>
+bool DTree<eT, cT>::
+FindSplit_(MatType* data, 
+	   size_t *split_dim, 
+	   size_t *split_ind,
+	   cT *left_error,
+	   cT *right_error,
+	   size_t maxLeafSize,
+	   size_t minLeafSize)
+{
+  assert(data->n_rows == max_vals_->n_elem);
+  assert(data->n_rows == min_vals_->n_elem);
+
+  size_t total_n = data->n_cols, n_t = end_ - start_;
+
+  cT min_error = error_;
+  bool some_split_found = false;
+  size_t point_mass_in_dim = 0;
+
+  // printf("In FindSplit %Lg\n", error_);fflush(NULL);
+
+  // loop through each dimension
+  for (size_t dim = 0; dim < max_vals_->n_elem; dim++) {
+    // have to deal with REAL, INTEGER, NOMINAL data
+    // differently so have to think of how to do that.
+    eT min = (*min_vals_)[dim], max = (*max_vals_)[dim];
+
+    // checking if there is any scope of splitting in this dim
+    if (max - min > 0.0) {
+      // initializing all the stuff for this dimension
+      bool dim_split_found = false;
+      cT min_dim_error = min_error,
+	temp_lval = 0.0, temp_rval = 0.0;
+      size_t dim_split_ind = -1;
+
+      cT log_range_all_not_dim = 0;
+      for (size_t i = 0; i < max_vals_->n_elem; i++) {
+
+	if ((*max_vals_)[i] -(*min_vals_)[i] > 0.0 && i != dim) {
+	  log_range_all_not_dim 
+	    += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+	}
+      }
+
+      assert(std::exp(log_range_all_not_dim) > 0);
+
+      // get the values for the dimension
+      // NEED TO CHECK: if this works correctly
+      RowVecType dim_val_vec = data->row(dim).subvec(start_, end_ - 1);
+
+      // sort the values in ascending order
+      dim_val_vec = arma::sort(dim_val_vec);
+
+      // get ready to go through the sorted list and compute error
+      assert(dim_val_vec.n_elem > maxLeafSize);
+      // enforcing the leaves to have a minimum of MIN_LEAF_SIZE 
+      // number of points to avoid spikes
+
+      // one way of doing it is only considering splits resulting
+      // in sizes > MIN_LEAF_SIZE
+      size_t left_child_size = minLeafSize - 1, right_child_size;
+
+      // finding the best split for this dimension
+      // need to figure out why there are spikes if 
+      // this min_leaf_size is enforced here
+      for (size_t i = minLeafSize -1;
+	   i < dim_val_vec.n_elem - minLeafSize;
+	   i++, left_child_size++) {
+
+	eT split, lsplit = dim_val_vec[i],
+	  rsplit = dim_val_vec[i + 1];
+
+	if (lsplit < rsplit) {
+
+	  // this makes sense for real continuous data
+	  // This kinda corrupts the data and estimation
+	  // if the data is ordinal
+	  split = (lsplit + rsplit) / 2;
+
+	  // Another way of picking split is using 
+	  // 	  split = left_split;
+
+	  if (split - min > 0.0 && max - split > 0.0) {
+
+	    assert(std::exp(log_range_all_not_dim 
+			    + (cT) std::log(split - min)) > 0);
+	    assert(std::exp(log_range_all_not_dim 
+			    + (cT) std::log(max - split)) > 0);
+
+	    cT temp_log_neg_l_error = 2 * std::log((cT) (i + 1) / (cT) total_n)
+	      - (log_range_all_not_dim + (cT) std::log(split - min));
+
+	    assert(std::exp(temp_log_neg_l_error) > 0.0);
+
+	    cT temp_l_error = -1.0 * std::exp(temp_log_neg_l_error);
+
+// 	      = -1.0 * ((cT)(i + 1) / (cT)total_n)
+// 	      * ((cT)(i + 1) / (cT)total_n)
+// 	      / (std::exp(log_range_all_not_dim 
+// 			  + (cT) std::log(split - min)));
+
+	    assert(std::abs(temp_l_error) 
+		   < std::numeric_limits<cT>::max());
+
+	    cT temp_log_neg_r_error 
+	      = 2 * std::log((cT) (n_t - i - 1) / (cT) total_n)
+	      - (log_range_all_not_dim + (cT) std::log(max - split));
+
+	    assert(std::exp(temp_log_neg_r_error) > 0.0);
+
+	    right_child_size = n_t - i - 1;
+	    assert(right_child_size >= minLeafSize);
+	    
+	    cT temp_r_error = -1.0 * std::exp(temp_log_neg_r_error);
+
+// 	      = -1.0 * ((cT) (n_t - i - 1) / (cT)total_n)
+// 	      * ((cT) (n_t - i - 1) / (cT)total_n)
+// 	      / (std::exp(log_range_all_not_dim 
+// 			  + (cT) std::log(max - split)));
+
+	    assert(std::abs(temp_r_error) 
+		   < std::numeric_limits<cT>::max());
+
+	    //if (temp_l + temp_r <= min_dim_error) {
+	    // why not just less than
+	    if (temp_l_error + temp_r_error < min_dim_error) {
+	      min_dim_error = temp_l_error + temp_r_error;
+	      temp_lval = temp_l_error;
+	      temp_rval = temp_r_error;
+	      dim_split_ind = i;
+	      dim_split_found = true;
+	    } // end if improvement
+	  } // end if split - min > 0 & max - split > 0
+	} // end if lsplit < rsplit instead of being equal
+      } // end for loop over all splits in this dim
+      
+      dim_val_vec.clear();
+
+      if ((min_dim_error < min_error) && dim_split_found) {
+	min_error = min_dim_error;
+	*split_dim = dim;
+	*split_ind = dim_split_ind;
+	*left_error = temp_lval;
+	*right_error = temp_rval;
+	some_split_found = true;
+      } // end if better split found in this dim
+    } else {
+      point_mass_in_dim++;
+    } // end if
+  } // end for each dimension
+
+  return some_split_found;
+} // end FindSplit_
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+SplitData_(MatType* data, 
+	   size_t split_dim,
+	   size_t split_ind,
+	   arma::Col<size_t> *old_from_new, 
+	   eT *split_val,
+	   eT *lsplit_val,
+	   eT *rsplit_val) 
+{
+  // get the values for the split dim
+  RowVecType dim_val_vec 
+    = data->row(split_dim).subvec(start_, end_ - 1);
+
+  // sort the values
+  dim_val_vec = arma::sort(dim_val_vec);
+
+  *lsplit_val =  dim_val_vec[split_ind];
+  *rsplit_val =  dim_val_vec[split_ind + 1];
+  *split_val = (*lsplit_val + *rsplit_val) / 2 ;
+
+  std::vector<bool> left_membership;
+  left_membership.reserve(end_ - start_);
+
+  for (size_t i = start_; i < end_; i++) {
+    if ((*data)(split_dim, i) > *split_val)
+      left_membership[i - start_] = false;
+    else 
+      left_membership[i - start_] = true;
+  }
+
+  size_t left_ind = start_, right_ind = end_ - 1;
+  for(;;) {
+    while (left_membership[left_ind - start_] && (left_ind <= right_ind))
+      left_ind++;
+
+    while (!left_membership[right_ind - start_] && (left_ind <= right_ind))
+      right_ind--;
+
+    if (left_ind > right_ind)
+      break;
+
+
+    data->swap_cols(left_ind, right_ind);
+    bool tmp = left_membership[left_ind - start_];
+    left_membership[left_ind - start_] = left_membership[right_ind - start_];
+    left_membership[right_ind - start_] = tmp;
+
+    size_t t = (*old_from_new)[left_ind];
+    (*old_from_new)[left_ind] = (*old_from_new)[right_ind];
+    (*old_from_new)[right_ind] = t;
+
+  } // swap for loop
+
+  assert(left_ind == right_ind + 1);
+} // end SplitData_
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+GetMaxMinVals_(MatType* data, 
+	       VecType *max_vals,
+	       VecType *min_vals) {
+
+  max_vals->set_size(data->n_rows);
+  min_vals->set_size(data->n_rows);
+
+  MatType temp_d = arma::trans(*data);
+
+  for (size_t i = 0; i < temp_d.n_cols; i++) {
+
+    VecType dim_vals = arma::sort(temp_d.col(i));
+    (*min_vals)[i] = dim_vals[0];
+    (*max_vals)[i] = dim_vals[dim_vals.n_elem - 1];
+  }
+} // end GetMaxMinVals_
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree() :
+  start_(0),
+  end_(0),
+  max_vals_(NULL),
+  min_vals_(NULL),
+  left_(NULL),
+  right_(NULL)
+{}
+
+
+// Root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals, 
+      VecType* min_vals,
+      size_t total_points) :
+  start_(0),
+  end_(total_points),
+  max_vals_(max_vals),
+  min_vals_(min_vals),
+  left_(NULL),
+  right_(NULL)
+{
+  error_ = ComputeNodeError_(total_points);
+  // if this assert fails, this implies that you need 
+  // a higher precision (or higher range) 'eT'
+  assert(std::abs(error_) < std::numeric_limits<cT>::max());
+
+  bucket_tag_ = -1;
+  root_ = true;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(MatType* data) : 
+  start_(0),
+  end_(data->n_cols),
+  left_(NULL),
+  right_(NULL)
+{
+  max_vals_ = new VecType();
+  min_vals_ = new VecType();
+
+  GetMaxMinVals_(data, max_vals_, min_vals_);
+
+  error_ = ComputeNodeError_(data->n_cols);
+
+  bucket_tag_ = -1;
+  root_ = true;
+}
+
+
+// Non-root node initializers
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals, 
+      VecType* min_vals,
+      size_t start, 
+      size_t end,
+      cT error) : 
+  start_(start),
+  end_(end),
+  error_(error),
+  max_vals_(max_vals),
+  min_vals_(min_vals),
+  left_(NULL),
+  right_(NULL)
+{
+  bucket_tag_ = -1;
+  root_ = false;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+DTree(VecType* max_vals, 
+      VecType* min_vals,
+      size_t total_points,
+      size_t start,
+      size_t end) : 
+  start_(start),
+  end_(end),
+  max_vals_(max_vals),
+  min_vals_(min_vals),
+  left_(NULL),
+  right_(NULL)
+{
+    
+  error_ = ComputeNodeError_(total_points);
+
+  bucket_tag_ = -1;
+  root_ = false;
+}
+
+
+template<typename eT, typename cT>
+DTree<eT, cT>::
+~DTree() 
+{
+  if (left_ != NULL)
+    delete left_;
+  
+  if (right_ != NULL)
+    delete right_;
+  
+  if (min_vals_ != NULL)
+    delete min_vals_;
+
+  if (max_vals_ != NULL)
+    delete max_vals_;
+}
+
+
+// Greedily expand the tree
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+Grow(MatType* data, 
+     arma::Col<size_t> *old_from_new,
+     bool useVolReg,
+     size_t maxLeafSize,
+     size_t minLeafSize) 
+{    
+  assert(data->n_rows == max_vals_->n_elem);
+  assert(data->n_rows == min_vals_->n_elem);
+  
+  cT left_g, right_g;
+
+  // computing points ratio
+  ratio_ = (cT) (end_ - start_)
+    / (cT) old_from_new->n_elem;
+
+  // computing the v_t_inv:
+  // the inverse of the volume of the node
+  cT log_vol_t = 0;
+  for (size_t i = 0; i < max_vals_->n_elem; i++)
+    if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+      // using log to prevent overflow
+      log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+
+  // check for overflow
+  assert(std::exp(log_vol_t) > 0.0);
+  v_t_inv_ = 1.0 / std::exp(log_vol_t);
+
+  // Checking if node is large enough
+  if ((size_t) (end_ - start_) > maxLeafSize) {
+
+    // find the split
+    size_t dim, split_ind;
+    cT left_error, right_error;
+    if (FindSplit_(data, &dim, &split_ind,
+		   &left_error, &right_error,
+		   maxLeafSize, minLeafSize)) {
+
+      // printf("Split found\n");fflush(NULL);
+      // Split the data for the children
+      // MatType data_l, data_r;
+      eT split_val, lsplit_val, rsplit_val;
+      SplitData_(data, dim, split_ind,
+		 old_from_new, &split_val,
+		 &lsplit_val, &rsplit_val);
+
+      // make max and min vals for the children
+      VecType* max_vals_l = new VecType(*max_vals_);
+      VecType* max_vals_r = new VecType(*max_vals_);
+      VecType* min_vals_l = new VecType(*min_vals_);
+      VecType* min_vals_r = new VecType(*min_vals_);
+
+      (*max_vals_l)[dim] = split_val; // changed from just lsplit_val
+      (*min_vals_r)[dim] = split_val; // changed from just rsplit_val
+
+      // store split dim and split val in the node
+      split_value_ = split_val;
+      split_dim_ = dim;
+
+
+      // Recursively growing the children
+      left_ = new DTree(max_vals_l, min_vals_l,
+			start_, start_ + split_ind + 1,
+			left_error);
+      right_ = new DTree(max_vals_r, min_vals_r,
+			 start_ + split_ind + 1, end_, 
+			 right_error);
+
+      left_g = left_->Grow(data, old_from_new, useVolReg, 
+			   maxLeafSize, minLeafSize);
+      right_g = right_->Grow(data, old_from_new, useVolReg,
+			     maxLeafSize, minLeafSize);
+
+      // storing values of R(T~) and |T~|
+      subtree_leaves_ = left_->subtree_leaves() + right_->subtree_leaves();
+      subtree_leaves_error_ = left_->subtree_leaves_error()
+	+ right_->subtree_leaves_error();
+
+      // storing the subtree_leaves_v_t_inv
+      subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
+	+ right_->subtree_leaves_v_t_inv();
+
+      // 	// storing the sum of the estimates (OF WHAT)
+      // 	st_estimate_ = left_->st_estimate() + right_->st_estimate();
+
+      // 	// storing del_f / del r(split_dim)
+      // 	cT del_f = (ratio_ * v_t_inv_)
+      // 	  - (left_->ratio() * left_->v_t_inv());
+      // 	cT del_r = max_vals_[split_dim_] - split_value_;
+      // 	del_f_del_r_ = fabs(del_f / del_r);
+
+      // Forming T1 by removing leaves for which
+      // R(t) = R(t_L) + R(t_R)
+      if ((left_->subtree_leaves() == 1)
+	  && (right_->subtree_leaves() == 1)) {
+	if (left_->error() + right_->error() == error_) {
+	  delete left_;
+	  left_ = NULL;
+	  delete right_;
+	  right_ = NULL;
+	  subtree_leaves_ = 1;
+	  subtree_leaves_error_ = error_;
+	  subtree_leaves_v_t_inv_ = v_t_inv_;
+	} // end if
+      } // end if
+    } else {
+      // no split found so make a leaf out of it
+      subtree_leaves_ = 1;
+      subtree_leaves_error_ = error_;
+      subtree_leaves_v_t_inv_ = v_t_inv_;
+//       st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+//       del_f_del_r_ = 0.0;
+    } // end if-else
+  } else {
+    // We can make this a leaf node
+    assert((size_t) (end_ - start_) >= minLeafSize);
+    subtree_leaves_ = 1;
+    subtree_leaves_error_ = error_;
+    subtree_leaves_v_t_inv_ = v_t_inv_;
+
+    //       // TO CHECK: 
+    //       // if these are the density estimate 
+    //       // it should be ratio_ * v_t_inv_
+    //       st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+    //       del_f_del_r_ = 0.0;
+  } // end if-else 
+    
+    // if leaf do not compute g_k(t), else compute, store,
+    // and propagate min(g_k(t_L),g_k(t_R),g_k(t)), 
+    // unless t_L and/or t_R are leaves
+  if (subtree_leaves_ == 1) {
+    return std::numeric_limits<cT>::max();
+  } else {
+    cT g_t;
+    if (useVolReg) {
+      g_t = (error_ - subtree_leaves_error_) 
+	/ (subtree_leaves_v_t_inv_ - v_t_inv_);
+    } else {
+      g_t = (error_ - subtree_leaves_error_) 
+	/ (subtree_leaves_ - 1);
+    }
+
+    assert(g_t > 0.0);
+    return min(g_t, min(left_g, right_g));
+  } // end if-else
+
+    // need to compute (c_t^2)*r_t for all subtree leaves
+    // this is equal to n_t^2/r_t*n^2 = -error_ !!
+    // therefore the value we need is actually
+    // -1.0*subtree_leaves_error_
+} // Grow
+
+
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+PruneAndUpdate(cT old_alpha,
+	       bool useVolReg)
+{
+  // compute g_t
+  if (subtree_leaves_ == 1) { // if leaf
+    return std::numeric_limits<cT>::max();
+  } else {
+
+    // compute g_t value for node t
+    cT g_t;
+    if (useVolReg) {
+      g_t = (error_ - subtree_leaves_error_) 
+	/ (subtree_leaves_v_t_inv_ - v_t_inv_);
+    } else {
+      g_t = (error_ - subtree_leaves_error_) 
+	/ (subtree_leaves_ - 1);
+    }
+
+    if (g_t > old_alpha) { // go down the tree and update accordingly
+      // traverse the children
+      cT left_g = left_->PruneAndUpdate(old_alpha, useVolReg);
+      cT right_g = right_->PruneAndUpdate(old_alpha, useVolReg);
+
+      // update values
+      subtree_leaves_ = left_->subtree_leaves()
+	+ right_->subtree_leaves();
+      subtree_leaves_error_ = left_->subtree_leaves_error()
+	+ right_->subtree_leaves_error();
+      subtree_leaves_v_t_inv_ = left_->subtree_leaves_v_t_inv()
+	+ right_->subtree_leaves_v_t_inv();
+
+      // 	// updating values for the sum of density estimates 
+      // 	st_estimate_
+      // 	  = left_->st_estimate() + right_->st_estimate();
+
+      // update g_t value
+      if (useVolReg) {
+	g_t = (error_ - subtree_leaves_error_) 
+	  / (subtree_leaves_v_t_inv_ - v_t_inv_);
+      } else {
+	g_t = (error_ - subtree_leaves_error_) 
+	  / (subtree_leaves_ - 1);
+      }
+
+      assert(g_t < std::numeric_limits<cT>::max());
+
+      if (left_->subtree_leaves() == 1
+	  && right_->subtree_leaves() == 1) {
+	return g_t;
+      } else if (left_->subtree_leaves() == 1) {
+	return min(g_t, right_g);
+      } else if (right_->subtree_leaves() == 1) {
+	return min(g_t, left_g);
+      } else {
+	return min(g_t, min(left_g, right_g));
+      }
+    } else { // prune this subtree
+
+      // otherwise this should be equal to the alpha
+      // for this node. So we check that:
+      // assert(g_t == old_alpha, "Alpha != g(t) but less than!!");
+
+      // 	// compute \del f_hat(x) / \del r(split_dim)
+      // 	cT st_change_in_estimate 
+      // 	  = st_estimate_ - (ratio_ * ratio_ * v_t_inv_);
+
+      // printf("%lg:%lg Pruned %lg\n",
+      //       old_alpha, del_f_del_r_, st_change_in_estimate);
+
+
+      // making this node a leaf node
+      subtree_leaves_ = 1;
+      subtree_leaves_error_ = error_;
+      subtree_leaves_v_t_inv_ = v_t_inv_;
+//       st_estimate_ = ratio_ * ratio_ * v_t_inv_;
+//       del_f_del_r_ = 0.0;
+      delete left_;
+      left_ = NULL;
+      delete right_;
+      right_ = NULL;
+      // passing information upward
+      return std::numeric_limits<cT>::max();
+    } // end if-else for pruning subtree
+  } /// if-else for leaf or non-leaf
+} // PruneAndUpdate
+
+
+// Checking whether a given point is within the
+// bounding box of this node (check generally done
+// at the root, so its the bounding box of the data)
+//
+// Improvement: To open up the range with epsilons on 
+// both sides where epsilon on the density near the boundary.
+template<typename eT, typename cT>
+bool DTree<eT, cT>::
+WithinRange_(VecType* query) 
+{
+  for (size_t i = 0; i < query->n_elem; i++)
+    if (((*query)[i] < (*min_vals_)[i]) 
+	|| ((*query)[i] > (*max_vals_)[i]))
+      return false;
+
+  return true;
+}
+
+
+template<typename eT, typename cT>
+cT DTree<eT, cT>::
+ComputeValue(VecType* query)
+{
+  assert(query->n_elem == max_vals_->n_elem);
+
+  if (root_ == 1) // if root
+    // check if query is within range
+    if (!WithinRange_(query))
+      return 0.0;
+  // end WithinRange_ if-else
+  // end root if
+
+  if (subtree_leaves_ == 1)  // if leaf
+    return ratio_ * v_t_inv_;
+  else
+    if ((*query)[split_dim_] <= split_value_)  // if left subtree
+      // go to left child
+      return left_->ComputeValue(query); //, printer);
+    else  // if right subtree
+      // go to right child
+      return right_->ComputeValue(query);
+  // end if-else 
+} // ComputeValue  
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+WriteTree(size_t level, FILE *fp) 
+{
+  if (subtree_leaves_ > 1){
+    fprintf(fp, "\n");
+    for (size_t i = 0; i < level; i++){
+      fprintf(fp, "|\t");
+    }
+    fprintf(fp, "Var. %zu > %lg",
+	    split_dim_, split_value_);
+    right_->WriteTree(level+1, fp);
+    fprintf(fp, "\n");
+    for (size_t i = 0; i < level; i++){
+      fprintf(fp, "|\t");
+    }
+    fprintf(fp, "Var. %zu <= %lg ", 
+	    split_dim_, split_value_);
+    left_->WriteTree(level+1, fp);
+
+  } else { // if leaf
+    fprintf(fp, ": f(x)=%Lg", (cT) ratio_ * v_t_inv_);
+    if (bucket_tag_ != -1) 
+      fprintf(fp, " BT:%d", bucket_tag_);
+  }
+} // WriteTree
+
+
+// indexing the buckets for possible usage later
+template<typename eT, typename cT>
+int DTree<eT, cT>::
+TagTree(int tag) 
+{
+  if (subtree_leaves_ == 1) {
+    bucket_tag_ = tag;
+    return (tag+1);
+  } else {
+    return right_->TagTree(left_->TagTree(tag));
+  }
+} // TagTree
+
+
+template<typename eT, typename cT>
+int DTree<eT, cT>::
+FindBucket(VecType* query) 
+{
+  assert(query->n_elem == max_vals_->n_elem);
+
+  if (subtree_leaves_ == 1) { // if leaf
+    return bucket_tag_;
+  } else if ((*query)[split_dim_] <= split_value_) { // if left subtree
+    // go to left child
+    return left_->FindBucket(query);
+  } else { // if right subtree
+    // go to right child
+    return right_->FindBucket(query);
+  } // end if-else
+} // FindBucket
+
+
+template<typename eT, typename cT>
+void DTree<eT, cT>::
+ComputeVariableImportance(arma::Col<double> *imps)
+{
+  if (subtree_leaves_ == 1) {
+    // if leaf, do nothing
+    return;
+  } else {
+    // compute the improvement in error because of the 
+    // split
+    double error_improv
+      = (double) (error_ - (left_->error() + right_->error()));
+    (*imps)[split_dim_] += error_improv;
+    left_->ComputeVariableImportance(imps);
+    right_->ComputeVariableImportance(imps);
+    return;
+  }
+} // ComputeVariableImportance
+
+
+#endif




More information about the mlpack-svn mailing list