[mlpack-svn] r13244 - mlpack/trunk/src/mlpack/methods/det
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Jul 18 16:41:52 EDT 2012
Author: rcurtin
Date: 2012-07-18 16:41:52 -0400 (Wed, 18 Jul 2012)
New Revision: 13244
Modified:
mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
Log:
Use double for ComputeValue(), and take a reference instead of a pointer; also
fix const-correctness of function.
Modified: mlpack/trunk/src/mlpack/methods/det/dt_main.cpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dt_main.cpp 2012-07-18 20:41:52 UTC (rev 13244)
@@ -125,7 +125,7 @@
Timer::Start("DET/EstimationTime");
for (size_t i = 0; i < training_data.n_cols; i++) {
arma::Col<double> test_p = training_data.unsafe_col(i);
- long double f = dtree_opt->ComputeValue(&test_p);
+ long double f = dtree_opt->ComputeValue(test_p);
if (fp != NULL)
fprintf(fp, "%Lg\n", f);
} // end for
@@ -160,7 +160,7 @@
Timer::Start("DET/TestSetEstimation");
for (size_t i = 0; i < test_data.n_cols; i++) {
arma::Col<double> test_p = test_data.unsafe_col(i);
- long double f = dtree_opt->ComputeValue(&test_p);
+ long double f = dtree_opt->ComputeValue(test_p);
if (fp != NULL)
fprintf(fp, "%Lg\n", f);
} // end for
Modified: mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dt_utils.hpp 2012-07-18 20:41:52 UTC (rev 13244)
@@ -148,7 +148,7 @@
for (size_t i = 0; i < dataset->n_cols; ++i)
{
arma::Col<eT> test_p = dataset->unsafe_col(i);
- outfile << dtree->ComputeValue(&test_p) << endl;
+ outfile << dtree->ComputeValue(test_p) << endl;
}
}
else
@@ -243,7 +243,7 @@
for (size_t i = 0; i < test.n_cols; i++)
{
arma::Col<eT> test_point = test.unsafe_col(i);
- val_cv += dtree_cv->ComputeValue(&test_point);
+ val_cv += dtree_cv->ComputeValue(test_point);
}
// Update the cv error value.
@@ -259,7 +259,7 @@
for (size_t i = 0; i < test.n_cols; ++i)
{
arma::Col<eT> test_point = test.unsafe_col(i);
- val_cv += dtree_cv->ComputeValue(&test_point);
+ val_cv += dtree_cv->ComputeValue(test_point);
}
// Update the cv error value.
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-07-18 20:41:52 UTC (rev 13244)
@@ -204,7 +204,7 @@
double PruneAndUpdate(const double old_alpha, const bool useVolReg = false);
// compute the density at a given point
- cT ComputeValue(VecType* query);
+ double ComputeValue(const arma::vec& query) const;
// print the tree (in a DFS manner)
void WriteTree(size_t level, FILE *fp);
Modified: mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-16 19:36:33 UTC (rev 13243)
+++ mlpack/trunk/src/mlpack/methods/det/dtree_impl.hpp 2012-07-18 20:41:52 UTC (rev 13244)
@@ -492,24 +492,32 @@
template<typename eT, typename cT>
-cT DTree<eT, cT>::ComputeValue(VecType* query)
+double DTree<eT, cT>::ComputeValue(const arma::vec& query) const
{
- assert(query->n_elem == maxVals.n_elem);
+ Log::Assert(query.n_elem == maxVals.n_elem);
if (root_ == 1) // If we are the root...
+ {
// Check if the query is within range.
- if (!WithinRange(*query))
+ if (!WithinRange(query))
return 0.0;
+ }
if (subtree_leaves_ == 1) // If we are a leaf...
+ {
return ratio_ * v_t_inv_;
+ }
else
{
- if ((*query)[split_dim_] <= split_value_)
+ if (query[split_dim_] <= split_value_)
+ {
// If left subtree, go to left child.
return left_->ComputeValue(query);
+ }
else // If right subtree, go to right child
+ {
return right_->ComputeValue(query);
+ }
}
}
More information about the mlpack-svn
mailing list