[mlpack-git] master: - DET changes propagated to tests. (0b58fd9)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Nov 1 15:22:39 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit 0b58fd9a4542b00e12c4525b86d390154631a9ad
Author: theJonan <ivan at jonan.info>
Date: Fri Oct 14 17:42:25 2016 +0300
- DET changes propagated to tests.
>---------------------------------------------------------------
0b58fd9a4542b00e12c4525b86d390154631a9ad
src/mlpack/methods/det/dtree_impl.hpp | 29 ++++++++++++++++-------------
src/mlpack/tests/det_test.cpp | 27 ++++++++++++---------------
src/mlpack/tests/serialization_test.cpp | 15 ++++++++-------
3 files changed, 36 insertions(+), 35 deletions(-)
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 58131fb..2261bbf 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -250,12 +250,12 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
// Loop through each dimension.
#ifdef _WIN32
- #pragma omp parallel for default(none) \
- shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+ #pragma omp parallel for default(shared) \
+ shared(splitValue, splitDim, data)
for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
#else
- #pragma omp parallel for default(none) \
- shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+ #pragma omp parallel for default(shared) \
+ shared(splitValue, splitDim, data)
for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
#endif
{
@@ -327,17 +327,20 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
-#pragma omp critical
+#pragma omp atomic
if ((actualMinDimError > minError) && dimSplitFound)
{
- // Calculate actual error (in logspace) by adding terms back to our
- // estimate.
- minError = actualMinDimError;
- splitDim = dim;
- splitValue = dimSplitValue;
- leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
- rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
- splitFound = true;
+#pragma omp critical DTreeFindUpdate
+ {
+ // Calculate actual error (in logspace) by adding terms back to our
+ // estimate.
+ minError = actualMinDimError;
+ splitDim = dim;
+ splitValue = dimSplitValue;
+ leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+ rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+ splitFound = true;
+ }
} // end if better split found in this dimension.
}
diff --git a/src/mlpack/tests/det_test.cpp b/src/mlpack/tests/det_test.cpp
index 7e32def..2b0ef37 100644
--- a/src/mlpack/tests/det_test.cpp
+++ b/src/mlpack/tests/det_test.cpp
@@ -42,7 +42,7 @@ BOOST_AUTO_TEST_CASE(TestGetMaxMinVals)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree tree(testData);
+ DTree<arma::mat, arma::vec> tree(testData);
BOOST_REQUIRE_EQUAL(tree.maxVals[0], 7);
BOOST_REQUIRE_EQUAL(tree.minVals[0], 3);
@@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(TestComputeNodeError)
arma::vec maxVals("7 7 8");
arma::vec minVals("3 0 1");
- DTree testDTree(maxVals, minVals, 5);
+ DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
double trueNodeError = -log(4.0) - log(7.0) - log(7.0);
BOOST_REQUIRE_CLOSE((double) testDTree.logNegError, trueNodeError, 1e-10);
@@ -75,7 +75,7 @@ BOOST_AUTO_TEST_CASE(TestWithinRange)
arma::vec maxVals("7 7 8");
arma::vec minVals("3 0 1");
- DTree testDTree(maxVals, minVals, 5);
+ DTree<arma::mat, arma::vec> testDTree(maxVals, minVals, 5);
arma::vec testQuery(3);
testQuery << 4.5 << 2.5 << 2;
@@ -95,11 +95,10 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
size_t obDim, trueDim;
- double trueLeftError, obLeftError, trueRightError, obRightError,
- obSplit, trueSplit;
+ double trueLeftError, obLeftError, trueRightError, obRightError, obSplit, trueSplit;
trueDim = 2;
trueSplit = 5.5;
@@ -107,8 +106,7 @@ BOOST_AUTO_TEST_CASE(TestFindSplit)
trueRightError = 2 * log(3.0 / 5.0) - (log(7.0) + log(4.0) + log(2.5));
testDTree.logVolume = log(7.0) + log(4.0) + log(7.0);
- BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError,
- obRightError, 1));
+ BOOST_REQUIRE(testDTree.FindSplit(testData, obDim, obSplit, obLeftError, obRightError, 1));
BOOST_REQUIRE(trueDim == obDim);
BOOST_REQUIRE_CLOSE(trueSplit, obSplit, 1e-10);
@@ -125,7 +123,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
<< 5 << 0 << 1 << 7 << 1 << arma::endr
<< 5 << 6 << 7 << 1 << 8 << arma::endr;
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
arma::Col<size_t> oTest(5);
oTest << 1 << 2 << 3 << 4 << 5;
@@ -133,8 +131,7 @@ BOOST_AUTO_TEST_CASE(TestSplitData)
size_t splitDim = 2;
double trueSplitVal = 5.5;
- size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal,
- oTest);
+ size_t splitInd = testDTree.SplitData(testData, splitDim, trueSplitVal, oTest);
BOOST_REQUIRE_EQUAL(splitInd, 2); // 2 points on left side.
@@ -169,7 +166,7 @@ BOOST_AUTO_TEST_CASE(TestGrow)
rlError = 2 * log(1.0 / 5.0) - (log(0.5) + log(4.0) + log(2.5));
rrError = 2 * log(2.0 / 5.0) - (log(6.5) + log(4.0) + log(2.5));
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
BOOST_REQUIRE_EQUAL(oTest[0], 0);
@@ -222,7 +219,7 @@ BOOST_AUTO_TEST_CASE(TestPruneAndUpdate)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
alpha = testDTree.PruneAndUpdate(alpha, testData.n_cols, false);
@@ -255,7 +252,7 @@ BOOST_AUTO_TEST_CASE(TestComputeValue)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
double alpha = testDTree.Grow(testData, oTest, false, 2, 1);
double d1 = (2.0 / 5.0) / exp(log(4.0) + log(7.0) + log(4.5));
@@ -298,7 +295,7 @@ BOOST_AUTO_TEST_CASE(TestVariableImportance)
arma::Col<size_t> oTest(5);
oTest << 0 << 1 << 2 << 3 << 4;
- DTree testDTree(testData);
+ DTree<arma::mat, arma::vec> testDTree(testData);
testDTree.Grow(testData, oTest, false, 2, 1);
arma::vec imps;
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index e6aecc7..d7ea76e 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -853,18 +853,19 @@ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
BOOST_AUTO_TEST_CASE(DETTest)
{
using det::DTree;
+ typedef DTree<arma::mat, arma::vec> DTreeX;
// Create a density estimation tree on a random dataset.
arma::mat dataset = arma::randu<arma::mat>(25, 5000);
- DTree tree(dataset);
+ DTreeX tree(dataset);
arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
- DTree xmlTree, binaryTree, textTree(otherDataset);
+ DTreeX xmlTree, binaryTree, textTree(otherDataset);
SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
- std::stack<DTree*> stack, xmlStack, binaryStack, textStack;
+ std::stack<DTreeX*> stack, xmlStack, binaryStack, textStack;
stack.push(&tree);
xmlStack.push(&xmlTree);
binaryStack.push(&binaryTree);
@@ -873,10 +874,10 @@ BOOST_AUTO_TEST_CASE(DETTest)
while (!stack.empty())
{
// Get the top node from the stack.
- DTree* node = stack.top();
- DTree* xmlNode = xmlStack.top();
- DTree* binaryNode = binaryStack.top();
- DTree* textNode = textStack.top();
+ DTreeX* node = stack.top();
+ DTreeX* xmlNode = xmlStack.top();
+ DTreeX* binaryNode = binaryStack.top();
+ DTreeX* textNode = textStack.top();
stack.pop();
xmlStack.pop();
More information about the mlpack-git
mailing list