[mlpack-svn] r13262 - mlpack/trunk/src/mlpack/methods/det

fastlab-svn at coffeetalk-1.cc.gatech.edu fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 13:18:49 EDT 2012


Author: rcurtin
Date: 2012-07-20 13:18:48 -0400 (Fri, 20 Jul 2012)
New Revision: 13262

Modified:
   mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree.hpp
   mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Change API for ComputeVariableImportance().


Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp	2012-07-20 17:18:48 UTC (rev 13262)
@@ -32,7 +32,7 @@
   for (size_t i = 0; i < data.n_cols; i++)
   {
     arma::Col<eT> test_p = data.unsafe_col(i);
-    int leaf_tag = dtree->FindBucket(&test_p);
+    int leaf_tag = dtree->FindBucket(test_p);
     int label = labels[i];
     table(leaf_tag, label) += 1;
   }
@@ -70,20 +70,19 @@
                              size_t num_dims,
                              string vi_file = "")
 {
-  arma::Col<double> *imps = new arma::Col<double>(num_dims);
-  imps->zeros();
-
+  arma::vec imps;
   dtree->ComputeVariableImportance(imps);
+
   double max = 0.0;
-  for (size_t i = 0; i < imps->n_elem; ++i)
-    if ((*imps)[i] > max)
-      max = (*imps)[i];
+  for (size_t i = 0; i < imps.n_elem; ++i)
+    if (imps[i] > max)
+      max = imps[i];
 
   Log::Warn << "Max. variable importance: " << max << "." << std::endl;
 
   if (vi_file == "")
   {
-    Log::Warn << "Variable importance: " << std::endl << imps->t() << std::endl;
+    Log::Warn << "Variable importance: " << std::endl << imps.t() << std::endl;
   }
   else
   {
@@ -92,7 +91,7 @@
     {
       Log::Warn << "Variable importance printed in '" << vi_file << "'."
           << endl;
-      outfile << *imps;
+      outfile << imps;
     } else {
       Log::Warn << "Can't open '" << vi_file
         << "'" << endl;
@@ -100,8 +99,6 @@
     outfile.close();
   }
 
-  delete[] imps;
-
   return;
 } // PrintVariableImportance
 

Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp	2012-07-20 17:18:48 UTC (rev 13262)
@@ -117,30 +117,30 @@
 public:
 
   ////////////////////// Getters and Setters //////////////////////////////////
-  size_t start() { return start_; }
+  size_t start() const { return start_; }
 
-  size_t end() { return end_; }
+  size_t end() const { return end_; }
 
-  size_t split_dim() { return split_dim_; }
+  size_t split_dim() const { return split_dim_; }
 
-  eT split_value() { return split_value_; }
+  eT split_value() const { return split_value_; }
 
-  cT error() { return error_; }
+  cT error() const { return error_; }
 
-  cT subtree_leaves_error() { return subtree_leaves_error_; }
+  cT subtree_leaves_error() const { return subtree_leaves_error_; }
 
-  size_t subtree_leaves() { return subtree_leaves_; }
+  size_t subtree_leaves() const { return subtree_leaves_; }
 
-  cT ratio() { return ratio_; }
+  cT ratio() const { return ratio_; }
 
-  cT v_t_inv() { return v_t_inv_; }
+  cT v_t_inv() const { return v_t_inv_; }
 
-  cT subtree_leaves_v_t_inv() { return subtree_leaves_v_t_inv_; }
+  cT subtree_leaves_v_t_inv() const { return subtree_leaves_v_t_inv_; }
 
-  DTree<eT, cT>* left() { return left_; }
-  DTree<eT, cT>* right() { return right_; }
+  DTree<eT, cT>* left() const { return left_; }
+  DTree<eT, cT>* right() const { return right_; }
 
-  bool root() { return root_; }
+  bool root() const { return root_; }
 
   ////////////////////// Private Functions ////////////////////////////////////
  private:
@@ -214,11 +214,11 @@
 
   // This is used to generate the class membership
   // of a learned tree.
-  int FindBucket(VecType* query);
+  int FindBucket(const arma::vec& query) const;
 
   // This computes the variable importance list
   // for the learned tree.
-  void ComputeVariableImportance(arma::Col<double> *imps);
+  void ComputeVariableImportance(arma::vec& importances) const;
 
 }; // Class DTree
 

Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp	2012-07-20 17:18:48 UTC (rev 13262)
@@ -10,6 +10,7 @@
 #define __MLPACK_METHODS_DET_DTREE_IMPL_HPP
 
 #include "dtree.hpp"
+#include <stack>
 
 namespace mlpack {
 namespace det {
@@ -567,15 +568,15 @@
 
 
 template<typename eT, typename cT>
-int DTree<eT, cT>::FindBucket(VecType* query)
+int DTree<eT, cT>::FindBucket(const arma::vec& query) const
 {
-  assert(query->n_elem == maxVals.n_elem);
+  Log::Assert(query.n_elem == maxVals.n_elem);
 
   if (subtree_leaves_ == 1) // If we are a leaf...
   {
     return bucket_tag_;
   }
-  else if ((*query)[split_dim_] <= split_value_)
+  else if (query[split_dim_] <= split_value_)
   {
     // If left subtree, go to left child.
     return left_->FindBucket(query);
@@ -588,23 +589,29 @@
 
 
 template<typename eT, typename cT>
-void DTree<eT, cT>::ComputeVariableImportance(arma::Col<double> *imps)
+void DTree<eT, cT>::ComputeVariableImportance(arma::vec& importances)
+    const
 {
-  if (subtree_leaves_ == 1)
+  // Clear and set to right size.
+  importances.zeros(maxVals.n_elem);
+
+  std::stack<const DTree*> nodes;
+  nodes.push(this);
+
+  while(!nodes.empty())
   {
-    // If we are a leaf, do nothing.
-    return;
+    const DTree& curNode = *nodes.top();
+    nodes.pop();
+
+    if (curNode.subtree_leaves_ == 1)
+      continue; // Do nothing for leaves.
+
+    importances[curNode.split_dim()] += (double) (curNode.error() -
+        (curNode.left()->error() + curNode.right()->error()));
+
+    nodes.push(curNode.left());
+    nodes.push(curNode.right());
   }
-  else
-  {
-    // Compute the improvement in error because of the split.
-    double error_improv = (double)
-        (error_ - (left_->error() + right_->error()));
-    (*imps)[split_dim_] += error_improv;
-    left_->ComputeVariableImportance(imps);
-    right_->ComputeVariableImportance(imps);
-    return;
-  }
 }
 
 }; // namespace det




More information about the mlpack-svn mailing list