[mlpack-svn] r13262 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Fri Jul 20 13:18:49 EDT 2012
Author: rcurtin
Date: 2012-07-20 13:18:48 -0400 (Fri, 20 Jul 2012)
New Revision: 13262
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Change API for ComputeVariableImportance().
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-20 17:18:48 UTC (rev 13262)
@@ -32,7 +32,7 @@
for (size_t i = 0; i < data.n_cols; i++)
{
arma::Col<eT> test_p = data.unsafe_col(i);
- int leaf_tag = dtree->FindBucket(&test_p);
+ int leaf_tag = dtree->FindBucket(test_p);
int label = labels[i];
table(leaf_tag, label) += 1;
}
@@ -70,20 +70,19 @@
size_t num_dims,
string vi_file = "")
{
- arma::Col<double> *imps = new arma::Col<double>(num_dims);
- imps->zeros();
-
+ arma::vec imps;
dtree->ComputeVariableImportance(imps);
+
double max = 0.0;
- for (size_t i = 0; i < imps->n_elem; ++i)
- if ((*imps)[i] > max)
- max = (*imps)[i];
+ for (size_t i = 0; i < imps.n_elem; ++i)
+ if (imps[i] > max)
+ max = imps[i];
Log::Warn << "Max. variable importance: " << max << "." << std::endl;
if (vi_file == "")
{
- Log::Warn << "Variable importance: " << std::endl << imps->t() << std::endl;
+ Log::Warn << "Variable importance: " << std::endl << imps.t() << std::endl;
}
else
{
@@ -92,7 +91,7 @@
{
Log::Warn << "Variable importance printed in '" << vi_file << "'."
<< endl;
- outfile << *imps;
+ outfile << imps;
} else {
Log::Warn << "Can't open '" << vi_file
<< "'" << endl;
@@ -100,8 +99,6 @@
outfile.close();
}
- delete[] imps;
-
return;
} // PrintVariableImportance
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-20 17:18:48 UTC (rev 13262)
@@ -117,30 +117,30 @@
public:
////////////////////// Getters and Setters //////////////////////////////////
- size_t start() { return start_; }
+ size_t start() const { return start_; }
- size_t end() { return end_; }
+ size_t end() const { return end_; }
- size_t split_dim() { return split_dim_; }
+ size_t split_dim() const { return split_dim_; }
- eT split_value() { return split_value_; }
+ eT split_value() const { return split_value_; }
- cT error() { return error_; }
+ cT error() const { return error_; }
- cT subtree_leaves_error() { return subtree_leaves_error_; }
+ cT subtree_leaves_error() const { return subtree_leaves_error_; }
- size_t subtree_leaves() { return subtree_leaves_; }
+ size_t subtree_leaves() const { return subtree_leaves_; }
- cT ratio() { return ratio_; }
+ cT ratio() const { return ratio_; }
- cT v_t_inv() { return v_t_inv_; }
+ cT v_t_inv() const { return v_t_inv_; }
- cT subtree_leaves_v_t_inv() { return subtree_leaves_v_t_inv_; }
+ cT subtree_leaves_v_t_inv() const { return subtree_leaves_v_t_inv_; }
- DTree<eT, cT>* left() { return left_; }
- DTree<eT, cT>* right() { return right_; }
+ DTree<eT, cT>* left() const { return left_; }
+ DTree<eT, cT>* right() const { return right_; }
- bool root() { return root_; }
+ bool root() const { return root_; }
////////////////////// Private Functions ////////////////////////////////////
private:
@@ -214,11 +214,11 @@
// This is used to generate the class membership
// of a learned tree.
- int FindBucket(VecType* query);
+ int FindBucket(const arma::vec& query) const;
// This computes the variable importance list
// for the learned tree.
- void ComputeVariableImportance(arma::Col<double> *imps);
+ void ComputeVariableImportance(arma::vec& importances) const;
}; // Class DTree
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-19 19:02:56 UTC (rev 13261)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-20 17:18:48 UTC (rev 13262)
@@ -10,6 +10,7 @@
#define __MLPACK_METHODS_DET_DTREE_IMPL_HPP
#include "dtree.hpp"
+#include <stack>
namespace mlpack {
namespace det {
@@ -567,15 +568,15 @@
template<typename eT, typename cT>
-int DTree<eT, cT>::FindBucket(VecType* query)
+int DTree<eT, cT>::FindBucket(const arma::vec& query) const
{
- assert(query->n_elem == maxVals.n_elem);
+ Log::Assert(query.n_elem == maxVals.n_elem);
if (subtree_leaves_ == 1) // If we are a leaf...
{
return bucket_tag_;
}
- else if ((*query)[split_dim_] <= split_value_)
+ else if (query[split_dim_] <= split_value_)
{
// If left subtree, go to left child.
return left_->FindBucket(query);
@@ -588,23 +589,29 @@
template<typename eT, typename cT>
-void DTree<eT, cT>::ComputeVariableImportance(arma::Col<double> *imps)
+void DTree<eT, cT>::ComputeVariableImportance(arma::vec& importances)
+ const
{
- if (subtree_leaves_ == 1)
+ // Clear and set to right size.
+ importances.zeros(maxVals.n_elem);
+
+ std::stack<const DTree*> nodes;
+ nodes.push(this);
+
+ while(!nodes.empty())
{
- // If we are a leaf, do nothing.
- return;
+ const DTree& curNode = *nodes.top();
+ nodes.pop();
+
+ if (curNode.subtree_leaves_ == 1)
+ continue; // Do nothing for leaves.
+
+ importances[curNode.split_dim()] += (double) (curNode.error() -
+ (curNode.left()->error() + curNode.right()->error()));
+
+ nodes.push(curNode.left());
+ nodes.push(curNode.right());
}
- else
- {
- // Compute the improvement in error because of the split.
- double error_improv = (double)
- (error_ - (left_->error() + right_->error()));
- (*imps)[split_dim_] += error_improv;
- left_->ComputeVariableImportance(imps);
- right_->ComputeVariableImportance(imps);
- return;
- }
}
}; // namespace det
More information about the mlpack-svn
mailing list