[mlpack-svn] r13310 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Aug 1 20:52:39 EDT 2012


Author: rcurtin
Date: 2012-08-01 20:52:39 -0400 (Wed, 01 Aug 2012)
New Revision: 13310

Added:
   mlpack/trunk/src/mlpack/methods/det/dt_utils.cpp
   mlpack/trunk/src/mlpack/methods/det/dtree.cpp
Removed:
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Modified:
   mlpack/trunk/src/mlpack/methods/det/CMakeLists.txt
   mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
Log:
Remove templatization of density estimation trees (it wasn't necessary).  It
will be brought back, but not at this time.


Modified: mlpack/trunk/src/mlpack/methods/det/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/CMakeLists.txt	2012-08-01 20:47:08 UTC (rev 13309)
+++ mlpack/trunk/src/mlpack/methods/det/CMakeLists.txt	2012-08-02 00:52:39 UTC (rev 13310)
@@ -4,13 +4,13 @@
 # Anything not in this list will not be compiled into the output library
 # Do not include test programs here
 set(SOURCES
-
   # the DET class
   dtree.hpp
-  dtree_impl.hpp
+  dtree.cpp
 
   # the util file
   dt_utils.hpp
+  dt_utils.cpp
 )
 
 # add directory name to sources

Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-08-01 20:47:08 UTC (rev 13309)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -89,8 +89,8 @@
 
   // Obtain the optimal tree.
   Timer::Start("det_training");
-  DTree<double> *dtreeOpt = Trainer<double>(trainingData, folds,
-      regularization, maxLeafSize, minLeafSize, unprunedTreeEstimateFile);
+  DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
+      minLeafSize, unprunedTreeEstimateFile);
   Timer::Stop("det_training");
 
   // Compute densities for the training points in the optimal tree.
@@ -187,15 +187,14 @@
     Log::Assert(trainingData.n_cols == labels.n_cols);
     Log::Assert(labels.n_rows == 1);
 
-    PrintLeafMembership<double>(dtreeOpt, trainingData, labels, num_classes,
+    PrintLeafMembership(dtreeOpt, trainingData, labels, num_classes,
        CLI::GetParam<string>("output/leaf_class_table"));
   }
 
   // Print variable importance.
   if (CLI::HasParam("flag/print_vi"))
   {
-    PrintVariableImportance<double>(dtreeOpt,
-        CLI::GetParam<string>("output/vi"));
+    PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("output/vi"));
   }
 
   delete dtreeOpt;

Copied: mlpack/trunk/src/mlpack/methods/det/dt_utils.cpp (from rev 13309, mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.cpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -0,0 +1,300 @@
+/**
+ * @file dt_utils.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * This file implements functions to perform different tasks with the Density
+ * Tree class.
+ */
+#include "dt_utils.hpp"
+
+using namespace mlpack;
+using namespace det;
+
+void mlpack::det::PrintLeafMembership(DTree* dtree,
+                                      const arma::mat& data,
+                                      const arma::Mat<size_t>& labels,
+                                      const size_t numClasses,
+                                      const std::string leafClassMembershipFile)
+{
+  // Tag the leaves with numbers.
+  int numLeaves = dtree->TagTree();
+
+  arma::Mat<size_t> table(numLeaves, numClasses);
+  table.zeros();
+
+  for (size_t i = 0; i < data.n_cols; i++)
+  {
+    const arma::vec testPoint = data.unsafe_col(i);
+    const int leafTag = dtree->FindBucket(testPoint);
+    const size_t label = labels[i];
+    table(leafTag, label) += 1;
+  }
+
+  if (leafClassMembershipFile == "")
+  {
+    Log::Info << "Leaf membership; row represents leaf id, column represents "
+        << "class id; value represents number of points in leaf in class."
+        << std::endl << table;
+  }
+  else
+  {
+    // Create a stream for the file.
+    std::ofstream outfile(leafClassMembershipFile.c_str());
+    if (outfile.good())
+    {
+      outfile << table;
+      Log::Info << "Leaf membership printed to '" << leafClassMembershipFile
+          << "'." << std::endl;
+    }
+    else
+    {
+      Log::Warn << "Can't open '" << leafClassMembershipFile << "' to write "
+          << "leaf membership to." << std::endl;
+    }
+    outfile.close();
+  }
+
+  return;
+}
+
+
+void mlpack::det::PrintVariableImportance(const DTree* dtree,
+                                          const std::string viFile)
+{
+  arma::vec imps;
+  dtree->ComputeVariableImportance(imps);
+
+  double max = 0.0;
+  for (size_t i = 0; i < imps.n_elem; ++i)
+    if (imps[i] > max)
+      max = imps[i];
+
+  Log::Info << "Maximum variable importance: " << max << "." << std::endl;
+
+  if (viFile == "")
+  {
+    Log::Info << "Variable importance: " << std::endl << imps.t() << std::endl;
+  }
+  else
+  {
+    std::ofstream outfile(viFile.c_str());
+    if (outfile.good())
+    {
+      outfile << imps;
+      Log::Info << "Variable importance printed to '" << viFile << "'."
+          << std::endl;
+    }
+    else
+    {
+      Log::Warn << "Can't open '" << viFile << "' to write variable importance "
+          << "to." << std::endl;
+    }
+    outfile.close();
+  }
+}
+
+
+// This function trains the optimal decision tree using the given number of
+// folds.
+DTree* mlpack::det::Trainer(arma::mat& dataset,
+                            const size_t folds,
+                            const bool useVolumeReg,
+                            const size_t maxLeafSize,
+                            const size_t minLeafSize,
+                            const std::string unprunedTreeOutput)
+{
+  // Initialize the tree.
+  DTree* dtree = new DTree(dataset);
+
+  // Prepare to grow the tree...
+  arma::Col<size_t> oldFromNew(dataset.n_cols);
+  for (size_t i = 0; i < oldFromNew.n_elem; i++)
+    oldFromNew[i] = i;
+
+  // Save the dataset since it would be modified while growing the tree.
+  arma::mat newDataset(dataset);
+
+  // Growing the tree
+  double oldAlpha = 0.0;
+  double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
+      minLeafSize);
+
+  Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
+      << "dataset; minimum alpha: " << alpha << "." << std::endl;
+
+  // Compute densities for the training points in the full tree, if we were
+  // asked for this.
+  if (unprunedTreeOutput != "")
+  {
+    std::ofstream outfile(unprunedTreeOutput.c_str());
+    if (outfile.good())
+    {
+      for (size_t i = 0; i < dataset.n_cols; ++i)
+      {
+        arma::vec testPoint = dataset.unsafe_col(i);
+        outfile << dtree->ComputeValue(testPoint) << std::endl;
+      }
+    }
+    else
+    {
+      Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed"
+          << " densities to." << std::endl;
+    }
+
+    outfile.close();
+  }
+
+  // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
+  std::vector<std::pair<double, double> > prunedSequence;
+  while (dtree->SubtreeLeaves() > 1)
+  {
+    std::pair<double, double> treeSeq(oldAlpha,
+        dtree->SubtreeLeavesLogNegError());
+    prunedSequence.push_back(treeSeq);
+    oldAlpha = alpha;
+    alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
+
+    // Some sanity checks.
+    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
+        (dtree->SubtreeLeaves() == 1));
+    Log::Assert(alpha > oldAlpha);
+    Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second);
+  }
+
+  std::pair<double, double> treeSeq(oldAlpha,
+      dtree->SubtreeLeavesLogNegError());
+  prunedSequence.push_back(treeSeq);
+
+  Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
+      << " " << oldAlpha << "." << std::endl;
+
+  delete dtree;
+
+  arma::mat cvData(dataset);
+  size_t testSize = dataset.n_cols / folds;
+
+  // Go through each fold.
+  for (size_t fold = 0; fold < folds; fold++)
+  {
+    // Break up data into train and test sets.
+    size_t start = fold * testSize;
+    size_t end = std::min((fold + 1) * testSize, (size_t) cvData.n_cols);
+
+    arma::mat test = cvData.cols(start, end - 1);
+    arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);
+
+    if (start == 0 && end < cvData.n_cols)
+    {
+      train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
+    }
+    else if (start > 0 && end == cvData.n_cols)
+    {
+      train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
+    }
+    else
+    {
+      train.cols(0, start - 1) = cvData.cols(0, start - 1);
+      train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
+    }
+
+    // Initialize the tree.
+    DTree* cvDTree = new DTree(train);
+
+    // Getting ready to grow the tree...
+    arma::Col<size_t> cvOldFromNew(train.n_cols);
+    for (size_t i = 0; i < cvOldFromNew.n_elem; i++)
+      cvOldFromNew[i] = i;
+
+    // Grow the tree.
+    oldAlpha = 0.0;
+    alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
+        minLeafSize);
+
+    // Sequentially prune with all the values of available alphas and adding
+    // values for test values.
+    std::vector<std::pair<double, double> >::iterator it;
+    for (it = prunedSequence.begin(); it < prunedSequence.end() - 2; ++it)
+    {
+      // Compute test values for this state of the tree.
+      double cvVal = 0.0;
+      for (size_t i = 0; i < test.n_cols; i++)
+      {
+        arma::vec testPoint = test.unsafe_col(i);
+        cvVal += cvDTree->ComputeValue(testPoint);
+      }
+
+      // Update the cv error value by mapping out of log-space then back into
+      // it, using long doubles.
+      long double notLogVal = -std::exp((long double) it->second) -
+          2.0 * cvVal / (double) dataset.n_cols;
+      it->second = (double) std::log(-notLogVal);
+
+      // Determine the new alpha value and prune accordingly.
+      oldAlpha = sqrt(((it + 1)->first) * ((it + 2)->first));
+      alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg);
+    }
+
+    // Compute test values for this state of the tree.
+    double cvVal = 0.0;
+    for (size_t i = 0; i < test.n_cols; ++i)
+    {
+      arma::vec testPoint = test.unsafe_col(i);
+      cvVal += cvDTree->ComputeValue(testPoint);
+    }
+
+    // Update the cv error value.
+    long double notLogVal = -std::exp((long double) it->second) -
+        2.0 * cvVal / (double) dataset.n_cols;
+    it->second -= (double) std::log(-notLogVal);
+
+    test.reset();
+    delete cvDTree;
+  }
+
+  double optimalAlpha = -1.0;
+  double cvBestError = std::numeric_limits<double>::max();
+  std::vector<std::pair<double, double> >::iterator it;
+
+  for (it = prunedSequence.begin(); it < prunedSequence.end() -1; ++it)
+  {
+    if (it->second < cvBestError)
+    {
+      cvBestError = it->second;
+      optimalAlpha = it->first;
+    }
+  }
+
+  Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
+
+  // Initialize the tree.
+  DTree* dtreeOpt = new DTree(dataset);
+
+  // Getting ready to grow the tree...
+  for (size_t i = 0; i < oldFromNew.n_elem; i++)
+    oldFromNew[i] = i;
+
+  // Save the dataset since it would be modified while growing the tree.
+  newDataset = dataset;
+
+  // Grow the tree.
+  oldAlpha = 0.0;
+  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
+      minLeafSize);
+
+  // Prune with optimal alpha.
+  while ((oldAlpha > optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
+  {
+    oldAlpha = alpha;
+    alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
+
+    // Some sanity checks.
+    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
+        (dtreeOpt->SubtreeLeaves() == 1));
+    Log::Assert(alpha < oldAlpha);
+  }
+
+  Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally "
+      << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;
+
+  return dtreeOpt;
+}

Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-08-01 20:47:08 UTC (rev 13309)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -16,298 +16,52 @@
 namespace mlpack {
 namespace det {
 
-template<typename eT>
-void PrintLeafMembership(DTree<eT>* dtree,
-                         const arma::Mat<eT>& data,
+/**
+ * Print the membership of leaves of a density estimation tree given the labels
+ * and number of classes.  Optionally, pass the name of a file to print this
+ * information to (otherwise stdout is used).
+ *
+ * @param dtree Tree to print membership of.
+ * @param data Dataset tree is built upon.
+ * @param labels Class labels of dataset.
+ * @param numClasses Number of classes in dataset.
+ * @param leafClassMembershipFile Name of file to print to (optional).
+ */
+void PrintLeafMembership(DTree* dtree,
+                         const arma::mat& data,
                          const arma::Mat<size_t>& labels,
                          const size_t numClasses,
-                         const std::string leafClassMembershipFile = "")
-{
-  // Tag the leaves with numbers.
-  int numLeaves = dtree->TagTree();
+                         const std::string leafClassMembershipFile = "");
 
-  arma::Mat<size_t> table(numLeaves, numClasses);
-  table.zeros();
+/**
+ * Print the variable importance of each dimension of a density estimation tree.
+ * Optionally, pass the name of a file to print this information to (otherwise
+ * stdout is used).
+ *
+ * @param dtree Density tree to use.
+ * @param viFile Name of file to print to (optional).
+ */
+void PrintVariableImportance(const DTree* dtree,
+                             const std::string viFile = "");
 
-  for (size_t i = 0; i < data.n_cols; i++)
-  {
-    const arma::Col<eT> test_p = data.unsafe_col(i);
-    const int leafTag = dtree->FindBucket(test_p);
-    const size_t label = labels[i];
-    table(leafTag, label) += 1;
-  }
+/**
+ * Train the optimal decision tree using cross-validation with the given number
+ * of folds.  Optionally, give a filename to print the unpruned tree to.
+ *
+ * @param dataset Dataset for the tree to use.
+ * @param folds Number of folds to use for cross-validation.
+ * @param useVolumeReg If true, use volume regularization.
+ * @param maxLeafSize Maximum number of points allowed in a leaf.
+ * @param minLeafSize Minimum number of points allowed in a leaf.
+ * @param unprunedTreeOutput Filename to print unpruned tree to (optional).
+ */
+DTree* Trainer(arma::mat& dataset,
+               const size_t folds,
+               const bool useVolumeReg = false,
+               const size_t maxLeafSize = 10,
+               const size_t minLeafSize = 5,
+               const std::string unprunedTreeOutput = "");
 
-  if (leafClassMembershipFile == "")
-  {
-    Log::Info << "Leaf membership; row represents leaf id, column represents "
-        << "class id; value represents number of points in leaf in class."
-        << std::endl << table;
-  }
-  else
-  {
-    // Create a stream for the file.
-    std::ofstream outfile(leafClassMembershipFile.c_str());
-    if (outfile.good())
-    {
-      outfile << table;
-      Log::Info << "Leaf membership printed to '" << leafClassMembershipFile
-          << "'." << std::endl;
-    }
-    else
-    {
-      Log::Warn << "Can't open '" << leafClassMembershipFile << "' to write "
-          << "leaf membership to." << std::endl;
-    }
-    outfile.close();
-  }
-
-  return;
-}
-
-
-template<typename eT>
-void PrintVariableImportance(const DTree<eT>* dtree,
-                             const std::string viFile = "")
-{
-  arma::vec imps;
-  dtree->ComputeVariableImportance(imps);
-
-  double max = 0.0;
-  for (size_t i = 0; i < imps.n_elem; ++i)
-    if (imps[i] > max)
-      max = imps[i];
-
-  Log::Info << "Maximum variable importance: " << max << "." << std::endl;
-
-  if (viFile == "")
-  {
-    Log::Info << "Variable importance: " << std::endl << imps.t() << std::endl;
-  }
-  else
-  {
-    std::ofstream outfile(viFile.c_str());
-    if (outfile.good())
-    {
-      outfile << imps;
-      Log::Info << "Variable importance printed to '" << viFile << "'."
-          << std::endl;
-    }
-    else
-    {
-      Log::Warn << "Can't open '" << viFile << "' to write variable importance "
-          << "to." << std::endl;
-    }
-    outfile.close();
-  }
-}
-
-
-// This function trains the optimal decision tree using the given number of
-// folds.
-template<typename eT>
-DTree<eT>* Trainer(arma::Mat<eT>& dataset,
-                   const size_t folds,
-                   const bool useVolumeReg = false,
-                   const size_t maxLeafSize = 10,
-                   const size_t minLeafSize = 5,
-                   const std::string unprunedTreeOutput = "")
-{
-  // Initialize the tree.
-  DTree<eT>* dtree = new DTree<eT>(dataset);
-
-  // Prepare to grow the tree...
-  arma::Col<size_t> oldFromNew(dataset.n_cols);
-  for (size_t i = 0; i < oldFromNew.n_elem; i++)
-    oldFromNew[i] = i;
-
-  // Save the dataset since it would be modified while growing the tree.
-  arma::Mat<eT> newDataset(dataset);
-
-  // Growing the tree
-  double oldAlpha = 0.0;
-  double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
-      minLeafSize);
-
-  Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
-      << "dataset; minimum alpha: " << alpha << "." << std::endl;
-
-  // Compute densities for the training points in the full tree, if we were
-  // asked for this.
-  if (unprunedTreeOutput != "")
-  {
-    std::ofstream outfile(unprunedTreeOutput.c_str());
-    if (outfile.good())
-    {
-      for (size_t i = 0; i < dataset.n_cols; ++i)
-      {
-        arma::Col<eT> test_p = dataset.unsafe_col(i);
-        outfile << dtree->ComputeValue(test_p) << std::endl;
-      }
-    }
-    else
-    {
-      Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed"
-          << " densities to." << std::endl;
-    }
-
-    outfile.close();
-  }
-
-  // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
-  std::vector<std::pair<double, double> > prunedSequence;
-  while (dtree->SubtreeLeaves() > 1)
-  {
-    std::pair<double, double> treeSeq(oldAlpha,
-        dtree->SubtreeLeavesLogNegError());
-    prunedSequence.push_back(treeSeq);
-    oldAlpha = alpha;
-    alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
-
-    // Some sanity checks.
-    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
-        (dtree->SubtreeLeaves() == 1));
-    Log::Assert(alpha > oldAlpha);
-    Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second);
-  }
-
-  std::pair<double, double> treeSeq(oldAlpha,
-      dtree->SubtreeLeavesLogNegError());
-  prunedSequence.push_back(treeSeq);
-
-  Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
-      << " " << oldAlpha << "." << std::endl;
-
-  delete dtree;
-
-  arma::Mat<eT> cvdata(dataset);
-  size_t testSize = dataset.n_cols / folds;
-
-  // Go through each fold.
-  for (size_t fold = 0; fold < folds; fold++)
-  {
-    // Break up data into train and test sets.
-    size_t start = fold * testSize;
-    size_t end = std::min((fold + 1) * testSize, (size_t) cvdata.n_cols);
-
-    arma::Mat<eT> test = cvdata.cols(start, end - 1);
-    arma::Mat<eT> train(cvdata.n_rows, cvdata.n_cols - test.n_cols);
-
-    if (start == 0 && end < cvdata.n_cols)
-    {
-      train.cols(0, train.n_cols - 1) = cvdata.cols(end, cvdata.n_cols - 1);
-    }
-    else if (start > 0 && end == cvdata.n_cols)
-    {
-      train.cols(0, train.n_cols - 1) = cvdata.cols(0, start - 1);
-    }
-    else
-    {
-      train.cols(0, start - 1) = cvdata.cols(0, start - 1);
-      train.cols(start, train.n_cols - 1) = cvdata.cols(end, cvdata.n_cols - 1);
-    }
-
-    // Initialize the tree.
-    DTree<eT>* cvDTree = new DTree<eT>(train);
-
-    // Getting ready to grow the tree...
-    arma::Col<size_t> cvOldFromNew(train.n_cols);
-    for (size_t i = 0; i < cvOldFromNew.n_elem; i++)
-      cvOldFromNew[i] = i;
-
-    // Grow the tree.
-    oldAlpha = 0.0;
-    alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
-        minLeafSize);
-
-    // Sequentially prune with all the values of available alphas and adding
-    // values for test values.
-    std::vector<std::pair<double, double> >::iterator it;
-    for (it = prunedSequence.begin(); it < prunedSequence.end() - 2; ++it)
-    {
-      // Compute test values for this state of the tree.
-      double cvVal = 0.0;
-      for (size_t i = 0; i < test.n_cols; i++)
-      {
-        arma::Col<eT> testPoint = test.unsafe_col(i);
-        cvVal += cvDTree->ComputeValue(testPoint);
-      }
-
-      // Update the cv error value by mapping out of log-space then back into
-      // it, using long doubles.
-      long double notLogVal = -std::exp((long double) it->second) -
-          2.0 * cvVal / (double) dataset.n_cols;
-      it->second = (double) std::log(-notLogVal);
-
-      // Determine the new alpha value and prune accordingly.
-      oldAlpha = sqrt(((it + 1)->first) * ((it + 2)->first));
-      alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg);
-    }
-
-    // Compute test values for this state of the tree.
-    double cvVal = 0.0;
-    for (size_t i = 0; i < test.n_cols; ++i)
-    {
-      arma::Col<eT> testPoint = test.unsafe_col(i);
-      cvVal += cvDTree->ComputeValue(testPoint);
-    }
-
-    // Update the cv error value.
-    long double notLogVal = -std::exp((long double) it->second) -
-        2.0 * cvVal / (double) dataset.n_cols;
-    it->second -= (double) std::log(-notLogVal);
-
-    test.reset();
-    delete cvDTree;
-  }
-
-  double optimalAlpha = -1.0;
-  double cvBestError = std::numeric_limits<double>::max();
-  std::vector<std::pair<double, double> >::iterator it;
-
-  for (it = prunedSequence.begin(); it < prunedSequence.end() -1; ++it)
-  {
-    if (it->second < cvBestError)
-    {
-      cvBestError = it->second;
-      optimalAlpha = it->first;
-    }
-  }
-
-  Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
-
-  // Initialize the tree.
-  DTree<eT>* dtreeOpt = new DTree<eT>(dataset);
-
-  // Getting ready to grow the tree...
-  for (size_t i = 0; i < oldFromNew.n_elem; i++)
-    oldFromNew[i] = i;
-
-  // Save the dataset since it would be modified while growing the tree.
-  newDataset = dataset;
-
-  // Grow the tree.
-  oldAlpha = 0.0;
-  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
-      minLeafSize);
-
-  // Prune with optimal alpha.
-  while ((oldAlpha > optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
-  {
-    oldAlpha = alpha;
-    alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);
-
-    // Some sanity checks.
-    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
-        (dtreeOpt->SubtreeLeaves() == 1));
-    Log::Assert(alpha < oldAlpha);
-  }
-
-  Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally "
-      << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;
-
-  return dtreeOpt;
-}
-
 }; // namespace det
 }; // namespace mlpack
 

Copied: mlpack/trunk/src/mlpack/methods/det/dtree.cpp (from rev 13309, mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp)
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.cpp	                        (rev 0)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.cpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -0,0 +1,660 @@
+ /**
+ * @file dtree.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Implementations of some declared functions in
+ * the Density Estimation Tree class.
+ *
+ */
+#include "dtree.hpp"
+#include <stack>
+
+using namespace mlpack;
+using namespace det;
+
+DTree::DTree() :
+    start(0),
+    end(0),
+    logNegError(-DBL_MAX),
+    root(true),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+
+// Root node initializers
+DTree::DTree(const arma::vec& maxVals,
+             const arma::vec& minVals,
+             const size_t totalPoints) :
+    start(0),
+    end(totalPoints),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(LogNegativeError(totalPoints)),
+    root(true),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::DTree(arma::mat& data) :
+    start(0),
+    end(data.n_cols),
+    left(NULL),
+    right(NULL)
+{
+  maxVals.set_size(data.n_rows);
+  minVals.set_size(data.n_rows);
+
+  // Initialize to first column; values will be overwritten if necessary.
+  maxVals = data.col(0);
+  minVals = data.col(0);
+
+  // Loop over data to extract maximum and minimum values in each dimension.
+  for (size_t i = 1; i < data.n_cols; ++i)
+  {
+    for (size_t j = 0; j < data.n_rows; ++j)
+    {
+      if (data(j, i) > maxVals[j])
+        maxVals[j] = data(j, i);
+      if (data(j, i) < minVals[j])
+        minVals[j] = data(j, i);
+    }
+  }
+
+  logNegError = LogNegativeError(data.n_cols);
+
+  bucketTag = -1;
+  root = true;
+}
+
+
+// Non-root node initializers
+DTree::DTree(const arma::vec& maxVals,
+             const arma::vec& minVals,
+             const size_t start,
+             const size_t end,
+             const double logNegError) :
+    start(start),
+    end(end),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(logNegError),
+    root(false),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::DTree(const arma::vec& maxVals,
+             const arma::vec& minVals,
+             const size_t totalPoints,
+             const size_t start,
+             const size_t end) :
+    start(start),
+    end(end),
+    maxVals(maxVals),
+    minVals(minVals),
+    logNegError(LogNegativeError(totalPoints)),
+    root(false),
+    bucketTag(-1),
+    left(NULL),
+    right(NULL)
+{ /* Nothing to do. */ }
+
+DTree::~DTree()
+{
+  if (left != NULL)
+    delete left;
+
+  if (right != NULL)
+    delete right;
+}
+
+// This function computes the log-l2-negative-error of a given node from the
+// formula R(t) = log(|t|^2 / (N^2 V_t)).
+double DTree::LogNegativeError(const size_t totalPoints) const
+{
+  // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
+  return 2 * std::log((double) (end - start)) -
+         2 * std::log((double) totalPoints) -
+         arma::accu(arma::log(maxVals - minVals));
+}
+
+// This function finds the best split with respect to the L2-error, by trying
+// all possible splits.  The dataset is the full data set but the start and
+// end are used to obtain the point in this node.
+bool DTree::FindSplit(const arma::mat& data,
+                      size_t& splitDim,
+                      double& splitValue,
+                      double& leftError,
+                      double& rightError,
+                      const size_t maxLeafSize,
+                      const size_t minLeafSize) const
+{
+  // Ensure the dimensionality of the data is the same as the dimensionality of
+  // the bounding rectangle.
+  assert(data.n_rows == maxVals.n_elem);
+  assert(data.n_rows == minVals.n_elem);
+
+  const size_t points = end - start;
+
+  double minError = logNegError;
+  bool splitFound = false;
+
+  // Loop through each dimension.
+  for (size_t dim = 0; dim < maxVals.n_elem; dim++)
+  {
+    // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
+    // think of how to do that...
+    const double min = minVals[dim];
+    const double max = maxVals[dim];
+
+    // If there is nothing to split in this dimension, move on.
+    if (max - min == 0.0)
+      continue; // Skip to next dimension.
+
+    // Initializing all the stuff for this dimension.
+    bool dimSplitFound = false;
+    // Take an error estimate for this dimension.
+    double minDimError = std::pow(points, 2.0) / (max - min);
+    double dimLeftError;
+    double dimRightError;
+    double dimSplitValue;
+
+    // Find the log volume of all the other dimensions.
+    double volumeWithoutDim = logVolume - std::log(max - min);
+
+    // Get the values for the dimension.
+    arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
+
+    // Sort the values in ascending order.
+    dimVec = arma::sort(dimVec);
+
+    // Get ready to go through the sorted list and compute error.
+    assert(dimVec.n_elem > maxLeafSize);
+
+    // Find the best split for this dimension.  We need to figure out why
+    // there are spikes if this minLeafSize is enforced here...
+    for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
+    {
+      // This makes sense for real continuous data.  This kinda corrupts the
+      // data and estimation if the data is ordinal.
+      const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
+
+      if (split == dimVec[i])
+        continue; // We can't split here (two points are the same).
+
+      // Another way of picking split is using this:
+      //   split = leftsplit;
+      if ((split - min > 0.0) && (max - split > 0.0))
+      {
+        // Ensure that the right node will have at least the minimum number of
+        // points.
+        Log::Assert((points - i - 1) >= minLeafSize);
+
+        // Now we have to see if the error will be reduced.  Simple manipulation
+        // of the error function gives us the condition we must satisfy:
+        //   |t_l|^2 / V_l + |t_r|^2 / V_r  >= |t|^2 / (V_l + V_r)
+        // and because the volume is only dependent on the dimension we are
+        // splitting, we can assume V_l is just the range of the left and V_r is
+        // just the range of the right.
+        double negLeftError = std::pow(i + 1, 2.0) / (split - min);
+        double negRightError = std::pow(points - i - 1, 2.0) / (max - split);
+
+        // If this is better, take it.
+        if ((negLeftError + negRightError) >= minDimError)
+        {
+          minDimError = negLeftError + negRightError;
+          dimLeftError = negLeftError;
+          dimRightError = negRightError;
+          dimSplitValue = split;
+          dimSplitFound = true;
+        }
+      }
+    }
+
+    double actualMinDimError = std::log(minDimError) - 2 * std::log(data.n_cols)
+        - volumeWithoutDim;
+
+    if ((actualMinDimError > minError) && dimSplitFound)
+    {
+      // Calculate actual error (in logspace) by adding terms back to our
+      // estimate.
+      minError = actualMinDimError;
+      splitDim = dim;
+      splitValue = dimSplitValue;
+      leftError = std::log(dimLeftError) - 2 * std::log(data.n_cols) -
+          volumeWithoutDim;
+      rightError = std::log(dimRightError) - 2 * std::log(data.n_cols) -
+          volumeWithoutDim;
+      splitFound = true;
+    } // end if better split found in this dimension.
+  }
+
+  return splitFound;
+}
+
+size_t DTree::SplitData(arma::mat& data,
+                        const size_t splitDim,
+                        const double splitValue,
+                        arma::Col<size_t>& oldFromNew) const
+{
+  // Swap all columns such that any columns with value in dimension splitDim
+  // less than or equal to splitValue are on the left side, and all others are
+  // on the right side.  A similar sort to this is also performed in
+  // BinarySpaceTree construction (its comments are more detailed).
+  size_t left = start;
+  size_t right = end - 1;
+  for (;;)
+  {
+    while (data(splitDim, left) <= splitValue)
+      ++left;
+    while (data(splitDim, right) > splitValue)
+      --right;
+
+    if (left > right)
+      break;
+
+    data.swap_cols(left, right);
+
+    // Store the mapping from old to new.
+    const size_t tmp = oldFromNew[left];
+    oldFromNew[left] = oldFromNew[right];
+    oldFromNew[right] = tmp;
+  }
+
+  // This now refers to the first index of the "right" side.
+  return left;
+}
+
+// Greedily expand the tree
+double DTree::Grow(arma::mat& data,
+                   arma::Col<size_t>& oldFromNew,
+                   const bool useVolReg,
+                   const size_t maxLeafSize,
+                   const size_t minLeafSize)
+{
+  assert(data.n_rows == maxVals.n_elem);
+  assert(data.n_rows == minVals.n_elem);
+
+  double leftG, rightG;
+
+  // Compute points ratio.
+  ratio = (double) (end - start) / (double) oldFromNew.n_elem;
+
+  // Compute the log of the volume of the node.
+  logVolume = 0;
+  for (size_t i = 0; i < maxVals.n_elem; ++i)
+    if (maxVals[i] - minVals[i] > 0.0)
+      logVolume += std::log(maxVals[i] - minVals[i]);
+
+  // Check if node is large enough to split.
+  if ((size_t) (end - start) > maxLeafSize) {
+
+    // Find the split.
+    size_t dim;
+    double splitValueTmp;
+    double leftError, rightError;
+    if (FindSplit(data, dim, splitValueTmp, leftError, rightError, maxLeafSize,
+        minLeafSize))
+    {
+      // Move the data around for the children to have points in a node lie
+      // contiguously (to increase efficiency during the training).
+      const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
+
+      // Make max and min vals for the children.
+      arma::vec maxValsL(maxVals);
+      arma::vec maxValsR(maxVals);
+      arma::vec minValsL(minVals);
+      arma::vec minValsR(minVals);
+
+      maxValsL[dim] = splitValueTmp;
+      minValsR[dim] = splitValueTmp;
+
+      // Store split dim and split val in the node.
+      splitValue = splitValueTmp;
+      splitDim = dim;
+
+      // Recursively grow the children.
+      left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
+      right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
+
+      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+          minLeafSize);
+      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
+          minLeafSize);
+
+      // Store values of R(T~) and |T~|.
+      subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+
+      // Find the log negative error of the subtree leaves.  This is kind of an
+      // odd one because we don't want to represent the error in non-log-space,
+      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
+      // V_t (remember E_l has an inverse relationship to the volume of the
+      // nodes) and then subtract log(V_t) at the end of the whole expression.
+      // As a result we do leave log-space, but the largest quantity we
+      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+      // node below this node, which depends heavily on the depth of the tree.
+      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+          - logVolume;
+    }
+    else
+    {
+      // No split found so make a leaf out of it.
+      subtreeLeaves = 1;
+      subtreeLeavesLogNegError = logNegError;
+    }
+  }
+  else
+  {
+    // We can make this a leaf node.
+    assert((size_t) (end - start) >= minLeafSize);
+    subtreeLeaves = 1;
+    subtreeLeavesLogNegError = logNegError;
+  }
+
+  // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
+  // propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
+  // leaves.
+  if (subtreeLeaves == 1)
+  {
+    return std::numeric_limits<double>::max();
+  }
+  else
+  {
+    const double range = maxVals[splitDim] - minVals[splitDim];
+    const double leftRatio = (splitValue - minVals[splitDim]) / range;
+    const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+    const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+    const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+    const size_t thisPow = std::pow(end - start, 2);
+
+    double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
+
+    if (left->SubtreeLeaves() > 1)
+    {
+      const double exponent = 2 * std::log(data.n_cols) + logVolume +
+          left->AlphaUpper();
+
+      // Whether or not this will overflow is highly dependent on the depth of
+      // the tree.
+      tmpAlphaSum += std::exp(exponent);
+    }
+
+    if (right->SubtreeLeaves() > 1)
+    {
+      const double exponent = 2 * std::log(data.n_cols) + logVolume +
+          right->AlphaUpper();
+
+      tmpAlphaSum += std::exp(exponent);
+    }
+
+    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(data.n_cols) - logVolume;
+
+    double gT;
+    if (useVolReg)
+    {
+      // This is wrong for now!
+      gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
+    }
+    else
+    {
+      gT = alphaUpper - std::log(subtreeLeaves - 1);
+    }
+
+    return std::min(gT, std::min(leftG, rightG));
+  }
+
+  // We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
+  // n_t ^ 2 / r_t * n ^ 2 = -error.  Therefore the value we need is actually
+  // -1.0 * subtreeLeavesError.
+}
+
+
+double DTree::PruneAndUpdate(const double oldAlpha,
+                             const size_t points,
+                             const bool useVolReg)
+
+{
+  // Compute gT.
+  if (subtreeLeaves == 1) // If we are a leaf...
+  {
+    return 0;
+  }
+  else
+  {
+    // Compute gT value for node t.
+    double gT;
+    if (useVolReg)
+      gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
+    else
+      gT = alphaUpper - std::log(subtreeLeaves - 1);
+
+    if (gT < oldAlpha)
+    {
+      // Go down the tree and update accordingly.  Traverse the children.
+      double leftG = left->PruneAndUpdate(oldAlpha, useVolReg);
+      double rightG = right->PruneAndUpdate(oldAlpha, useVolReg);
+
+      // Update values.
+      subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
+
+      // Find the log negative error of the subtree leaves.  This is kind of an
+      // odd one because we don't want to represent the error in non-log-space,
+      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
+      // V_t (remember E_l has an inverse relationship to the volume of the
+      // nodes) and then subtract log(V_t) at the end of the whole expression.
+      // As a result we do leave log-space, but the largest quantity we
+      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
+      // node below this node, which depends heavily on the depth of the tree.
+      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
+          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
+          - logVolume;
+
+      // Recalculate upper alpha.
+      const double range = maxVals[splitDim] - minVals[splitDim];
+      const double leftRatio = (splitValue - minVals[splitDim]) / range;
+      const double rightRatio = (maxVals[splitDim] - splitValue) / range;
+
+      const size_t leftPow = std::pow(left->End() - left->Start(), 2);
+      const size_t rightPow = std::pow(right->End() - right->Start(), 2);
+      const size_t thisPow = std::pow(end - start, 2);
+
+      double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
+          thisPow;
+
+      if (left->SubtreeLeaves() > 1)
+      {
+        const double exponent = 2 * std::log(points) + logVolume +
+            left->AlphaUpper();
+
+        // Whether or not this will overflow is highly dependent on the depth of
+        // the tree.
+        tmpAlphaSum += std::exp(exponent);
+      }
+
+      if (right->SubtreeLeaves() > 1)
+      {
+        const double exponent = 2 * std::log(points) + logVolume +
+            right->AlphaUpper();
+
+        tmpAlphaSum += std::exp(exponent);
+      }
+
+      alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(points) - logVolume;
+
+      // Update gT value.
+      if (useVolReg)
+      {
+        // This is incorrect.
+        gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
+      }
+      else
+      {
+        gT = alphaUpper - std::log(subtreeLeaves - 1);
+      }
+
+      assert(gT < std::numeric_limits<double>::max());
+
+      return std::min(gT, std::min(leftG, rightG));
+    }
+    else
+    {
+      // Prune this subtree.
+      // First, make this node a leaf node.
+      subtreeLeaves = 1;
+      subtreeLeavesLogNegError = logNegError;
+
+      delete left;
+      delete right;
+
+      left = NULL;
+      right = NULL;
+
+      // Pass information upward.
+      return std::numeric_limits<double>::max();
+    }
+  }
+}
+
+// Check whether a given point is within the bounding box of this node (check
+// generally done at the root, so its the bounding box of the data).
+//
+// Future improvement: Open up the range with epsilons on both sides where
+// epsilon depends on the density near the boundary.
+bool DTree::WithinRange(const arma::vec& query) const
+{
+  for (size_t i = 0; i < query.n_elem; ++i)
+    if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
+      return false;
+
+  return true;
+}
+
+
+double DTree::ComputeValue(const arma::vec& query) const
+{
+  Log::Assert(query.n_elem == maxVals.n_elem);
+
+  if (root == 1) // If we are the root...
+  {
+    // Check if the query is within range.
+    if (!WithinRange(query))
+      return 0.0;
+  }
+
+  if (subtreeLeaves == 1)  // If we are a leaf...
+  {
+    return std::exp(std::log(ratio) - logVolume);
+  }
+  else
+  {
+    if (query[splitDim] <= splitValue)
+    {
+      // If left subtree, go to left child.
+      return left->ComputeValue(query);
+    }
+    else  // If right subtree, go to right child
+    {
+      return right->ComputeValue(query);
+    }
+  }
+
+  return 0.0;
+}
+
+
+void DTree::WriteTree(FILE *fp, const size_t level) const
+{
+  if (subtreeLeaves > 1)
+  {
+    fprintf(fp, "\n");
+    for (size_t i = 0; i < level; ++i)
+      fprintf(fp, "|\t");
+    fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
+
+    right->WriteTree(fp, level + 1);
+
+    fprintf(fp, "\n");
+    for (size_t i = 0; i < level; ++i)
+      fprintf(fp, "|\t");
+    fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
+
+    left->WriteTree(fp, level);
+  }
+  else // If we are a leaf...
+  {
+    fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
+    if (bucketTag != -1)
+      fprintf(fp, " BT:%d", bucketTag);
+  }
+}
+
+
+// Index the buckets for possible usage later.
+int DTree::TagTree(const int tag)
+{
+  if (subtreeLeaves == 1)
+  {
+    // Only label leaves.
+    bucketTag = tag;
+    return (tag + 1);
+  }
+  else
+  {
+    return right->TagTree(left->TagTree(tag));
+  }
+}
+
+
+int DTree::FindBucket(const arma::vec& query) const
+{
+  Log::Assert(query.n_elem == maxVals.n_elem);
+
+  if (subtreeLeaves == 1) // If we are a leaf...
+  {
+    return bucketTag;
+  }
+  else if (query[splitDim] <= splitValue)
+  {
+    // If left subtree, go to left child.
+    return left->FindBucket(query);
+  }
+  else
+  {
+    // If right subtree, go to right child.
+    return right->FindBucket(query);
+  }
+}
+
+
+void DTree::ComputeVariableImportance(arma::vec& importances) const
+{
+  // Clear and set to right size.
+  importances.zeros(maxVals.n_elem);
+
+  std::stack<const DTree*> nodes;
+  nodes.push(this);
+
+  while(!nodes.empty())
+  {
+    const DTree& curNode = *nodes.top();
+    nodes.pop();
+
+    if (curNode.subtreeLeaves == 1)
+      continue; // Do nothing for leaves.
+
+    // The way to do this entirely in log-space is (at this time) somewhat
+    // unclear.  So this risks overflow.
+    importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
+        (-std::exp(curNode.Left()->LogNegError()) +
+         -std::exp(curNode.Right()->LogNegError())));
+
+    nodes.push(curNode.Left());
+    nodes.push(curNode.Right());
+  }
+}

Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-08-01 20:47:08 UTC (rev 13309)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -16,18 +16,6 @@
 namespace det /** Density Estimation Trees */ {
 
 /**
- * This two types in the template are used for two purposes:
- *
- *   eT - the type to store the data in (for most practical purposes, storing
- *       the data as a float suffices).
- *   cT - the type to perform computations in (computations like computing the
- *       error, the volume of the node etc.).
- *
- * For high dimensional data, it might be possible that the computation might
- * overflow, so you should use either normalize your data in the (-1, 1)
- * hypercube or use long double or modify this code to perform computations
- * using logarithms.
- *
  * A density estimation tree is similar to both a decision tree and a space
  * partitioning tree (like a kd-tree).  Each leaf represents a constant-density
  * hyper-rectangle.  The tree is constructed in such a way as to minimize the
@@ -50,8 +38,6 @@
  * }
  * @endcode
  */
-template<typename eT = double,
-         typename cT = long double>
 class DTree
 {
  public:
@@ -194,8 +180,13 @@
    *
    * @param totalPoints Total number of points in the dataset.
    */
-  inline double LogNegativeError(const size_t totalPoints) const;
+  double LogNegativeError(const size_t totalPoints) const;
 
+  /**
+   * Return whether a query point is within the range of this node.
+   */
+  bool WithinRange(const arma::vec& query) const;
+
  private:
   // The indices in the complete set of points
   // (after all forms of swapping in the original data
@@ -246,9 +237,9 @@
   double alphaUpper;
 
   //! The left child.
-  DTree<eT, cT> *left;
+  DTree* left;
   //! The right child.
-  DTree<eT, cT> *right;
+  DTree* right;
 
  public:
   //! Return the starting index of points contained in this node.
@@ -271,9 +262,9 @@
   //! Return the inverse of the volume of this node.
   double LogVolume() const { return logVolume; }
   //! Return the left child.
-  DTree<eT, cT>* Left() const { return left; }
+  DTree* Left() const { return left; }
   //! Return the right child.
-  DTree<eT, cT>* Right() const { return right; }
+  DTree* Right() const { return right; }
   //! Return whether or not this is the root of the tree.
   bool Root() const { return root; }
   //! Return the upper part of the alpha sum.
@@ -312,15 +303,9 @@
                    const double splitValue,
                    arma::Col<size_t>& oldFromNew) const;
 
-  /**
-   * Return whether a query point is within the range of this node.
-   */
-  inline bool WithinRange(const arma::vec& query) const;
-}; // Class DTree
+};
 
 }; // namespace det
 }; // namespace mlpack
 
-#include "dtree_impl.hpp"
-
 #endif // __MLPACK_METHODS_DET_DTREE_HPP

Deleted: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-08-01 20:47:08 UTC (rev 13309)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-08-02 00:52:39 UTC (rev 13310)
@@ -1,686 +0,0 @@
- /**
- * @file dtree_impl.hpp
- * @author Parikshit Ram (pram at cc.gatech.edu)
- *
- * Implementations of some declared functions in
- * the Density Estimation Tree class.
- *
- */
-#ifndef __MLPACK_METHODS_DET_DTREE_IMPL_HPP
-#define __MLPACK_METHODS_DET_DTREE_IMPL_HPP
-
-#include "dtree.hpp"
-#include <stack>
-
-namespace mlpack {
-namespace det {
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree() :
-    start(0),
-    end(0),
-    logNegError(-DBL_MAX),
-    root(true),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-
-// Root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t totalPoints) :
-    start(0),
-    end(totalPoints),
-    maxVals(maxVals),
-    minVals(minVals),
-    logNegError(LogNegativeError(totalPoints)),
-    root(true),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(arma::mat& data) :
-    start(0),
-    end(data.n_cols),
-    left(NULL),
-    right(NULL)
-{
-  maxVals.set_size(data.n_rows);
-  minVals.set_size(data.n_rows);
-
-  // Initialize to first column; values will be overwritten if necessary.
-  maxVals = data.col(0);
-  minVals = data.col(0);
-
-  // Loop over data to extract maximum and minimum values in each dimension.
-  for (size_t i = 1; i < data.n_cols; ++i)
-  {
-    for (size_t j = 0; j < data.n_rows; ++j)
-    {
-      if (data(j, i) > maxVals[j])
-        maxVals[j] = data(j, i);
-      if (data(j, i) < minVals[j])
-        minVals[j] = data(j, i);
-    }
-  }
-
-  logNegError = LogNegativeError(data.n_cols);
-
-  bucketTag = -1;
-  root = true;
-}
-
-
-// Non-root node initializers
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t start,
-                     const size_t end,
-                     const double logNegError) :
-    start(start),
-    end(end),
-    maxVals(maxVals),
-    minVals(minVals),
-    logNegError(logNegError),
-    root(false),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::DTree(const arma::vec& maxVals,
-                     const arma::vec& minVals,
-                     const size_t totalPoints,
-                     const size_t start,
-                     const size_t end) :
-    start(start),
-    end(end),
-    maxVals(maxVals),
-    minVals(minVals),
-    logNegError(LogNegativeError(totalPoints)),
-    root(false),
-    bucketTag(-1),
-    left(NULL),
-    right(NULL)
-{ /* Nothing to do. */ }
-
-template<typename eT, typename cT>
-DTree<eT, cT>::~DTree()
-{
-  if (left != NULL)
-    delete left;
-
-  if (right != NULL)
-    delete right;
-}
-
-// This function computes the log-l2-negative-error of a given node from the
-// formula R(t) = log(|t|^2 / (N^2 V_t)).
-template<typename eT, typename cT>
-inline double DTree<eT, cT>::LogNegativeError(const size_t totalPoints) const
-{
-  // log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
-  return 2 * std::log((double) (end - start)) -
-         2 * std::log((double) totalPoints) -
-         arma::accu(arma::log(maxVals - minVals));
-}
-
-// This function finds the best split with respect to the L2-error, by trying
-// all possible splits.  The dataset is the full data set but the start and
-// end are used to obtain the point in this node.
-template<typename eT, typename cT>
-bool DTree<eT, cT>::FindSplit(const arma::mat& data,
-                              size_t& splitDim,
-                              double& splitValue,
-                              double& leftError,
-                              double& rightError,
-                              const size_t maxLeafSize,
-                              const size_t minLeafSize) const
-{
-  // Ensure the dimensionality of the data is the same as the dimensionality of
-  // the bounding rectangle.
-  assert(data.n_rows == maxVals.n_elem);
-  assert(data.n_rows == minVals.n_elem);
-
-  const size_t points = end - start;
-
-  double minError = logNegError;
-  bool splitFound = false;
-
-  // Loop through each dimension.
-  for (size_t dim = 0; dim < maxVals.n_elem; dim++)
-  {
-    // Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
-    // think of how to do that...
-    const double min = minVals[dim];
-    const double max = maxVals[dim];
-
-    // If there is nothing to split in this dimension, move on.
-    if (max - min == 0.0)
-      continue; // Skip to next dimension.
-
-    // Initializing all the stuff for this dimension.
-    bool dimSplitFound = false;
-    // Take an error estimate for this dimension.
-    double minDimError = std::pow(points, 2.0) / (max - min);
-    double dimLeftError;
-    double dimRightError;
-    double dimSplitValue;
-
-    // Find the log volume of all the other dimensions.
-    double volumeWithoutDim = logVolume - std::log(max - min);
-
-    // Get the values for the dimension.
-    arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
-
-    // Sort the values in ascending order.
-    dimVec = arma::sort(dimVec);
-
-    // Get ready to go through the sorted list and compute error.
-    assert(dimVec.n_elem > maxLeafSize);
-
-    // Find the best split for this dimension.  We need to figure out why
-    // there are spikes if this minLeafSize is enforced here...
-    for (size_t i = minLeafSize - 1; i < dimVec.n_elem - minLeafSize; ++i)
-    {
-      // This makes sense for real continuous data.  This kinda corrupts the
-      // data and estimation if the data is ordinal.
-      const double split = (dimVec[i] + dimVec[i + 1]) / 2.0;
-
-      if (split == dimVec[i])
-        continue; // We can't split here (two points are the same).
-
-      // Another way of picking split is using this:
-      //   split = leftsplit;
-      if ((split - min > 0.0) && (max - split > 0.0))
-      {
-        // Ensure that the right node will have at least the minimum number of
-        // points.
-        Log::Assert((points - i - 1) >= minLeafSize);
-
-        // Now we have to see if the error will be reduced.  Simple manipulation
-        // of the error function gives us the condition we must satisfy:
-        //   |t_l|^2 / V_l + |t_r|^2 / V_r  >= |t|^2 / (V_l + V_r)
-        // and because the volume is only dependent on the dimension we are
-        // splitting, we can assume V_l is just the range of the left and V_r is
-        // just the range of the right.
-        double negLeftError = std::pow(i + 1, 2.0) / (split - min);
-        double negRightError = std::pow(points - i - 1, 2.0) / (max - split);
-
-        // If this is better, take it.
-        if ((negLeftError + negRightError) >= minDimError)
-        {
-          minDimError = negLeftError + negRightError;
-          dimLeftError = negLeftError;
-          dimRightError = negRightError;
-          dimSplitValue = split;
-          dimSplitFound = true;
-        }
-      }
-    }
-
-    double actualMinDimError = std::log(minDimError) - 2 * std::log(data.n_cols)
-        - volumeWithoutDim;
-
-    if ((actualMinDimError > minError) && dimSplitFound)
-    {
-      // Calculate actual error (in logspace) by adding terms back to our
-      // estimate.
-      minError = actualMinDimError;
-      splitDim = dim;
-      splitValue = dimSplitValue;
-      leftError = std::log(dimLeftError) - 2 * std::log(data.n_cols) -
-          volumeWithoutDim;
-      rightError = std::log(dimRightError) - 2 * std::log(data.n_cols) -
-          volumeWithoutDim;
-      splitFound = true;
-    } // end if better split found in this dimension.
-  }
-
-  return splitFound;
-}
-
-template<typename eT, typename cT>
-size_t DTree<eT, cT>::SplitData(arma::mat& data,
-                                const size_t splitDim,
-                                const double splitValue,
-                                arma::Col<size_t>& oldFromNew) const
-{
-  // Swap all columns such that any columns with value in dimension splitDim
-  // less than or equal to splitValue are on the left side, and all others are
-  // on the right side.  A similar sort to this is also performed in
-  // BinarySpaceTree construction (its comments are more detailed).
-  size_t left = start;
-  size_t right = end - 1;
-  for (;;)
-  {
-    while (data(splitDim, left) <= splitValue)
-      ++left;
-    while (data(splitDim, right) > splitValue)
-      --right;
-
-    if (left > right)
-      break;
-
-    data.swap_cols(left, right);
-
-    // Store the mapping from old to new.
-    const size_t tmp = oldFromNew[left];
-    oldFromNew[left] = oldFromNew[right];
-    oldFromNew[right] = tmp;
-  }
-
-  // This now refers to the first index of the "right" side.
-  return left;
-}
-
-// Greedily expand the tree
-template<typename eT, typename cT>
-double DTree<eT, cT>::Grow(arma::mat& data,
-                           arma::Col<size_t>& oldFromNew,
-                           const bool useVolReg,
-                           const size_t maxLeafSize,
-                           const size_t minLeafSize)
-{
-  assert(data.n_rows == maxVals.n_elem);
-  assert(data.n_rows == minVals.n_elem);
-
-  double leftG, rightG;
-
-  // Compute points ratio.
-  ratio = (double) (end - start) / (double) oldFromNew.n_elem;
-
-  // Compute the log of the volume of the node.
-  logVolume = 0;
-  for (size_t i = 0; i < maxVals.n_elem; ++i)
-    if (maxVals[i] - minVals[i] > 0.0)
-      logVolume += std::log(maxVals[i] - minVals[i]);
-
-  // Check if node is large enough to split.
-  if ((size_t) (end - start) > maxLeafSize) {
-
-    // Find the split.
-    size_t dim;
-    double splitValueTmp;
-    double leftError, rightError;
-    if (FindSplit(data, dim, splitValueTmp, leftError, rightError, maxLeafSize,
-        minLeafSize))
-    {
-      // Move the data around for the children to have points in a node lie
-      // contiguously (to increase efficiency during the training).
-      const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
-
-      // Make max and min vals for the children.
-      arma::vec maxValsL(maxVals);
-      arma::vec maxValsR(maxVals);
-      arma::vec minValsL(minVals);
-      arma::vec minValsR(minVals);
-
-      maxValsL[dim] = splitValueTmp;
-      minValsR[dim] = splitValueTmp;
-
-      // Store split dim and split val in the node.
-      splitValue = splitValueTmp;
-      splitDim = dim;
-
-      // Recursively grow the children.
-      left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
-      right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
-
-      leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
-          minLeafSize);
-      rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
-          minLeafSize);
-
-      // Store values of R(T~) and |T~|.
-      subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-
-      // Find the log negative error of the subtree leaves.  This is kind of an
-      // odd one because we don't want to represent the error in non-log-space,
-      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
-      // V_t (remember E_l has an inverse relationship to the volume of the
-      // nodes) and then subtract log(V_t) at the end of the whole expression.
-      // As a result we do leave log-space, but the largest quantity we
-      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
-      // node below this node, which depends heavily on the depth of the tree.
-      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
-          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
-          - logVolume;
-    }
-    else
-    {
-      // No split found so make a leaf out of it.
-      subtreeLeaves = 1;
-      subtreeLeavesLogNegError = logNegError;
-    }
-  }
-  else
-  {
-    // We can make this a leaf node.
-    assert((size_t) (end - start) >= minLeafSize);
-    subtreeLeaves = 1;
-    subtreeLeavesLogNegError = logNegError;
-  }
-
-  // If this is a leaf, do not compute g_k(t); otherwise compute, store, and
-  // propagate min(g_k(t_L), g_k(t_R), g_k(t)), unless t_L and/or t_R are
-  // leaves.
-  if (subtreeLeaves == 1)
-  {
-    return std::numeric_limits<double>::max();
-  }
-  else
-  {
-    const double range = maxVals[splitDim] - minVals[splitDim];
-    const double leftRatio = (splitValue - minVals[splitDim]) / range;
-    const double rightRatio = (maxVals[splitDim] - splitValue) / range;
-
-    const size_t leftPow = std::pow(left->End() - left->Start(), 2);
-    const size_t rightPow = std::pow(right->End() - right->Start(), 2);
-    const size_t thisPow = std::pow(end - start, 2);
-
-    double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio - thisPow;
-
-    if (left->SubtreeLeaves() > 1)
-    {
-      const double exponent = 2 * std::log(data.n_cols) + logVolume +
-          left->AlphaUpper();
-
-      // Whether or not this will overflow is highly dependent on the depth of
-      // the tree.
-      tmpAlphaSum += std::exp(exponent);
-    }
-
-    if (right->SubtreeLeaves() > 1)
-    {
-      const double exponent = 2 * std::log(data.n_cols) + logVolume +
-          right->AlphaUpper();
-
-      tmpAlphaSum += std::exp(exponent);
-    }
-
-    alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(data.n_cols) - logVolume;
-
-    double gT;
-    if (useVolReg)
-    {
-      // This is wrong for now!
-      gT = alphaUpper;// / (subtreeLeavesVTInv - vTInv);
-    }
-    else
-    {
-      gT = alphaUpper - std::log(subtreeLeaves - 1);
-    }
-
-    return std::min(gT, std::min(leftG, rightG));
-  }
-
-  // We need to compute (c_t^2) * r_t for all subtree leaves; this is equal to
-  // n_t ^ 2 / r_t * n ^ 2 = -error.  Therefore the value we need is actually
-  // -1.0 * subtreeLeavesError.
-}
-
-
-template<typename eT, typename cT>
-double DTree<eT, cT>::PruneAndUpdate(const double oldAlpha,
-                                     const size_t points,
-                                     const bool useVolReg)
-
-{
-  // Compute gT.
-  if (subtreeLeaves == 1) // If we are a leaf...
-  {
-    return 0;
-  }
-  else
-  {
-    // Compute gT value for node t.
-    double gT;
-    if (useVolReg)
-      gT = alphaUpper;// - std::log(subtreeLeavesVTInv - vTInv);
-    else
-      gT = alphaUpper - std::log(subtreeLeaves - 1);
-
-    if (gT < oldAlpha)
-    {
-      // Go down the tree and update accordingly.  Traverse the children.
-      double leftG = left->PruneAndUpdate(oldAlpha, useVolReg);
-      double rightG = right->PruneAndUpdate(oldAlpha, useVolReg);
-
-      // Update values.
-      subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
-
-      // Find the log negative error of the subtree leaves.  This is kind of an
-      // odd one because we don't want to represent the error in non-log-space,
-      // but we have to calculate log(E_l + E_r).  So we multiply E_l and E_r by
-      // V_t (remember E_l has an inverse relationship to the volume of the
-      // nodes) and then subtract log(V_t) at the end of the whole expression.
-      // As a result we do leave log-space, but the largest quantity we
-      // represent is on the order of (V_t / V_i) where V_i is the smallest leaf
-      // node below this node, which depends heavily on the depth of the tree.
-      subtreeLeavesLogNegError = std::log(std::exp(logVolume +
-          left->SubtreeLeavesLogNegError() + right->SubtreeLeavesLogNegError()))
-          - logVolume;
-
-      // Recalculate upper alpha.
-      const double range = maxVals[splitDim] - minVals[splitDim];
-      const double leftRatio = (splitValue - minVals[splitDim]) / range;
-      const double rightRatio = (maxVals[splitDim] - splitValue) / range;
-
-      const size_t leftPow = std::pow(left->End() - left->Start(), 2);
-      const size_t rightPow = std::pow(right->End() - right->Start(), 2);
-      const size_t thisPow = std::pow(end - start, 2);
-
-      double tmpAlphaSum = leftPow / leftRatio + rightPow / rightRatio -
-          thisPow;
-
-      if (left->SubtreeLeaves() > 1)
-      {
-        const double exponent = 2 * std::log(points) + logVolume +
-            left->AlphaUpper();
-
-        // Whether or not this will overflow is highly dependent on the depth of
-        // the tree.
-        tmpAlphaSum += std::exp(exponent);
-      }
-
-      if (right->SubtreeLeaves() > 1)
-      {
-        const double exponent = 2 * std::log(points) + logVolume +
-            right->AlphaUpper();
-
-        tmpAlphaSum += std::exp(exponent);
-      }
-
-      alphaUpper = std::log(tmpAlphaSum) - 2 * std::log(points) - logVolume;
-
-      // Update gT value.
-      if (useVolReg)
-      {
-        // This is incorrect.
-        gT = alphaUpper; // / (subtreeLeavesVTInv - vTInv);
-      }
-      else
-      {
-        gT = alphaUpper - std::log(subtreeLeaves - 1);
-      }
-
-      assert(gT < std::numeric_limits<double>::max());
-
-      return std::min(gT, std::min(leftG, rightG));
-    }
-    else
-    {
-      // Prune this subtree.
-      // First, make this node a leaf node.
-      subtreeLeaves = 1;
-      subtreeLeavesLogNegError = logNegError;
-
-      delete left;
-      delete right;
-
-      left = NULL;
-      right = NULL;
-
-      // Pass information upward.
-      return std::numeric_limits<double>::max();
-    }
-  }
-}
-
-// Check whether a given point is within the bounding box of this node (check
-// generally done at the root, so its the bounding box of the data).
-//
-// Future improvement: Open up the range with epsilons on both sides where
-// epsilon depends on the density near the boundary.
-template<typename eT, typename cT>
-inline bool DTree<eT, cT>::WithinRange(const arma::vec& query) const
-{
-  for (size_t i = 0; i < query.n_elem; ++i)
-    if ((query[i] < minVals[i]) || (query[i] > maxVals[i]))
-      return false;
-
-  return true;
-}
-
-
-template<typename eT, typename cT>
-double DTree<eT, cT>::ComputeValue(const arma::vec& query) const
-{
-  Log::Assert(query.n_elem == maxVals.n_elem);
-
-  if (root == 1) // If we are the root...
-  {
-    // Check if the query is within range.
-    if (!WithinRange(query))
-      return 0.0;
-  }
-
-  if (subtreeLeaves == 1)  // If we are a leaf...
-  {
-    return std::exp(std::log(ratio) - logVolume);
-  }
-  else
-  {
-    if (query[splitDim] <= splitValue)
-    {
-      // If left subtree, go to left child.
-      return left->ComputeValue(query);
-    }
-    else  // If right subtree, go to right child
-    {
-      return right->ComputeValue(query);
-    }
-  }
-
-  return 0.0;
-}
-
-
-template<typename eT, typename cT>
-void DTree<eT, cT>::WriteTree(FILE *fp, const size_t level) const
-{
-  if (subtreeLeaves > 1)
-  {
-    fprintf(fp, "\n");
-    for (size_t i = 0; i < level; ++i)
-      fprintf(fp, "|\t");
-    fprintf(fp, "Var. %zu > %lg", splitDim, splitValue);
-
-    right->WriteTree(fp, level + 1);
-
-    fprintf(fp, "\n");
-    for (size_t i = 0; i < level; ++i)
-      fprintf(fp, "|\t");
-    fprintf(fp, "Var. %zu <= %lg ", splitDim, splitValue);
-
-    left->WriteTree(fp, level);
-  }
-  else // If we are a leaf...
-  {
-    fprintf(fp, ": f(x)=%lg", std::exp(std::log(ratio) - logVolume));
-    if (bucketTag != -1)
-      fprintf(fp, " BT:%d", bucketTag);
-  }
-}
-
-
-// Index the buckets for possible usage later.
-template<typename eT, typename cT>
-int DTree<eT, cT>::TagTree(const int tag)
-{
-  if (subtreeLeaves == 1)
-  {
-    // Only label leaves.
-    bucketTag = tag;
-    return (tag + 1);
-  }
-  else
-  {
-    return right->TagTree(left->TagTree(tag));
-  }
-}
-
-
-template<typename eT, typename cT>
-int DTree<eT, cT>::FindBucket(const arma::vec& query) const
-{
-  Log::Assert(query.n_elem == maxVals.n_elem);
-
-  if (subtreeLeaves == 1) // If we are a leaf...
-  {
-    return bucketTag;
-  }
-  else if (query[splitDim] <= splitValue)
-  {
-    // If left subtree, go to left child.
-    return left->FindBucket(query);
-  }
-  else
-  {
-    // If right subtree, go to right child.
-    return right->FindBucket(query);
-  }
-}
-
-
-template<typename eT, typename cT>
-void DTree<eT, cT>::ComputeVariableImportance(arma::vec& importances)
-    const
-{
-  // Clear and set to right size.
-  importances.zeros(maxVals.n_elem);
-
-  std::stack<const DTree*> nodes;
-  nodes.push(this);
-
-  while(!nodes.empty())
-  {
-    const DTree& curNode = *nodes.top();
-    nodes.pop();
-
-    if (curNode.subtreeLeaves == 1)
-      continue; // Do nothing for leaves.
-
-    // The way to do this entirely in log-space is (at this time) somewhat
-    // unclear.  So this risks overflow.
-    importances[curNode.SplitDim()] += (-std::exp(curNode.LogNegError()) -
-        (-std::exp(curNode.Left()->LogNegError()) +
-         -std::exp(curNode.Right()->LogNegError())));
-
-    nodes.push(curNode.Left());
-    nodes.push(curNode.Right());
-  }
-}
-
-}; // namespace det
-}; // namespace mlpack
-
-#endif // __MLPACK_METHODS_DET_DTREE_IMPL_HPP




More information about the mlpack-svn mailing list