[mlpack-git] master: Add Serialize() and other accessors. (f0c32e8)
gitdub at big.cc.gt.atl.ga.us
gitdub at big.cc.gt.atl.ga.us
Tue Dec 8 11:11:04 EST 2015
Repository : https://github.com/mlpack/mlpack
On branch : master
Link : https://github.com/mlpack/mlpack/compare/157595c68e3d26679e90152f07e1ee28e5e563c2...fc50782ab165567b0f04b11534b4ddc499262330
>---------------------------------------------------------------
commit f0c32e8adc0f944304abfb4606bd955b4b55d32a
Author: Ryan Curtin <ryan at ratml.org>
Date: Mon Dec 7 16:12:44 2015 +0000
Add Serialize() and other accessors.
>---------------------------------------------------------------
f0c32e8adc0f944304abfb4606bd955b4b55d32a
src/mlpack/methods/adaboost/adaboost.hpp | 44 ++++++++++++++++++----
src/mlpack/methods/adaboost/adaboost_impl.hpp | 53 ++++++++++++++++++++++-----
2 files changed, 80 insertions(+), 17 deletions(-)
diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index ef6bfa6..678831b 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -69,10 +69,10 @@ namespace adaboost {
* perceptron::Perceptron<> and decision_stump::DecisionStump<>.
*
* @tparam MatType Data matrix type (i.e. arma::mat or arma::sp_mat).
- * @tparam WeakLearner Type of weak learner to use.
+ * @tparam WeakLearnerType Type of weak learner to use.
*/
template<typename MatType = arma::mat,
- typename WeakLearner = mlpack::perceptron::Perceptron<> >
+ typename WeakLearnerType = mlpack::perceptron::Perceptron<> >
class AdaBoost
{
public:
@@ -90,21 +90,43 @@ class AdaBoost
*/
AdaBoost(const MatType& data,
const arma::Row<size_t>& labels,
- const WeakLearner& other,
+ const WeakLearnerType& other,
const size_t iterations = 100,
const double tolerance = 1e-6);
+ /**
+ * Create the AdaBoost object without training. Be sure to call Train()
+ * before calling Classify()!
+ */
+ AdaBoost(const double tolerance = 1e-6);
+
// Return the value of ztProduct.
- double GetztProduct() { return ztProduct; }
+ double ZtProduct() { return ztProduct; }
//! Get the tolerance for stopping the optimization during training.
double Tolerance() const { return tolerance; }
//! Modify the tolerance for stopping the optimization during training.
double& Tolerance() { return tolerance; }
+ //! Get the number of classes this model is trained on.
+ size_t Classes() const { return classes; }
+
+ //! Get the number of weak learners in the model.
+ size_t WeakLearners() const { return alpha.size(); }
+
+ //! Get the weights for the given weak learner.
+ double Alpha(const size_t i) const { return alpha[i]; }
+ //! Modify the weight for the given weak learner (be careful!).
+ double& Alpha(const size_t i) { return alpha[i]; }
+
+ //! Get the given weak learner.
+ const WeakLearnerType& WeakLearner(const size_t i) const { return wl[i]; }
+ //! Modify the given weak learner (be careful!).
+ WeakLearnerType& WeakLearner(const size_t i) { return wl[i]; }
+
/**
* Train AdaBoost on the given dataset. This method takes an initialized
- * WeakLearner; the parameters for this weak learner will be used to train
+ * WeakLearnerType; the parameters for this weak learner will be used to train
* each of the weak learners during AdaBoost training. Note that this will
* completely overwrite any model that has already been trained with this
* object.
@@ -115,7 +137,7 @@ class AdaBoost
*/
void Train(const MatType& data,
const arma::Row<size_t>& labels,
- const WeakLearner& learner,
+ const WeakLearnerType& learner,
const size_t iterations = 100,
const double tolerance = 1e-6);
@@ -128,6 +150,12 @@ class AdaBoost
*/
void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
+ /**
+ * Serialize the AdaBoost model.
+ */
+ template<typename Archive>
+ void Serialize(Archive& ar, const unsigned int /* version */);
+
private:
//! The number of classes in the model.
size_t classes;
@@ -135,11 +163,11 @@ private:
double tolerance;
//! The vector of weak learners.
- std::vector<WeakLearner> wl;
+ std::vector<WeakLearnerType> wl;
//! The weights corresponding to each weak learner.
std::vector<double> alpha;
- // To check for the bound for the hammingLoss.
+ //! To check for the bound for the Hamming loss.
double ztProduct;
}; // class AdaBoost
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 3172f06..ee400e0 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -35,23 +35,31 @@ namespace adaboost {
* @param tol Tolerance for termination of Adaboost.MH.
* @param other Weak Learner, which has been initialized already.
*/
-template<typename MatType, typename WeakLearner>
-AdaBoost<MatType, WeakLearner>::AdaBoost(
+template<typename MatType, typename WeakLearnerType>
+AdaBoost<MatType, WeakLearnerType>::AdaBoost(
const MatType& data,
const arma::Row<size_t>& labels,
- const WeakLearner& other,
+ const WeakLearnerType& other,
const size_t iterations,
const double tol)
{
Train(data, labels, other, iterations, tol);
}
+// Empty constructor.
+template<typename MatType, typename WeakLearnerType>
+AdaBoost<MatType, WeakLearnerType>::AdaBoost(const double tolerance) :
+ tolerance(tolerance)
+{
+ // Nothing to do.
+}
+
// Train AdaBoost.
-template<typename MatType, typename WeakLearner>
-void AdaBoost<MatType, WeakLearner>::Train(
+template<typename MatType, typename WeakLearnerType>
+void AdaBoost<MatType, WeakLearnerType>::Train(
const MatType& data,
const arma::Row<size_t>& labels,
- const WeakLearner& other,
+ const WeakLearnerType& other,
const size_t iterations,
const double tolerance)
{
@@ -104,7 +112,7 @@ void AdaBoost<MatType, WeakLearner>::Train(
weights = arma::sum(D);
// Use the existing weak learner to train a new one with new weights.
- WeakLearner w(other, tempData, labels, weights);
+ WeakLearnerType w(other, tempData, labels, weights);
w.Classify(tempData, predictedLabels);
// Now from predictedLabels, build ht, the weak hypothesis
@@ -180,8 +188,8 @@ void AdaBoost<MatType, WeakLearner>::Train(
/**
* Classify the given test points.
*/
-template <typename MatType, typename WeakLearner>
-void AdaBoost<MatType, WeakLearner>::Classify(
+template<typename MatType, typename WeakLearnerType>
+void AdaBoost<MatType, WeakLearnerType>::Classify(
const MatType& test,
arma::Row<size_t>& predictedLabels)
{
@@ -210,6 +218,33 @@ void AdaBoost<MatType, WeakLearner>::Classify(
}
}
+/**
+ * Serialize the AdaBoost model.
+ */
+template<typename MatType, typename WeakLearnerType>
+template<typename Archive>
+void AdaBoost<MatType, WeakLearnerType>::Serialize(Archive& ar,
+ const unsigned int /* version */)
+{
+ ar & data::CreateNVP(classes, "classes");
+ ar & data::CreateNVP(tolerance, "tolerance");
+ ar & data::CreateNVP(ztProduct, "ztProduct");
+ ar & data::CreateNVP(alpha, "alpha");
+
+ // Now serialize each weak learner.
+ if (Archive::is_loading::value)
+ {
+ wl.clear();
+ wl.resize(alpha.size());
+ }
+ for (size_t i = 0; i < wl.size(); ++i)
+ {
+ std::ostringstream oss;
+ oss << "weakLearner" << i;
+ ar & data::CreateNVP(wl[i], oss.str());
+ }
+}
+
} // namespace adaboost
} // namespace mlpack
More information about the mlpack-git
mailing list