[mlpack-git] master: - First successfull builtd. (8b4c907)
gitdub at mlpack.org
gitdub at mlpack.org
Tue Nov 1 15:22:37 EDT 2016
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/94d14187222231ca29e4f6419c5999c660db4f8a...981ffa2d67d8fe38df6c699589005835fef710ea
>---------------------------------------------------------------
commit 8b4c90729c8a787fc528b0e66aedb5c864e4cf0d
Author: theJonan <ivan at jonan.info>
Date: Thu Oct 13 19:53:50 2016 +0300
- First successfull builtd.
>---------------------------------------------------------------
8b4c90729c8a787fc528b0e66aedb5c864e4cf0d
src/mlpack/methods/det/CMakeLists.txt | 4 -
src/mlpack/methods/det/det_main.cpp | 4 +-
src/mlpack/methods/det/dt_utils.hpp | 23 +--
.../det/{dt_utils.cpp => dt_utils_impl.hpp} | 59 ++++---
src/mlpack/methods/det/dtree.hpp | 6 +-
src/mlpack/methods/det/dtree_impl.hpp | 171 +++++++++++++--------
6 files changed, 152 insertions(+), 115 deletions(-)
diff --git a/src/mlpack/methods/det/CMakeLists.txt b/src/mlpack/methods/det/CMakeLists.txt
index 66edbe6..4dd3bc3 100644
--- a/src/mlpack/methods/det/CMakeLists.txt
+++ b/src/mlpack/methods/det/CMakeLists.txt
@@ -5,10 +5,6 @@ set(SOURCES
# the DET class
dtree.hpp
dtree_impl.hpp
-
- # the util file
- dt_utils.hpp
- dt_utils.cpp
)
# add directory name to sources
diff --git a/src/mlpack/methods/det/det_main.cpp b/src/mlpack/methods/det/det_main.cpp
index f1b33ba..26b394d 100644
--- a/src/mlpack/methods/det/det_main.cpp
+++ b/src/mlpack/methods/det/det_main.cpp
@@ -101,7 +101,7 @@ int main(int argc, char *argv[])
<< "(-T) is not specified." << endl;
// Are we training a DET or loading from file?
- DTree* tree;
+ DTree<arma::mat, arma::vec, int>* tree;
if (CLI::HasParam("training_file"))
{
const string trainSetFile = CLI::GetParam<string>("training_file");
@@ -127,7 +127,7 @@ int main(int argc, char *argv[])
// Obtain the optimal tree.
Timer::Start("det_training");
- tree = Trainer(trainingData, folds, regularization, maxLeafSize,
+ tree = Trainer<arma::mat, arma::vec, int>(trainingData, folds, regularization, maxLeafSize,
minLeafSize, "");
Timer::Stop("det_training");
diff --git a/src/mlpack/methods/det/dt_utils.hpp b/src/mlpack/methods/det/dt_utils.hpp
index c20f78a..067e6fe 100644
--- a/src/mlpack/methods/det/dt_utils.hpp
+++ b/src/mlpack/methods/det/dt_utils.hpp
@@ -25,8 +25,9 @@ namespace det {
* @param numClasses Number of classes in dataset.
* @param leafClassMembershipFile Name of file to print to (optional).
*/
-void PrintLeafMembership(DTree* dtree,
- const arma::mat& data,
+template <typename MatType, typename VecType, typename TagType>
+void PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+ const MatType& data,
const arma::Mat<size_t>& labels,
const size_t numClasses,
const std::string leafClassMembershipFile = "");
@@ -39,7 +40,8 @@ void PrintLeafMembership(DTree* dtree,
* @param dtree Density tree to use.
* @param viFile Name of file to print to (optional).
*/
-void PrintVariableImportance(const DTree* dtree,
+template <typename MatType, typename VecType, typename TagType>
+void PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
const std::string viFile = "");
/**
@@ -54,14 +56,17 @@ void PrintVariableImportance(const DTree* dtree,
* @param minLeafSize Minimum number of points allowed in a leaf.
* @param unprunedTreeOutput Filename to print unpruned tree to (optional).
*/
-DTree* Trainer(arma::mat& dataset,
- const size_t folds,
- const bool useVolumeReg = false,
- const size_t maxLeafSize = 10,
- const size_t minLeafSize = 5,
- const std::string unprunedTreeOutput = "");
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>* Trainer(MatType& dataset,
+ const size_t folds,
+ const bool useVolumeReg = false,
+ const size_t maxLeafSize = 10,
+ const size_t minLeafSize = 5,
+ const std::string unprunedTreeOutput = "");
} // namespace det
} // namespace mlpack
+#include "dt_utils_impl.hpp"
+
#endif // MLPACK_METHODS_DET_DT_UTILS_HPP
diff --git a/src/mlpack/methods/det/dt_utils.cpp b/src/mlpack/methods/det/dt_utils_impl.hpp
similarity index 83%
rename from src/mlpack/methods/det/dt_utils.cpp
rename to src/mlpack/methods/det/dt_utils_impl.hpp
index b8946a9..4c057f7 100644
--- a/src/mlpack/methods/det/dt_utils.cpp
+++ b/src/mlpack/methods/det/dt_utils_impl.hpp
@@ -10,22 +10,23 @@
using namespace mlpack;
using namespace det;
-void mlpack::det::PrintLeafMembership(DTree* dtree,
- const arma::mat& data,
+template <typename MatType, typename VecType, typename TagType>
+void mlpack::det::PrintLeafMembership(DTree<MatType, VecType, TagType>* dtree,
+ const MatType& data,
const arma::Mat<size_t>& labels,
const size_t numClasses,
const std::string leafClassMembershipFile)
{
// Tag the leaves with numbers.
- int numLeaves = dtree->TagTree();
+ TagType numLeaves = dtree->TagTree();
arma::Mat<size_t> table(numLeaves, (numClasses + 1));
table.zeros();
for (size_t i = 0; i < data.n_cols; i++)
{
- const arma::vec testPoint = data.unsafe_col(i);
- const int leafTag = dtree->FindBucket(testPoint);
+ const VecType testPoint = data.unsafe_col(i);
+ const TagType leafTag = dtree->FindBucket(testPoint);
const size_t label = labels[i];
table(leafTag, label) += 1;
}
@@ -57,8 +58,8 @@ void mlpack::det::PrintLeafMembership(DTree* dtree,
return;
}
-
-void mlpack::det::PrintVariableImportance(const DTree* dtree,
+template <typename MatType, typename VecType, typename TagType>
+void mlpack::det::PrintVariableImportance(const DTree<MatType, VecType, TagType>* dtree,
const std::string viFile)
{
arma::vec imps;
@@ -96,15 +97,16 @@ void mlpack::det::PrintVariableImportance(const DTree* dtree,
// This function trains the optimal decision tree using the given number of
// folds.
-DTree* mlpack::det::Trainer(arma::mat& dataset,
- const size_t folds,
- const bool useVolumeReg,
- const size_t maxLeafSize,
- const size_t minLeafSize,
- const std::string unprunedTreeOutput)
+template <typename MatType, typename VecType, typename TagType>
+DTree<MatType, VecType, TagType>* mlpack::det::Trainer(MatType& dataset,
+ const size_t folds,
+ const bool useVolumeReg,
+ const size_t maxLeafSize,
+ const size_t minLeafSize,
+ const std::string unprunedTreeOutput)
{
// Initialize the tree.
- DTree dtree(dataset);
+ DTree<MatType, VecType, TagType> dtree(dataset);
// Prepare to grow the tree...
arma::Col<size_t> oldFromNew(dataset.n_cols);
@@ -112,7 +114,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
oldFromNew[i] = i;
// Save the dataset since it would be modified while growing the tree.
- arma::mat newDataset(dataset);
+ MatType newDataset(dataset);
// Growing the tree
double oldAlpha = 0.0;
@@ -148,8 +150,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
std::vector<std::pair<double, double> > prunedSequence;
while (dtree.SubtreeLeaves() > 1)
{
- std::pair<double, double> treeSeq(oldAlpha,
- dtree.SubtreeLeavesLogNegError());
+ std::pair<double, double> treeSeq(oldAlpha, dtree.SubtreeLeavesLogNegError());
prunedSequence.push_back(treeSeq);
oldAlpha = alpha;
alpha = dtree.PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);
@@ -157,20 +158,18 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
// Some sanity checks. It seems that on some datasets, the error does not
// increase as the tree is pruned but instead stays the same---hence the
// "<=" in the final assert.
- Log::Assert((alpha < std::numeric_limits<double>::max()) ||
- (dtree.SubtreeLeaves() == 1));
+ Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree.SubtreeLeaves() == 1));
Log::Assert(alpha > oldAlpha);
Log::Assert(dtree.SubtreeLeavesLogNegError() <= treeSeq.second);
}
- std::pair<double, double> treeSeq(oldAlpha,
- dtree.SubtreeLeavesLogNegError());
+ std::pair<double, double> treeSeq(oldAlpha, dtree.SubtreeLeavesLogNegError());
prunedSequence.push_back(treeSeq);
Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
<< " " << oldAlpha << "." << std::endl;
- arma::mat cvData(dataset);
+ MatType cvData(dataset);
size_t testSize = dataset.n_cols / folds;
arma::vec regularizationConstants(prunedSequence.size());
@@ -194,8 +193,8 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
size_t start = fold * testSize;
size_t end = std::min((size_t) (fold + 1) * testSize, (size_t) cvData.n_cols);
- arma::mat test = cvData.cols(start, end - 1);
- arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);
+ MatType test = cvData.cols(start, end - 1);
+ MatType train(cvData.n_rows, cvData.n_cols - test.n_cols);
if (start == 0 && end < cvData.n_cols)
{
@@ -212,7 +211,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
}
// Initialize the tree.
- DTree cvDTree(train);
+ DTree<MatType, VecType, TagType> cvDTree(train);
// Getting ready to grow the tree...
arma::Col<size_t> cvOldFromNew(train.n_cols);
@@ -252,13 +251,12 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
double cvVal = 0.0;
for (size_t i = 0; i < test.n_cols; ++i)
{
- arma::vec testPoint = test.unsafe_col(i);
+ VecType testPoint = test.unsafe_col(i);
cvVal += cvDTree.ComputeValue(testPoint);
}
if (prunedSequence.size() > 2)
- cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal /
- (double) dataset.n_cols;
+ cvRegularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal / (double) dataset.n_cols;
#pragma omp critical
regularizationConstants += cvRegularizationConstants;
@@ -285,7 +283,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;
// Initialize the tree.
- DTree* dtreeOpt = new DTree(dataset);
+ DTree<MatType, VecType, TagType>* dtreeOpt = new DTree<MatType, VecType, TagType>(dataset);
// Getting ready to grow the tree...
for (size_t i = 0; i < oldFromNew.n_elem; i++)
@@ -296,8 +294,7 @@ DTree* mlpack::det::Trainer(arma::mat& dataset,
// Grow the tree.
oldAlpha = -DBL_MAX;
- alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
- minLeafSize);
+ alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize);
// Prune with optimal alpha.
while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
diff --git a/src/mlpack/methods/det/dtree.hpp b/src/mlpack/methods/det/dtree.hpp
index f39b5bf..f34750f 100644
--- a/src/mlpack/methods/det/dtree.hpp
+++ b/src/mlpack/methods/det/dtree.hpp
@@ -154,7 +154,7 @@ class DTree
*
* @param tag Tag for the next leaf; leave at 0 for the initial call.
*/
- TagType TagTree(const TagType tag = 0);
+ TagType TagTree(const TagType& tag = 0);
/**
* Return the tag of the leaf containing the query. This is useful for
@@ -271,13 +271,9 @@ class DTree
//! Return the maximum values.
const VecType& MaxVals() const { return maxVals; }
- //! Modify the maximum values.
- VecType& MaxVals() { return maxVals; }
//! Return the minimum values.
const VecType& MinVals() const { return minVals; }
- //! Modify the minimum values.
- VecType& MinVals() { return minVals; }
/**
* Serialize the density estimation tree.
diff --git a/src/mlpack/methods/det/dtree_impl.hpp b/src/mlpack/methods/det/dtree_impl.hpp
index 5ef3758..58131fb 100644
--- a/src/mlpack/methods/det/dtree_impl.hpp
+++ b/src/mlpack/methods/det/dtree_impl.hpp
@@ -1,4 +1,4 @@
- /**
+/**
* @file dtree.cpp
* @author Parikshit Ram (pram at cc.gatech.edu)
*
@@ -8,16 +8,91 @@
*/
#include "dtree.hpp"
#include <stack>
+#include <mlpack/core/tree/perform_split.hpp>
using namespace mlpack;
using namespace det;
+namespace detail
+{
+ template <typename ElemType>
+ class DTreeSplit
+ {
+ public:
+ typedef DTreeSplit<ElemType> SplitInfo;
+
+ template<typename VecType>
+ static bool AssignToLeftNode(const VecType& point,
+ const SplitInfo& splitInfo)
+ {
+ return point[splitInfo.splitDimension] < splitInfo.splitVal;
+ }
+
+ private:
+ ElemType splitVal;
+ size_t splitDimension;
+ };
+
+ /**
+ * We need that function, to be able to specialize it for sparse matrices
+ * in a way which is much faster then usual iteration.
+ */
+ template <typename MatType, typename VecType>
+ void ExtractMinMax(const MatType& data,
+ VecType& minVals,
+ VecType& maxVals)
+ {
+ // Initialize to first column; values will be overwritten if necessary.
+ maxVals = data.col(0);
+ minVals = data.col(0);
+
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (size_t i = 1; i < data.n_cols; ++i)
+ {
+ for (size_t j = 0; j < data.n_rows; ++j)
+ {
+ if (data(j, i) > maxVals[j])
+ maxVals[j] = data(j, i);
+ if (data(j, i) < minVals[j])
+ minVals[j] = data(j, i);
+ }
+ }
+ }
+
+ /**
+ * Here is the optimized specialization
+ */
+ template <typename ElemType>
+ void ExtractMinMax(const arma::SpMat<ElemType>& data,
+ arma::SpCol<ElemType>& minVals,
+ arma::SpCol<ElemType>& maxVals)
+ {
+ // Initialize to first column; values will be overwritten if necessary.
+ maxVals = data.col(0);
+ minVals = data.col(0);
+
+ typename arma::sp_mat::iterator dataEnd = data.end();
+
+ // Loop over data to extract maximum and minimum values in each dimension.
+ for (typename arma::sp_mat::iterator i = data.begin(); i != dataEnd; ++i)
+ {
+ size_t j = i.row();
+ if (i.col() == 0)
+ continue; // we've already taken these values.
+ else if (*i > maxVals[j])
+ maxVals[j] = *i;
+ else if (*i < minVals[j])
+ minVals[j] = *i;
+ }
+ }
+};
+
template <typename MatType, typename VecType, typename TagType>
DTree<MatType, VecType, TagType>::DTree() :
start(0),
end(0),
splitDim(size_t(-1)),
- splitValue(DBL_MAX),
+ splitValue(std::numeric_limits<ElemType>::max()),
logNegError(-DBL_MAX),
subtreeLeavesLogNegError(-DBL_MAX),
subtreeLeaves(0),
@@ -42,7 +117,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
maxVals(maxVals),
minVals(minVals),
splitDim(size_t(-1)),
- splitValue(DBL_MAX),
+ splitValue(std::numeric_limits<ElemType>::max()),
logNegError(LogNegativeError(totalPoints)),
subtreeLeavesLogNegError(-DBL_MAX),
subtreeLeaves(0),
@@ -60,7 +135,7 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
start(0),
end(data.n_cols),
splitDim(size_t(-1)),
- splitValue(DBL_MAX),
+ splitValue(std::numeric_limits<ElemType>::max()),
subtreeLeavesLogNegError(-DBL_MAX),
subtreeLeaves(0),
root(true),
@@ -71,26 +146,10 @@ DTree<MatType, VecType, TagType>::DTree(MatType & data) :
left(NULL),
right(NULL)
{
- // Initialize to first column; values will be overwritten if necessary.
- maxVals = data.col(0);
- minVals = data.col(0);
-
- // Loop over data to extract maximum and minimum values in each dimension.
- for (size_t i = 1; i < data.n_cols; ++i)
- {
- for (size_t j = 0; j < data.n_rows; ++j)
- {
- if (data(j, i) > maxVals[j])
- maxVals[j] = data(j, i);
- if (data(j, i) < minVals[j])
- minVals[j] = data(j, i);
- }
- }
-
+ detail::ExtractMinMax(data, minVals, maxVals);
logNegError = LogNegativeError(data.n_cols);
}
-
// Non-root node initializers
template <typename MatType, typename VecType, typename TagType>
DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
@@ -103,7 +162,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
maxVals(maxVals),
minVals(minVals),
splitDim(size_t(-1)),
- splitValue(DBL_MAX),
+ splitValue(std::numeric_limits<ElemType>::max()),
logNegError(logNegError),
subtreeLeavesLogNegError(-DBL_MAX),
subtreeLeaves(0),
@@ -127,7 +186,7 @@ DTree<MatType, VecType, TagType>::DTree(const VecType& maxVals,
maxVals(maxVals),
minVals(minVals),
splitDim(size_t(-1)),
- splitValue(DBL_MAX),
+ splitValue(std::numeric_limits<ElemType>::max()),
logNegError(LogNegativeError(totalPoints)),
subtreeLeavesLogNegError(-DBL_MAX),
subtreeLeaves(0),
@@ -156,12 +215,13 @@ double DTree<MatType, VecType, TagType>::LogNegativeError(const size_t totalPoin
double err = 2 * std::log((double) (end - start)) -
2 * std::log((double) totalPoints);
- arma::vec valDiffs = maxVals - minVals;
- for (size_t i = 0; i < maxVals.n_elem; ++i)
+ VecType valDiffs = maxVals - minVals;
+ typename VecType::iterator valEnd = valDiffs.end();
+ for (typename VecType::iterator i = valDiffs.begin(); i != valEnd; ++i)
{
// Ignore very small dimensions to prevent overflow.
- if (valDiffs[i] > 1e-50)
- err -= std::log(valDiffs[i]);
+ if (*i > 1e-50)
+ err -= std::log(*i);
}
return err;
@@ -191,11 +251,11 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
// Loop through each dimension.
#ifdef _WIN32
#pragma omp parallel for default(none) \
- shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
- for (intmax_t dim = 0; fold < (intmax_t) maxVals.n_elem; ++dim)
+ shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
+ for (intmax_t dim = 0; dim < (intmax_t) maxVals.n_elem; ++dim)
#else
#pragma omp parallel for default(none) \
- shared(testSize, cvData, prunedSequence, regularizationConstants, dataset)
+ shared(minError, splitFound, points, data, minVals, maxVals, minLeafSize, maxLeafSize)
for (size_t dim = 0; dim < maxVals.n_elem; ++dim)
#endif
{
@@ -220,7 +280,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
double volumeWithoutDim = logVolume - std::log(max - min);
// Get the values for the dimension.
- arma::rowvec dimVec = data.row(dim).subvec(start, end - 1);
+ VecType dimVec = data.row(dim).subvec(start, end - 1);
// Sort the values in ascending order.
dimVec = arma::sort(dimVec);
@@ -265,9 +325,9 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
}
}
- double actualMinDimError = std::log(minDimError)
- - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+ double actualMinDimError = std::log(minDimError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+#pragma omp critical
if ((actualMinDimError > minError) && dimSplitFound)
{
// Calculate actual error (in logspace) by adding terms back to our
@@ -275,10 +335,8 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
minError = actualMinDimError;
splitDim = dim;
splitValue = dimSplitValue;
- leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
- rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols)
- - volumeWithoutDim;
+ leftError = std::log(dimLeftError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
+ rightError = std::log(dimRightError) - 2 * std::log((double) data.n_cols) - volumeWithoutDim;
splitFound = true;
} // end if better split found in this dimension.
}
@@ -289,7 +347,7 @@ bool DTree<MatType, VecType, TagType>::FindSplit(const MatType& data,
template <typename MatType, typename VecType, typename TagType>
size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
const size_t splitDim,
- const double splitValue,
+ const ElemType splitValue,
arma::Col<size_t>& oldFromNew) const
{
// Swap all columns such that any columns with value in dimension splitDim
@@ -310,7 +368,7 @@ size_t DTree<MatType, VecType, TagType>::SplitData(MatType& data,
data.swap_cols(left, right);
- // Store the mapping from old to new.
+ // Store the mapping from old to new. Do not put std::swap here...
const size_t tmp = oldFromNew[left];
oldFromNew[left] = oldFromNew[right];
oldFromNew[right] = tmp;
@@ -353,6 +411,7 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
{
// Move the data around for the children to have points in a node lie
// contiguously (to increase efficiency during the training).
+// const size_t splitIndex = splt::PerformSplit(data, start, end - start, )
const size_t splitIndex = SplitData(data, dim, splitValueTmp, oldFromNew);
// Make max and min vals for the children.
@@ -372,10 +431,8 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
left = new DTree(maxValsL, minValsL, start, splitIndex, leftError);
right = new DTree(maxValsR, minValsR, splitIndex, end, rightError);
- leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize,
- minLeafSize);
- rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize,
- minLeafSize);
+ leftG = left->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
+ rightG = right->Grow(data, oldFromNew, useVolReg, maxLeafSize, minLeafSize);
// Store values of R(T~) and |T~|.
subtreeLeaves = left->SubtreeLeaves() + right->SubtreeLeaves();
@@ -440,14 +497,12 @@ double DTree<MatType, VecType, TagType>::Grow(MatType& data,
if (right->SubtreeLeaves() > 1)
{
- const double exponent = 2 * std::log((double) data.n_cols) + logVolume +
- right->AlphaUpper();
+ const double exponent = 2 * std::log((double) data.n_cols) + logVolume + right->AlphaUpper();
tmpAlphaSum += std::exp(exponent);
}
- alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols)
- - logVolume;
+ alphaUpper = std::log(tmpAlphaSum) - 2 * std::log((double) data.n_cols) - logVolume;
double gT;
if (useVolReg)
@@ -613,15 +668,8 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
}
else
{
- if (query[splitDim] <= splitValue)
- {
- // If left subtree, go to left child.
- return left->ComputeValue(query);
- }
- else // If right subtree, go to right child
- {
- return right->ComputeValue(query);
- }
+ // Return either of the two children - left or right, depending on the splitValue
+ return (query[splitDim] <= splitValue) ? left->ComputeValue(query) : right->ComputeValue(query);
}
return 0.0;
@@ -630,7 +678,7 @@ double DTree<MatType, VecType, TagType>::ComputeValue(const VecType& query) cons
// Index the buckets for possible usage later.
template <typename MatType, typename VecType, typename TagType>
-TagType DTree<MatType, VecType, TagType>::TagTree(const TagType tag)
+TagType DTree<MatType, VecType, TagType>::TagTree(const TagType& tag)
{
if (subtreeLeaves == 1)
{
@@ -654,15 +702,10 @@ TagType DTree<MatType, VecType, TagType>::FindBucket(const VecType& query) const
{
return bucketTag;
}
- else if (query[splitDim] <= splitValue)
- {
- // If left subtree, go to left child.
- return left->FindBucket(query);
- }
else
{
- // If right subtree, go to right child.
- return right->FindBucket(query);
+ // Return the tag from either of the two children - left or right.
+ return (query[splitDim] <= splitValue) ? left->FindBucket(query) : right->FindBucket(query);
}
}
More information about the mlpack-git
mailing list