[mlpack-svn] r13235 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Mon Jul 16 11:41:51 EDT 2012
Author: rcurtin
Date: 2012-07-16 11:41:51 -0400 (Mon, 16 Jul 2012)
New Revision: 13235
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Clean up constructors. Don't hand-allocate memory for minVals and maxVals;
remove GetMinMaxVals_() because it was only ever used once; inline it directly
into the constructor where it was used.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-16 15:41:51 UTC (rev 13235)
@@ -117,7 +117,7 @@
string unprunedTreeOutput = "")
{
// Initialize the tree.
- DTree<eT>* dtree = new DTree<eT>(dataset);
+ DTree<eT>* dtree = new DTree<eT>(*dataset);
// Getting ready to grow the tree...
arma::Col<size_t> old_from_new(dataset->n_cols);
@@ -221,7 +221,7 @@
assert(train->n_cols + test.n_cols == cvdata->n_cols);
// Initialize the tree.
- DTree<eT>* dtree_cv = new DTree<eT>(train);
+ DTree<eT>* dtree_cv = new DTree<eT>(*train);
// Getting ready to grow the tree...
arma::Col<size_t> old_from_new_cv(train->n_cols);
@@ -289,7 +289,7 @@
Log::Info << "Optimal alpha: " << optimal_alpha << "." << std::endl;
// Initialize the tree.
- DTree<eT>* dtree_opt = new DTree<eT>(dataset);
+ DTree<eT>* dtree_opt = new DTree<eT>(*dataset);
// Getting ready to grow the tree...
for (size_t i = 0; i < old_from_new.n_elem; i++)
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-16 15:41:51 UTC (rev 13235)
@@ -104,8 +104,8 @@
// since we are using uniform density, we need
// the max and min of every dimension for every node
- VecType* max_vals_;
- VecType* min_vals_;
+ arma::vec maxVals;
+ arma::vec minVals;
// the tag for the leaf used for hashing points
int bucket_tag_;
@@ -160,10 +160,6 @@
const double splitValue,
arma::Col<size_t>& oldFromNew) const;
- void GetMaxMinVals_(MatType* data,
- VecType* max_vals,
- VecType* min_vals);
-
bool WithinRange_(VecType* query);
///////////////////// Public Functions //////////////////////////////////////
@@ -174,23 +170,23 @@
// Root node initializer
// with the bounding box of the data
// it contains instead of just the data.
- DTree(VecType* max_vals,
- VecType* min_vals,
+ DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
size_t total_points);
// Root node initializer
// with the data, no bounding box.
- DTree(MatType* data);
+ DTree(arma::mat& data);
// Non-root node initializers
- DTree(VecType* max_vals,
- VecType* min_vals,
+ DTree(const arma::vec& max_vals,
+ const arma::vec& min_vals,
size_t start,
size_t end,
cT error);
- DTree(VecType* max_vals,
- VecType* min_vals,
+ DTree(const arma::vec& max_vals,
+ const arma::vec& min_vals,
size_t total_points,
size_t start,
size_t end);
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-16 02:59:49 UTC (rev 13234)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-16 15:41:51 UTC (rev 13235)
@@ -22,7 +22,7 @@
// log(-|t|^2 / (N^2 V_t)) = log(-1) + 2 log(|t|) - 2 log(N) - log(V_t).
return 2 * std::log((double) (end_ - start_)) -
2 * std::log((double) totalPoints) -
- arma::accu(arma::log((*max_vals_) - (*min_vals_)));
+ arma::accu(arma::log(maxVals - minVals));
}
// This function finds the best split with respect to the L2-error, by trying
@@ -39,8 +39,8 @@
{
// Ensure the dimensionality of the data is the same as the dimensionality of
// the bounding rectangle.
- assert(data.n_rows == max_vals_->n_elem);
- assert(data.n_rows == min_vals_->n_elem);
+ assert(data.n_rows == maxVals.n_elem);
+ assert(data.n_rows == minVals.n_elem);
const size_t points = end_ - start_;
@@ -48,12 +48,12 @@
bool splitFound = false;
// Loop through each dimension.
- for (size_t dim = 0; dim < max_vals_->n_elem; dim++)
+ for (size_t dim = 0; dim < maxVals.n_elem; dim++)
{
// Have to deal with REAL, INTEGER, NOMINAL data differently, so we have to
// think of how to do that...
- const double min = (*min_vals_)[dim];
- const double max = (*max_vals_)[dim];
+ const double min = minVals[dim];
+ const double max = maxVals[dim];
// If there is nothing to split in this dimension, move on.
if (max - min == 0.0)
@@ -69,11 +69,11 @@
// Find the log volume of all the other dimensions.
double volumeWithoutDim = 0;
- for (size_t i = 0; i < max_vals_->n_elem; ++i)
+ for (size_t i = 0; i < maxVals.n_elem; ++i)
{
- if (((*max_vals_)[i] - (*min_vals_)[i] > 0.0) && (i != dim))
+ if ((maxVals[i] - minVals[i] > 0.0) && (i != dim))
{
- volumeWithoutDim += std::log((*max_vals_)[i] - (*min_vals_)[i]);
+ volumeWithoutDim += std::log(maxVals[i] - minVals[i]);
}
}
@@ -188,30 +188,9 @@
template<typename eT, typename cT>
-void DTree<eT, cT>::GetMaxMinVals_(MatType* data,
- VecType *max_vals,
- VecType *min_vals)
-{
- max_vals->set_size(data->n_rows);
- min_vals->set_size(data->n_rows);
-
- MatType temp_d = arma::trans(*data);
-
- for (size_t i = 0; i < temp_d.n_cols; ++i)
- {
- VecType dim_vals = arma::sort(temp_d.col(i));
- (*min_vals)[i] = dim_vals[0];
- (*max_vals)[i] = dim_vals[dim_vals.n_elem - 1];
- }
-}
-
-
-template<typename eT, typename cT>
DTree<eT, cT>::DTree() :
start_(0),
end_(0),
- max_vals_(NULL),
- min_vals_(NULL),
left_(NULL),
right_(NULL)
{ /* Nothing to do. */ }
@@ -219,13 +198,13 @@
// Root node initializers
template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
- VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
size_t total_points) :
start_(0),
end_(total_points),
- max_vals_(max_vals),
- min_vals_(min_vals),
+ maxVals(maxVals),
+ minVals(minVals),
left_(NULL),
right_(NULL)
{
@@ -237,19 +216,33 @@
template<typename eT, typename cT>
-DTree<eT, cT>::DTree(MatType* data) :
+DTree<eT, cT>::DTree(arma::mat& data) :
start_(0),
- end_(data->n_cols),
+ end_(data.n_cols),
left_(NULL),
right_(NULL)
{
- max_vals_ = new VecType();
- min_vals_ = new VecType();
+ maxVals.set_size(data.n_rows);
+ minVals.set_size(data.n_cols);
- GetMaxMinVals_(data, max_vals_, min_vals_);
+ // Initialize to first column; values will be overwritten if necessary.
+ maxVals = data.col(0);
+ minVals = data.col(0);
- error_ = -std::exp(LogNegativeError(data->n_cols));
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ {
+ for (size_t j = 0; j < data.n_rows; ++j)
+ {
+ if (data(j, i) > maxVals[j])
+ maxVals[j] = data(j, i);
+ if (data(j, i) < minVals[j])
+ minVals[j] = data(j, i);
+ }
+ }
+ error_ = -std::exp(LogNegativeError(data.n_cols));
+
bucket_tag_ = -1;
root_ = true;
}
@@ -257,16 +250,16 @@
// Non-root node initializers
template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
- VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
size_t start,
size_t end,
cT error) :
start_(start),
end_(end),
error_(error),
- max_vals_(max_vals),
- min_vals_(min_vals),
+ maxVals(maxVals),
+ minVals(minVals),
left_(NULL),
right_(NULL)
{
@@ -276,15 +269,15 @@
template<typename eT, typename cT>
-DTree<eT, cT>::DTree(VecType* max_vals,
- VecType* min_vals,
+DTree<eT, cT>::DTree(const arma::vec& maxVals,
+ const arma::vec& minVals,
size_t total_points,
size_t start,
size_t end) :
start_(start),
end_(end),
- max_vals_(max_vals),
- min_vals_(min_vals),
+ maxVals(maxVals),
+ minVals(minVals),
left_(NULL),
right_(NULL)
{
@@ -303,12 +296,6 @@
if (right_ != NULL)
delete right_;
-
- if (min_vals_ != NULL)
- delete min_vals_;
-
- if (max_vals_ != NULL)
- delete max_vals_;
}
@@ -320,8 +307,8 @@
size_t maxLeafSize,
size_t minLeafSize)
{
- assert(data->n_rows == max_vals_->n_elem);
- assert(data->n_rows == min_vals_->n_elem);
+ assert(data->n_rows == maxVals.n_elem);
+ assert(data->n_rows == minVals.n_elem);
cT left_g, right_g;
@@ -330,10 +317,10 @@
// Compute the v_t_inv: the inverse of the volume of the node.
cT log_vol_t = 0;
- for (size_t i = 0; i < max_vals_->n_elem; ++i)
- if ((*max_vals_)[i] - (*min_vals_)[i] > 0.0)
+ for (size_t i = 0; i < maxVals.n_elem; ++i)
+ if (maxVals[i] - minVals[i] > 0.0)
// Use log to prevent overflow.
- log_vol_t += (cT) std::log((*max_vals_)[i] - (*min_vals_)[i]);
+ log_vol_t += (cT) std::log(maxVals[i] - minVals[i]);
// Check for overflow.
assert(std::exp(log_vol_t) > 0.0);
@@ -355,13 +342,13 @@
*old_from_new);
// Make max and min vals for the children.
- VecType* max_vals_l = new VecType(*max_vals_);
- VecType* max_vals_r = new VecType(*max_vals_);
- VecType* min_vals_l = new VecType(*min_vals_);
- VecType* min_vals_r = new VecType(*min_vals_);
+ arma::vec max_vals_l(maxVals);
+ arma::vec max_vals_r(maxVals);
+ arma::vec min_vals_l(minVals);
+ arma::vec min_vals_r(minVals);
- (*max_vals_l)[dim] = splitValue;
- (*min_vals_r)[dim] = splitValue;
+ max_vals_l[dim] = splitValue;
+ min_vals_r[dim] = splitValue;
// Store split dim and split val in the node.
split_value_ = splitValue;
@@ -524,7 +511,7 @@
bool DTree<eT, cT>::WithinRange_(VecType* query)
{
for (size_t i = 0; i < query->n_elem; ++i)
- if (((*query)[i] < (*min_vals_)[i]) || ((*query)[i] > (*max_vals_)[i]))
+ if (((*query)[i] < minVals[i]) || ((*query)[i] > maxVals[i]))
return false;
return true;
@@ -534,7 +521,7 @@
template<typename eT, typename cT>
cT DTree<eT, cT>::ComputeValue(VecType* query)
{
- assert(query->n_elem == max_vals_->n_elem);
+ assert(query->n_elem == maxVals.n_elem);
if (root_ == 1) // If we are the root...
// Check if the query is within range.
@@ -601,7 +588,7 @@
template<typename eT, typename cT>
int DTree<eT, cT>::FindBucket(VecType* query)
{
- assert(query->n_elem == max_vals_->n_elem);
+ assert(query->n_elem == maxVals.n_elem);
if (subtree_leaves_ == 1) // If we are a leaf...
{
More information about the mlpack-svn
mailing list