[mlpack-git] master: Merge branch 'master' of https://github.com/mlpack/mlpack (de88672)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Fri Oct 2 19:20:49 EDT 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/7a8b0e1292677b71888fad313772c63bcf0e7b80...de88672879a1893ebfc131538c64e7755251337c
>---------------------------------------------------------------
commit de88672879a1893ebfc131538c64e7755251337c
Merge: 0e910b1 7a8b0e1
Author: Ryan Curtin <ryan at ratml.org>
Date: Fri Oct 2 23:20:25 2015 +0000
Merge branch 'master' of https://github.com/mlpack/mlpack
>---------------------------------------------------------------
de88672879a1893ebfc131538c64e7755251337c
COPYRIGHT.txt | 1 +
src/mlpack/core.hpp | 1 +
.../softmax_regression/softmax_regression.hpp | 73 ++++++++++++---
.../softmax_regression_function.cpp | 62 ++++++++-----
.../softmax_regression_function.hpp | 55 ++++++++----
.../softmax_regression/softmax_regression_impl.hpp | 84 +++++++++++------
src/mlpack/tests/serialization_test.cpp | 25 +++++-
src/mlpack/tests/softmax_regression_test.cpp | 100 +++++++++++++++++----
8 files changed, 303 insertions(+), 98 deletions(-)
diff --cc src/mlpack/tests/serialization_test.cpp
index 424126c,18b8e81..316bdf6
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@@ -25,7 -25,7 +25,8 @@@
#include <mlpack/methods/perceptron/perceptron.hpp>
#include <mlpack/methods/logistic_regression/logistic_regression.hpp>
#include <mlpack/methods/neighbor_search/neighbor_search.hpp>
+ #include <mlpack/methods/softmax_regression/softmax_regression.hpp>
+#include <mlpack/methods/det/dtree.hpp>
using namespace mlpack;
using namespace mlpack::distribution;
@@@ -800,217 -800,27 +801,239 @@@ BOOST_AUTO_TEST_CASE(AllkNNTest
CheckMatrices(neighbors, xmlNeighbors, textNeighbors, binaryNeighbors);
}
+ BOOST_AUTO_TEST_CASE(SoftmaxRegressionTest)
+ {
+ using regression::SoftmaxRegression;
+
+ arma::mat dataset = arma::randu<arma::mat>(5, 1000);
+ arma::Row<size_t> labels(1000);
+ for (size_t i = 0; i < 500; ++i)
+ labels[i] = 0;
+ for (size_t i = 500; i < 1000; ++i)
+ labels[i] = 1;
+
+ SoftmaxRegression<> sr(dataset, labels, 2);
+
+ SoftmaxRegression<> srXml(dataset.n_rows, 2);
+ SoftmaxRegression<> srText(dataset.n_rows, 2);
+ SoftmaxRegression<> srBinary(dataset.n_rows, 2);
+
+ SerializeObjectAll(sr, srXml, srText, srBinary);
+
+ CheckMatrices(sr.Parameters(), srXml.Parameters(), srText.Parameters(),
+ srBinary.Parameters());
+ }
+
+BOOST_AUTO_TEST_CASE(DETTest)
+{
+ using det::DTree;
+
+ // Create a density estimation tree on a random dataset.
+ arma::mat dataset = arma::randu<arma::mat>(25, 5000);
+
+ DTree tree(dataset);
+
+ arma::mat otherDataset = arma::randu<arma::mat>(5, 100);
+ DTree xmlTree, binaryTree, textTree(otherDataset);
+
+ SerializeObjectAll(tree, xmlTree, binaryTree, textTree);
+
+ std::stack<DTree*> stack, xmlStack, binaryStack, textStack;
+ stack.push(&tree);
+ xmlStack.push(&xmlTree);
+ binaryStack.push(&binaryTree);
+ textStack.push(&textTree);
+
+ while (!stack.empty())
+ {
+ // Get the top node from the stack.
+ DTree* node = stack.top();
+ DTree* xmlNode = xmlStack.top();
+ DTree* binaryNode = binaryStack.top();
+ DTree* textNode = textStack.top();
+
+ stack.pop();
+ xmlStack.pop();
+ binaryStack.pop();
+ textStack.pop();
+
+ // Check that all the members are the same.
+ BOOST_REQUIRE_EQUAL(node->Start(), xmlNode->Start());
+ BOOST_REQUIRE_EQUAL(node->Start(), binaryNode->Start());
+ BOOST_REQUIRE_EQUAL(node->Start(), textNode->Start());
+
+ BOOST_REQUIRE_EQUAL(node->End(), xmlNode->End());
+ BOOST_REQUIRE_EQUAL(node->End(), binaryNode->End());
+ BOOST_REQUIRE_EQUAL(node->End(), textNode->End());
+
+ BOOST_REQUIRE_EQUAL(node->SplitDim(), xmlNode->SplitDim());
+ BOOST_REQUIRE_EQUAL(node->SplitDim(), binaryNode->SplitDim());
+ BOOST_REQUIRE_EQUAL(node->SplitDim(), textNode->SplitDim());
+
+ if (std::abs(node->SplitValue()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->SplitValue(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->SplitValue(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->SplitValue(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->SplitValue(), xmlNode->SplitValue(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->SplitValue(), binaryNode->SplitValue(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->SplitValue(), textNode->SplitValue(), 1e-5);
+ }
+
+ if (std::abs(node->LogNegError()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->LogNegError(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->LogNegError(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->LogNegError(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->LogNegError(), xmlNode->LogNegError(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->LogNegError(), binaryNode->LogNegError(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->LogNegError(), textNode->LogNegError(), 1e-5);
+ }
+
+ if (std::abs(node->SubtreeLeavesLogNegError()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->SubtreeLeavesLogNegError(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->SubtreeLeavesLogNegError(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->SubtreeLeavesLogNegError(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+ xmlNode->SubtreeLeavesLogNegError(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+ binaryNode->SubtreeLeavesLogNegError(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->SubtreeLeavesLogNegError(),
+ textNode->SubtreeLeavesLogNegError(), 1e-5);
+ }
+
+ BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), xmlNode->SubtreeLeaves());
+ BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), binaryNode->SubtreeLeaves());
+ BOOST_REQUIRE_EQUAL(node->SubtreeLeaves(), textNode->SubtreeLeaves());
+
+ if (std::abs(node->Ratio()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->Ratio(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->Ratio(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->Ratio(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->Ratio(), xmlNode->Ratio(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Ratio(), binaryNode->Ratio(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->Ratio(), textNode->Ratio(), 1e-5);
+ }
+
+ if (std::abs(node->LogVolume()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->LogVolume(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->LogVolume(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->LogVolume(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->LogVolume(), xmlNode->LogVolume(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->LogVolume(), binaryNode->LogVolume(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->LogVolume(), textNode->LogVolume(), 1e-5);
+ }
+
+ if (node->Left() == NULL)
+ {
+ BOOST_REQUIRE(xmlNode->Left() == NULL);
+ BOOST_REQUIRE(binaryNode->Left() == NULL);
+ BOOST_REQUIRE(textNode->Left() == NULL);
+ }
+ else
+ {
+ BOOST_REQUIRE(xmlNode->Left() != NULL);
+ BOOST_REQUIRE(binaryNode->Left() != NULL);
+ BOOST_REQUIRE(textNode->Left() != NULL);
+
+ // Push children onto stack.
+ stack.push(node->Left());
+ xmlStack.push(xmlNode->Left());
+ binaryStack.push(binaryNode->Left());
+ textStack.push(textNode->Left());
+ }
+
+ if (node->Right() == NULL)
+ {
+ BOOST_REQUIRE(xmlNode->Right() == NULL);
+ BOOST_REQUIRE(binaryNode->Right() == NULL);
+ BOOST_REQUIRE(textNode->Right() == NULL);
+ }
+ else
+ {
+ BOOST_REQUIRE(xmlNode->Right() != NULL);
+ BOOST_REQUIRE(binaryNode->Right() != NULL);
+ BOOST_REQUIRE(textNode->Right() != NULL);
+
+ // Push children onto stack.
+ stack.push(node->Right());
+ xmlStack.push(xmlNode->Right());
+ binaryStack.push(binaryNode->Right());
+ textStack.push(textNode->Right());
+ }
+
+ BOOST_REQUIRE_EQUAL(node->Root(), xmlNode->Root());
+ BOOST_REQUIRE_EQUAL(node->Root(), binaryNode->Root());
+ BOOST_REQUIRE_EQUAL(node->Root(), textNode->Root());
+
+ if (std::abs(node->AlphaUpper()) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->AlphaUpper(), 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->AlphaUpper(), 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->AlphaUpper(), 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->AlphaUpper(), xmlNode->AlphaUpper(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->AlphaUpper(), binaryNode->AlphaUpper(), 1e-5);
+ BOOST_REQUIRE_CLOSE(node->AlphaUpper(), textNode->AlphaUpper(), 1e-5);
+ }
+
+ BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, xmlNode->MaxVals().n_elem);
+ BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, binaryNode->MaxVals().n_elem);
+ BOOST_REQUIRE_EQUAL(node->MaxVals().n_elem, textNode->MaxVals().n_elem);
+ for (size_t i = 0; i < node->MaxVals().n_elem; ++i)
+ {
+ if (std::abs(node->MaxVals()[i]) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->MaxVals()[i], 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->MaxVals()[i], 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->MaxVals()[i], 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->MaxVals()[i], xmlNode->MaxVals()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(node->MaxVals()[i], binaryNode->MaxVals()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(node->MaxVals()[i], textNode->MaxVals()[i], 1e-5);
+ }
+ }
+
+ BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, xmlNode->MinVals().n_elem);
+ BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, binaryNode->MinVals().n_elem);
+ BOOST_REQUIRE_EQUAL(node->MinVals().n_elem, textNode->MinVals().n_elem);
+ for (size_t i = 0; i < node->MinVals().n_elem; ++i)
+ {
+ if (std::abs(node->MinVals()[i]) < 1e-5)
+ {
+ BOOST_REQUIRE_SMALL(xmlNode->MinVals()[i], 1e-5);
+ BOOST_REQUIRE_SMALL(binaryNode->MinVals()[i], 1e-5);
+ BOOST_REQUIRE_SMALL(textNode->MinVals()[i], 1e-5);
+ }
+ else
+ {
+ BOOST_REQUIRE_CLOSE(node->MinVals()[i], xmlNode->MinVals()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(node->MinVals()[i], binaryNode->MinVals()[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(node->MinVals()[i], textNode->MinVals()[i], 1e-5);
+ }
+ }
+ }
- }
+
BOOST_AUTO_TEST_SUITE_END();
More information about the mlpack-git
mailing list