[mlpack-svn] r12219 - in mlpack/trunk/src/mlpack: methods/det tests
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Wed Apr 4 17:20:48 EDT 2012
Author: pram
Date: 2012-04-04 17:20:47 -0400 (Wed, 04 Apr 2012)
New Revision: 12219
Added:
mlpack/trunk/src/mlpack/tests/det_test.cpp
Modified:
mlpack/trunk/src/mlpack/methods/det/dtree.hpp
mlpack/trunk/src/mlpack/tests/CMakeLists.txt
Log:
DET Tests (almost all done) added
Modified: mlpack/trunk/src/mlpack/methods/det/dtree.hpp
===================================================================
--- mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-04-04 21:19:18 UTC (rev 12218)
+++ mlpack/trunk/src/mlpack/methods/det/dtree.hpp 2012-04-04 21:20:47 UTC (rev 12219)
@@ -209,203 +209,6 @@
// for the learned tree.
void ComputeVariableImportance(arma::Col<double> *imps);
- // A public function to test the private functions
- bool TestPrivateFunctions() {
-
-
- bool return_flag = true;
-
- // Create data
- MatType test_data(3,5);
-
- test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
- << 5 << 0 << 1 << 7 << 1 << arma::endr
- << 5 << 6 << 7 << 1 << 8 << arma::endr;
-
- // save current data
- size_t true_start = start_, true_end = end_;
- VecType* true_max_vals = max_vals_;
- VecType* true_min_vals = min_vals_;
- cT true_error = error_;
-
-
- // Test GetMaxMinVals_
- min_vals_ = NULL;
- max_vals_ = NULL;
- max_vals_ = new VecType();
- min_vals_ = new VecType();
-
- GetMaxMinVals_(&test_data, max_vals_, min_vals_);
-
- if ((*max_vals_)[0] != 7 || (*min_vals_)[0] != 3) {
- Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
- return_flag = false;
- }
-
- if ((*max_vals_)[1] != 7 || (*min_vals_)[1] != 0) {
- Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
- return_flag = false;
- }
-
- if ((*max_vals_)[2] != 8 || (*min_vals_)[2] != 1) {
- Log::Warn << "Test: GetMaxMinVals_ failed." << endl;
- return_flag = false;
- }
-
- // Test ComputeNodeError_
- start_ = 0;
- end_ = 5;
- cT node_error = ComputeNodeError_(5);
- cT log_vol = (cT) std::log(4) + (cT) std::log(7) + (cT) std::log(7);
- cT true_node_error = -1.0 * std::exp(-log_vol);
-
- if (std::abs(node_error - true_node_error) > 1e-7) {
- Log::Warn << "Test: True error : " << true_node_error
- << ", Computed error: " << node_error
- << ", diff: " << std::abs(node_error - true_node_error)
- << endl;
- return_flag = false;
- }
-
- start_ = 3;
- end_ = 5;
- node_error = ComputeNodeError_(5);
- true_node_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5) - log_vol);
-
- if (std::abs(node_error - true_node_error) > 1e-7) {
- Log::Warn << "Test: True error : " << true_node_error
- << ", Computed error: " << node_error
- << ", diff: " << std::abs(node_error - true_node_error)
- << endl;
- return_flag = false;
- }
-
- // Test WithinRange_
-
- VecType test_query(3);
- test_query << 4.5 << 2.5 << 2;
-
- if (!WithinRange_(&test_query)) {
- Log::Warn << "Test: WithinRange_ failed" << endl;
- return_flag = false;
- }
-
- test_query << 8.5 << 2.5 << 2;
-
- if (WithinRange_(&test_query)) {
- Log::Warn << "Test: WithinRange_ failed" << endl;
- return_flag = false;
- }
-
- // Test FindSplit_
- start_ = 0;
- end_ = 5;
- error_ = ComputeNodeError_(5);
-
- size_t ob_dim, true_dim, ob_ind, true_ind;
- cT true_left_error, ob_left_error, true_right_error, ob_right_error;
-
- true_dim = 2;
- true_ind = 1;
- true_left_error = -1.0 * std::exp(2 * std::log((cT) 2 / (cT) 5)
- - (std::log((cT) 7) + std::log((cT) 4)
- + std::log((cT) 4.5)));
- true_right_error = -1.0 * std::exp(2 * std::log((cT) 3 / (cT) 5)
- - (std::log((cT) 7) + std::log((cT) 4)
- + std::log((cT) 2.5)));
-
- if(!FindSplit_(&test_data, &ob_dim, &ob_ind,
- &ob_left_error, &ob_right_error, 2, 1)) {
- Log::Warn << "Test: FindSplit_ returns false." << endl;
- return_flag = false;
- }
-
- if (true_dim != ob_dim) {
- Log::Warn << "Test: FindSplit_ - True dim: " << true_dim
- << ", Obtained dim: " << ob_dim << endl;
- return_flag = false;
- }
-
- if (true_ind != ob_ind) {
- Log::Warn << "Test: FindSplit_ - True ind: " << true_ind
- << ", Obtained ind: " << ob_ind << endl;
- return_flag = false;
- }
-
- if (std::abs(true_left_error - ob_left_error) > 1e-7) {
- Log::Warn << "Test: FindSplit_ - True left_error: " << true_left_error
- << ", Obtained left_error: " << ob_left_error
- << ", diff: " << std::abs(true_left_error - ob_left_error)
- << endl;
- return_flag = false;
- }
-
- if (std::abs(true_right_error - ob_right_error) > 1e-7) {
- Log::Warn << "Test: FindSplit_ - True right_error: " << true_right_error
- << ", Obtained right_error: " << ob_right_error
- << ", diff: " << std::abs(true_right_error - ob_right_error)
- << endl;
- return_flag = false;
- }
-
- // Test SplitData_
- MatType split_test_data(test_data);
- arma::Col<size_t> o_test(5);
- o_test << 1 << 2 << 3 << 4 << 5;
-
- start_ = 0;
- end_ = 5;
- size_t split_dim = 2, split_ind = 1;
- eT true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
- true_rsplit_val, ob_rsplit_val;
-
- true_lsplit_val = 5;
- true_rsplit_val = 6;
- true_split_val = (true_lsplit_val + true_rsplit_val) / 2;
-
- SplitData_(&split_test_data, split_dim, split_ind,
- &o_test, &ob_split_val,
- &ob_lsplit_val, &ob_rsplit_val);
-
- if (o_test[0] != 1 || o_test[1] != 4 || o_test[2] != 3
- || o_test[3] != 2 || o_test[4] != 5) {
- Log::Warn << "Test: SplitData_ - OFW should be 1,4,3,2,5"
- << ", is " << o_test.t();
- return_flag = false;
- }
-
- if (true_split_val != ob_split_val) {
- Log::Warn << "Test: SplitData_ - True split val: " << true_split_val
- << ", Ob split val: " << ob_split_val << endl;
- return_flag = false;
- }
-
- if (true_lsplit_val != ob_lsplit_val) {
- Log::Warn << "Test: SplitData_ - True lsplit val: " << true_lsplit_val
- << ", Ob lsplit val: " << ob_lsplit_val << endl;
- return_flag = false;
- }
-
- if (true_rsplit_val != ob_rsplit_val) {
- Log::Warn << "Test: SplitData_ - True rsplit val: " << true_rsplit_val
- << ", Ob rsplit val: " << ob_rsplit_val << endl;
- return_flag = false;
- }
-
-
- // restore original values
- delete max_vals_;
- delete min_vals_;
- max_vals_ = true_max_vals;
- min_vals_ = true_min_vals;
- start_ = true_start;
- end_ = true_end;
- error_ = true_error;
-
- return return_flag;
-
- } // TestPrivateFunctions
-
}; // Class DTree
}; // namespace det
Modified: mlpack/trunk/src/mlpack/tests/CMakeLists.txt
===================================================================
--- mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-04-04 21:19:18 UTC (rev 12218)
+++ mlpack/trunk/src/mlpack/tests/CMakeLists.txt 2012-04-04 21:20:47 UTC (rev 12219)
@@ -31,6 +31,7 @@
sparse_coding_test.cpp
tree_test.cpp
union_find_test.cpp
+ det_test.cpp
)
# Link dependencies of test executable.
target_link_libraries(mlpack_test
Added: mlpack/trunk/src/mlpack/tests/det_test.cpp
===================================================================
--- mlpack/trunk/src/mlpack/tests/det_test.cpp (rev 0)
+++ mlpack/trunk/src/mlpack/tests/det_test.cpp 2012-04-04 21:20:47 UTC (rev 12219)
@@ -0,0 +1,455 @@
+/**
+ * @file det_test.cpp
+ * @author Parikshit Ram (pram at cc.gatech.edu)
+ *
+ * Unit tests for the functions of the class DTree
+ * and the utility functions using this class.
+ */
+
+#define protected public
+#define private public
+#include <mlpack/methods/det/dtree.hpp>
+#include <mlpack/methods/det/dt_utils.hpp>
+#undef protected
+#undef private
+
+#include <mlpack/core.hpp>
+#include <boost/test/unit_test.hpp>
+
+using namespace mlpack;
+using namespace mlpack::det;
+using namespace std;
+
+BOOST_AUTO_TEST_SUITE(DETTest);
+
+// Testing functions of the DTree class
+
+typedef arma::Mat<float> MatType;
+typedef arma::Col<float> VecType;
+
+
+// the private functions
+
+BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
+{
+ DTree<>* testDTree = new DTree<>();
+
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ VecType* max_vals = new VecType();
+ VecType* min_vals = new VecType();
+
+ testDTree->GetMaxMinVals_(&test_data, max_vals, min_vals);
+
+ BOOST_REQUIRE((*max_vals)[0] == 7);
+ BOOST_REQUIRE((*min_vals)[0] == 3);
+ BOOST_REQUIRE((*max_vals)[1] == 7);
+ BOOST_REQUIRE((*min_vals)[1] == 0);
+ BOOST_REQUIRE((*max_vals)[2] == 8);
+ BOOST_REQUIRE((*min_vals)[2] == 1);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestComputeNodeError)
+{
+ VecType* max_vals = new VecType(3);
+ VecType* min_vals = new VecType(3);
+
+ *max_vals << 7 << 7 << 8;
+ *min_vals << 3 << 0 << 1;
+
+ DTree<>* testDTree = new DTree<>(max_vals, min_vals, 5);
+ long double true_node_error = -1.0 * exp(-(long double) log((float) 4.0)
+ - (long double) log((float) 7.0)
+ - (long double) log((float) 7.0));
+
+ BOOST_REQUIRE_CLOSE(testDTree->error_, true_node_error, 1e-10);
+
+ testDTree->start_ = 3;
+ testDTree->end_ = 5;
+
+ long double node_error = testDTree->ComputeNodeError_(5);
+ true_node_error = -1.0 * exp(2 * log((long double) 2 / (long double) 5)
+ -(long double) log((float) 4.0)
+ - (long double) log((float) 7.0)
+ - (long double) log((float) 7.0));
+ BOOST_REQUIRE_CLOSE(node_error, true_node_error, 1e-10);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestWithinRange)
+{
+ VecType* max_vals = new VecType(3);
+ VecType* min_vals = new VecType(3);
+
+ *max_vals << 7 << 7 << 8;
+ *min_vals << 3 << 0 << 1;
+
+ DTree<>* testDTree = new DTree<>(max_vals, min_vals, 5);
+
+ VecType test_query(3);
+ test_query << 4.5 << 2.5 << 2;
+
+ BOOST_REQUIRE(testDTree->WithinRange_(&test_query));
+
+ test_query << 8.5 << 2.5 << 2;
+
+ BOOST_REQUIRE(!testDTree->WithinRange_(&test_query));
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestFindSplit)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+
+ size_t ob_dim, true_dim, ob_ind, true_ind;
+ long double true_left_error, ob_left_error,
+ true_right_error, ob_right_error;
+
+ true_dim = 2;
+ true_ind = 1;
+ true_left_error = -1.0 * exp(2 * log((long double) 2
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 4.5)));
+ true_right_error = -1.0 * exp(2 * log((long double) 3
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ BOOST_REQUIRE(testDTree->FindSplit_
+ (&test_data, &ob_dim, &ob_ind, &ob_left_error,
+ &ob_right_error, 2, 1));
+
+ BOOST_REQUIRE(true_dim == ob_dim);
+ BOOST_REQUIRE(true_ind == ob_ind);
+
+ BOOST_REQUIRE_CLOSE(true_left_error, ob_left_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(true_right_error, ob_right_error, 1e-10);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestSplitData)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+
+ arma::Col<size_t> o_test(5);
+ o_test << 1 << 2 << 3 << 4 << 5;
+
+ size_t split_dim = 2, split_ind = 1;
+ float true_split_val, ob_split_val, true_lsplit_val, ob_lsplit_val,
+ true_rsplit_val, ob_rsplit_val;
+
+ true_lsplit_val = 5;
+ true_rsplit_val = 6;
+ true_split_val = (true_lsplit_val + true_rsplit_val) / 2;
+
+ testDTree->SplitData_(&test_data, split_dim, split_ind,
+ &o_test, &ob_split_val,
+ &ob_lsplit_val, &ob_rsplit_val);
+
+ BOOST_REQUIRE(o_test[0] == 1 && o_test[1] == 4
+ && o_test[2] == 3 && o_test[3] == 2
+ && o_test[4] == 5);
+
+ BOOST_REQUIRE(true_split_val == ob_split_val);
+ BOOST_REQUIRE(true_lsplit_val == ob_lsplit_val);
+ BOOST_REQUIRE(true_rsplit_val == ob_rsplit_val);
+
+ delete testDTree;
+}
+
+// the public functions
+
+BOOST_AUTO_TEST_CASE(TestGrow)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ arma::Col<size_t> o_test(5);
+ o_test << 0 << 1 << 2 << 3 << 4;
+
+ long double root_error, l_error, r_error, rl_error, rr_error;
+
+ root_error = -1.0 * exp(-(long double) log((float) 4.0)
+ - (long double) log((float) 7.0)
+ - (long double) log((float) 7.0));
+
+ l_error = -1.0 * exp(2 * log((long double) 2
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 4.5)));
+ r_error = -1.0 * exp(2 * log((long double) 3
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ rl_error = -1.0 * exp(2 * log((long double) 1
+ / (long double) 5)
+ - ((long double) log((float) 0.5)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ rr_error = -1.0 * exp(2 * log((long double) 2
+ / (long double) 5)
+ - ((long double) log((float) 6.5)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+ long double alpha = testDTree->Grow(&test_data, &o_test,
+ false, 2, 1);
+
+ BOOST_REQUIRE(o_test[0] == 0 && o_test[1] == 3
+ && o_test[2] == 1 && o_test[3] == 2
+ && o_test[4] == 4);
+
+ // test the structure of the tree
+ BOOST_REQUIRE(testDTree->left()->left() == NULL);
+ BOOST_REQUIRE(testDTree->left()->right() == NULL);
+ BOOST_REQUIRE(testDTree->right()->left()->left() == NULL);
+ BOOST_REQUIRE(testDTree->right()->left()->right() == NULL);
+ BOOST_REQUIRE(testDTree->right()->right()->left() == NULL);
+ BOOST_REQUIRE(testDTree->right()->right()->right() == NULL);
+
+ BOOST_REQUIRE(testDTree->subtree_leaves() == 3);
+
+ BOOST_REQUIRE(testDTree->split_dim() == 2);
+ BOOST_REQUIRE_CLOSE(testDTree->split_value(), (float) 5.5, (float) 1e-5);
+ BOOST_REQUIRE(testDTree->right()->split_dim() == 1);
+ BOOST_REQUIRE_CLOSE(testDTree->right()->split_value(),
+ (float) 0.5, (float) 1e-5);
+
+ // test node errors for every node
+ BOOST_REQUIRE_CLOSE(testDTree->error_, root_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree->left()->error_, l_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree->right()->error_, r_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree->right()->left()->error_, rl_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree->right()->right()->error_, rr_error, 1e-10);
+
+
+ // test alpha
+ long double root_alpha, r_alpha;
+ root_alpha = (root_error - (l_error + rl_error + rr_error)) / 2;
+ r_alpha = r_error - (rl_error + rr_error);
+
+ BOOST_REQUIRE_CLOSE(alpha, min(root_alpha, r_alpha), 1e-10);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ arma::Col<size_t> o_test(5);
+ o_test << 0 << 1 << 2 << 3 << 4;
+ DTree<>* testDTree = new DTree<>(&test_data);
+ long double alpha = testDTree->Grow(&test_data, &o_test,
+ false, 2, 1);
+ alpha = testDTree->PruneAndUpdate(alpha, false);
+
+ BOOST_REQUIRE_CLOSE(alpha, numeric_limits<long double>::max(), 1e-10);
+ BOOST_REQUIRE(testDTree->subtree_leaves() == 1);
+
+ long double root_error = -1.0 * exp(-(long double) log((float) 4.0)
+ - (long double) log((float) 7.0)
+ - (long double) log((float) 7.0));
+
+ BOOST_REQUIRE_CLOSE(testDTree->error(), root_error, 1e-10);
+ BOOST_REQUIRE_CLOSE(testDTree->subtree_leaves_error(), root_error, 1e-10);
+ BOOST_REQUIRE(testDTree->left() == NULL);
+ BOOST_REQUIRE(testDTree->right() == NULL);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestComputeValue)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ VecType q1(3), q2(3), q3(3), q4(3);
+
+ q1 << 4 << 2 << 2;
+ q2 << 5 << 0.25 << 6;
+ q3 << 5 << 3 << 7;
+ q4 << 2 << 3 << 3;
+
+ arma::Col<size_t> o_test(5);
+ o_test << 0 << 1 << 2 << 3 << 4;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+ long double alpha = testDTree->Grow(&test_data, &o_test,
+ false, 2, 1);
+
+ long double d1, d2, d3;
+ d1 = ((long double) 2 / (long double) 5)
+ / exp((long double) log((float) 4) + (long double) log((float) 7)
+ + (long double) log((float) 4.5));
+
+ d2 = ((long double) 1 / (long double) 5)
+ / exp((long double) log((float) 4) + (long double) log((float) 0.5)
+ + (long double) log((float) 2.5));
+
+ d3 = ((long double) 2 / (long double) 5)
+ / exp((long double) log((float) 4) + (long double) log((float) 6.5)
+ + (long double) log((float) 2.5));
+
+ BOOST_REQUIRE_CLOSE(d1, testDTree->ComputeValue(&q1), 1e-10);
+ BOOST_REQUIRE_CLOSE(d2, testDTree->ComputeValue(&q2), 1e-10);
+ BOOST_REQUIRE_CLOSE(d3, testDTree->ComputeValue(&q3), 1e-10);
+ BOOST_REQUIRE_CLOSE((long double) 0.0, testDTree->ComputeValue(&q4), 1e-10);
+
+ alpha = testDTree->PruneAndUpdate(alpha, false);
+
+ long double d = 1.0
+ / exp((long double) log((float) 4) + (long double) log((float) 7)
+ + (long double) log((float) 7));
+
+ BOOST_REQUIRE_CLOSE(d, testDTree->ComputeValue(&q1), 1e-10);
+ BOOST_REQUIRE_CLOSE(d, testDTree->ComputeValue(&q2), 1e-10);
+ BOOST_REQUIRE_CLOSE(d, testDTree->ComputeValue(&q3), 1e-10);
+ BOOST_REQUIRE_CLOSE((long double) 0.0, testDTree->ComputeValue(&q4), 1e-10);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestVariableImportance)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ long double root_error, l_error, r_error, rl_error, rr_error;
+
+ root_error = -1.0 * exp(-(long double) log((float) 4.0)
+ - (long double) log((float) 7.0)
+ - (long double) log((float) 7.0));
+
+ l_error = -1.0 * exp(2 * log((long double) 2
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 4.5)));
+ r_error = -1.0 * exp(2 * log((long double) 3
+ / (long double) 5)
+ - ((long double) log((float) 7)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ rl_error = -1.0 * exp(2 * log((long double) 1
+ / (long double) 5)
+ - ((long double) log((float) 0.5)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ rr_error = -1.0 * exp(2 * log((long double) 2
+ / (long double) 5)
+ - ((long double) log((float) 6.5)
+ + (long double) log((float) 4)
+ + (long double) log((float) 2.5)));
+
+ arma::Col<size_t> o_test(5);
+ o_test << 0 << 1 << 2 << 3 << 4;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+ testDTree->Grow(&test_data, &o_test,
+ false, 2, 1);
+
+ arma::vec imps(3);
+ imps.zeros();
+
+ testDTree->ComputeVariableImportance(&imps);
+
+ BOOST_REQUIRE_CLOSE((double) 0.0, imps[0], 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (r_error - (rl_error + rr_error)),
+ imps[1], 1e-10);
+ BOOST_REQUIRE_CLOSE((double) (root_error - (l_error + r_error)),
+ imps[2], 1e-10);
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestTagTree)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+
+
+ delete testDTree;
+}
+
+BOOST_AUTO_TEST_CASE(TestFindBucket)
+{
+ MatType test_data(3,5);
+
+ test_data << 4 << 5 << 7 << 3 << 5 << arma::endr
+ << 5 << 0 << 1 << 7 << 1 << arma::endr
+ << 5 << 6 << 7 << 1 << 8 << arma::endr;
+
+ DTree<>* testDTree = new DTree<>(&test_data);
+
+
+ delete testDTree;
+}
+
+// Test functions in dt_utils.hpp
+
+BOOST_AUTO_TEST_CASE(TestTrainer)
+{
+
+}
+
+BOOST_AUTO_TEST_CASE(TestPrintVariableImportance)
+{
+
+}
+
+BOOST_AUTO_TEST_CASE(TestPrintLeafMembership)
+{
+
+}
+
+BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-svn
mailing list