[mlpack-git] master: Add Serialize() and other accessors. (f0c32e8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Tue Dec 8 11:11:04 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/157595c68e3d26679e90152f07e1ee28e5e563c2...fc50782ab165567b0f04b11534b4ddc499262330

>---------------------------------------------------------------

commit f0c32e8adc0f944304abfb4606bd955b4b55d32a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Mon Dec 7 16:12:44 2015 +0000

    Add Serialize() and other accessors.


>---------------------------------------------------------------

f0c32e8adc0f944304abfb4606bd955b4b55d32a
 src/mlpack/methods/adaboost/adaboost.hpp      | 44 ++++++++++++++++++----
 src/mlpack/methods/adaboost/adaboost_impl.hpp | 53 ++++++++++++++++++++++-----
 2 files changed, 80 insertions(+), 17 deletions(-)

diff --git a/src/mlpack/methods/adaboost/adaboost.hpp b/src/mlpack/methods/adaboost/adaboost.hpp
index ef6bfa6..678831b 100644
--- a/src/mlpack/methods/adaboost/adaboost.hpp
+++ b/src/mlpack/methods/adaboost/adaboost.hpp
@@ -69,10 +69,10 @@ namespace adaboost {
  * perceptron::Perceptron<> and decision_stump::DecisionStump<>.
  *
  * @tparam MatType Data matrix type (i.e. arma::mat or arma::sp_mat).
- * @tparam WeakLearner Type of weak learner to use.
+ * @tparam WeakLearnerType Type of weak learner to use.
  */
 template<typename MatType = arma::mat,
-         typename WeakLearner = mlpack::perceptron::Perceptron<> >
+         typename WeakLearnerType = mlpack::perceptron::Perceptron<> >
 class AdaBoost
 {
  public:
@@ -90,21 +90,43 @@ class AdaBoost
    */
   AdaBoost(const MatType& data,
            const arma::Row<size_t>& labels,
-           const WeakLearner& other,
+           const WeakLearnerType& other,
            const size_t iterations = 100,
            const double tolerance = 1e-6);
 
+  /**
+   * Create the AdaBoost object without training.  Be sure to call Train()
+   * before calling Classify()!
+   */
+  AdaBoost(const double tolerance = 1e-6);
+
   // Return the value of ztProduct.
-  double GetztProduct() { return ztProduct; }
+  double ZtProduct() { return ztProduct; }
 
   //! Get the tolerance for stopping the optimization during training.
   double Tolerance() const { return tolerance; }
   //! Modify the tolerance for stopping the optimization during training.
   double& Tolerance() { return tolerance; }
 
+  //! Get the number of classes this model is trained on.
+  size_t Classes() const { return classes; }
+
+  //! Get the number of weak learners in the model.
+  size_t WeakLearners() const { return alpha.size(); }
+
+  //! Get the weights for the given weak learner.
+  double Alpha(const size_t i) const { return alpha[i]; }
+  //! Modify the weight for the given weak learner (be careful!).
+  double& Alpha(const size_t i) { return alpha[i]; }
+
+  //! Get the given weak learner.
+  const WeakLearnerType& WeakLearner(const size_t i) const { return wl[i]; }
+  //! Modify the given weak learner (be careful!).
+  WeakLearnerType& WeakLearner(const size_t i) { return wl[i]; }
+
   /**
    * Train AdaBoost on the given dataset.  This method takes an initialized
-   * WeakLearner; the parameters for this weak learner will be used to train
+   * WeakLearnerType; the parameters for this weak learner will be used to train
    * each of the weak learners during AdaBoost training.  Note that this will
    * completely overwrite any model that has already been trained with this
    * object.
@@ -115,7 +137,7 @@ class AdaBoost
    */
   void Train(const MatType& data,
              const arma::Row<size_t>& labels,
-             const WeakLearner& learner,
+             const WeakLearnerType& learner,
              const size_t iterations = 100,
              const double tolerance = 1e-6);
 
@@ -128,6 +150,12 @@ class AdaBoost
    */
   void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
 
+  /**
+   * Serialize the AdaBoost model.
+   */
+  template<typename Archive>
+  void Serialize(Archive& ar, const unsigned int /* version */);
+
 private:
   //! The number of classes in the model.
   size_t classes;
@@ -135,11 +163,11 @@ private:
   double tolerance;
 
   //! The vector of weak learners.
-  std::vector<WeakLearner> wl;
+  std::vector<WeakLearnerType> wl;
   //! The weights corresponding to each weak learner.
   std::vector<double> alpha;
 
-  // To check for the bound for the hammingLoss.
+  //! To check for the bound for the Hamming loss.
   double ztProduct;
 
 }; // class AdaBoost
diff --git a/src/mlpack/methods/adaboost/adaboost_impl.hpp b/src/mlpack/methods/adaboost/adaboost_impl.hpp
index 3172f06..ee400e0 100644
--- a/src/mlpack/methods/adaboost/adaboost_impl.hpp
+++ b/src/mlpack/methods/adaboost/adaboost_impl.hpp
@@ -35,23 +35,31 @@ namespace adaboost {
  * @param tol Tolerance for termination of Adaboost.MH.
  * @param other Weak Learner, which has been initialized already.
  */
-template<typename MatType, typename WeakLearner>
-AdaBoost<MatType, WeakLearner>::AdaBoost(
+template<typename MatType, typename WeakLearnerType>
+AdaBoost<MatType, WeakLearnerType>::AdaBoost(
     const MatType& data,
     const arma::Row<size_t>& labels,
-    const WeakLearner& other,
+    const WeakLearnerType& other,
     const size_t iterations,
     const double tol)
 {
   Train(data, labels, other, iterations, tol);
 }
 
+// Empty constructor.
+template<typename MatType, typename WeakLearnerType>
+AdaBoost<MatType, WeakLearnerType>::AdaBoost(const double tolerance) :
+    tolerance(tolerance)
+{
+  // Nothing to do.
+}
+
 // Train AdaBoost.
-template<typename MatType, typename WeakLearner>
-void AdaBoost<MatType, WeakLearner>::Train(
+template<typename MatType, typename WeakLearnerType>
+void AdaBoost<MatType, WeakLearnerType>::Train(
     const MatType& data,
     const arma::Row<size_t>& labels,
-    const WeakLearner& other,
+    const WeakLearnerType& other,
     const size_t iterations,
     const double tolerance)
 {
@@ -104,7 +112,7 @@ void AdaBoost<MatType, WeakLearner>::Train(
     weights = arma::sum(D);
 
     // Use the existing weak learner to train a new one with new weights.
-    WeakLearner w(other, tempData, labels, weights);
+    WeakLearnerType w(other, tempData, labels, weights);
     w.Classify(tempData, predictedLabels);
 
     // Now from predictedLabels, build ht, the weak hypothesis
@@ -180,8 +188,8 @@ void AdaBoost<MatType, WeakLearner>::Train(
 /**
  * Classify the given test points.
  */
-template <typename MatType, typename WeakLearner>
-void AdaBoost<MatType, WeakLearner>::Classify(
+template<typename MatType, typename WeakLearnerType>
+void AdaBoost<MatType, WeakLearnerType>::Classify(
     const MatType& test,
     arma::Row<size_t>& predictedLabels)
 {
@@ -210,6 +218,33 @@ void AdaBoost<MatType, WeakLearner>::Classify(
   }
 }
 
+/**
+ * Serialize the AdaBoost model.
+ */
+template<typename MatType, typename WeakLearnerType>
+template<typename Archive>
+void AdaBoost<MatType, WeakLearnerType>::Serialize(Archive& ar,
+                                               const unsigned int /* version */)
+{
+  ar & data::CreateNVP(classes, "classes");
+  ar & data::CreateNVP(tolerance, "tolerance");
+  ar & data::CreateNVP(ztProduct, "ztProduct");
+  ar & data::CreateNVP(alpha, "alpha");
+
+  // Now serialize each weak learner.
+  if (Archive::is_loading::value)
+  {
+    wl.clear();
+    wl.resize(alpha.size());
+  }
+  for (size_t i = 0; i < wl.size(); ++i)
+  {
+    std::ostringstream oss;
+    oss << "weakLearner" << i;
+    ar & data::CreateNVP(wl[i], oss.str());
+  }
+}
+
 } // namespace adaboost
 } // namespace mlpack
 



More information about the mlpack-git mailing list