[mlpack-svn] r16842 - in mlpack/trunk/src/mlpack/methods: . adaboost
fastlab-svn at coffeetalk-1.cc.gatech.edu
fastlab-svn at coffeetalk-1.cc.gatech.edu
Sun Jul 20 14:02:20 EDT 2014
Author: saxena.udit
Date: Sun Jul 20 14:02:20 2014
New Revision: 16842
Log:
Changes to implementation of adaboost. Implemented adaboost.m1
Modified:
mlpack/trunk/src/mlpack/methods/CMakeLists.txt
mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp
Modified: mlpack/trunk/src/mlpack/methods/CMakeLists.txt
==============================================================================
--- mlpack/trunk/src/mlpack/methods/CMakeLists.txt (original)
+++ mlpack/trunk/src/mlpack/methods/CMakeLists.txt Sun Jul 20 14:02:20 2014
@@ -1,6 +1,6 @@
# Recurse into each method mlpack provides.
set(DIRS
-# adaboost
+ adaboost
amf
cf
decision_stump
Modified: mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp
==============================================================================
--- mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp (original)
+++ mlpack/trunk/src/mlpack/methods/adaboost/adaboost_impl.hpp Sun Jul 20 14:02:20 2014
@@ -14,6 +14,12 @@
namespace mlpack {
namespace adaboost {
+//
+// Currently this is an implementation of adaboost.m1
+// which will be templatized later and adaboost.mh and
+// adaboost.samme will be added.
+//
+
template<typename MatType, typename WeakLearner>
Adaboost<MatType, WeakLearner>::Adaboost(const MatType& data, const arma::Row<size_t>& labels,
int iterations, size_t classes, const WeakLearner& other)
@@ -26,15 +32,25 @@
// load the initial weights
const double initWeight = 1 / (data.n_cols * classes);
- arma::Row<double> D(data.n_cols);
+ arma::rowvec D(data.n_cols);
D.fill(initWeight);
- double rt, alphat = 0.0, zt;
+ size_t countMP; // for counting mispredictions.
+ double rt, alphat = 0.0, zt, et;
arma::Row<size_t> predictedLabels(labels.n_cols);
MatType tempData(data);
+
+ // This behaves as ht(x)
+ arma::rowvec mispredict(predictedLabels.n_cols);
+
+ arma::mat sumFinalH(data.n_cols, classes);
+ sumFinalH.fill(0.0);
+
+ arma::rowvec finalH(labels.n_cols);
// now start the boosting rounds
for (i = 0; i < iterations; i++)
{
+ countMP = 0;
rt = 0.0;
zt = 0.0;
@@ -47,37 +63,63 @@
// building a helper rowvector, mispredict to help in calculations.
// this stores the value of Yi(l)*ht(xi,l)
- arma::Row<double> mispredict(predictedLabels.n_cols);
-
+ // first calculate error:
for(j = 0;j < predictedLabels.n_cols; j++)
{
if (predictedLabels(j) != labels(j))
+ {
mispredict(j) = -predictedLabels(j);
+ countMP++;
+ }
else
mispredict(j) = predictedLabels(j);
}
+ et = ((double) countMP / predictedLabels.n_cols);
- // begin calculation of rt
+ if (et < 0.5)
+ {
+ // begin calculation of rt
- for (j = 0;j < predictedLabels.n_cols; j++)
- rt +=(D(j) * mispredict(j));
+ // for (j = 0;j < predictedLabels.n_cols; j++)
+ // rt +=(D(j) * mispredict(j));
- // end calculation of rt
+ // end calculation of rt
- alphat = 0.5 * log((1 + rt) / (1 - rt));
+ // alphat = 0.5 * log((1 + rt) / (1 - rt));
- // end calculation of alphat
-
- for (j = 0;j < mispredict.n_cols; j++)
- {
- zt += D(i) * exp(-1 * alphat * mispredict(i));
- D(i) = D(i) * exp(-1 * alphat * mispredict(i));
- }
+ alphat = 0.5 * log((1 - et) / et);
+
+ // end calculation of alphat
+
+ // now start modifying weights
- D = D / zt;
+ for (j = 0;j < mispredict.n_cols; j++)
+ {
+ // we calculate zt, the normalization constant
+ zt += D(j) * exp(-1 * alphat * (mispredict(j) / predictedLabels(j)));
+ D(j) = D(j) * exp(-1 * alphat * (mispredict(j) / predictedLabels(j)));
+ // adding to the matrix of FinalHypothesis
+ if (mispredict(j) == predictedLabels(j)) // if correct prediction
+ sumFinalH(j, mispredict(j)) += alphat;
+ }
+ // normalization of D
+
+ D = D / zt;
+ }
}
+ // build a strong hypothesis from a weighted combination of these weak hypotheses.
+
+ // This step of storing it in a temporary row vector can be improved upon.
+ arma::rowvec tempSumFinalH;
+
+ for (i = 0;i < sumFinalH.n_rows; i++)
+ {
+ tempSumFinalH = sumFinalH.row(i);
+ tempSumFinalH.max(max_index);
+ finalH(i) = max_index;
+ }
}
} // namespace adaboost
More information about the mlpack-svn
mailing list