[mlpack-git] master: Fix transpose issues. (1543ee8)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Dec 11 12:47:04 EST 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/dd7c8b93fe5f299cb534cda70c1c786456f9a78f...3b926fd86ab143eb8af7327b9fb89fead7538df0

>---------------------------------------------------------------

commit 1543ee88072ee4c535ae72b8b3129143420504f1
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Dec 11 03:44:20 2015 +0000

    Fix transpose issues.


>---------------------------------------------------------------

1543ee88072ee4c535ae72b8b3129143420504f1
 src/mlpack/methods/adaboost/adaboost_main.cpp | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/src/mlpack/methods/adaboost/adaboost_main.cpp b/src/mlpack/methods/adaboost/adaboost_main.cpp
index 1ba8e0c..ceed0a8 100644
--- a/src/mlpack/methods/adaboost/adaboost_main.cpp
+++ b/src/mlpack/methods/adaboost/adaboost_main.cpp
@@ -161,13 +161,13 @@ class AdaBoostModel
       if (dsBoost)
         delete dsBoost;
 
-      DecisionStump<> ds;
+      DecisionStump<> ds(data, labels, max(labels) + 1);
       dsBoost = new AdaBoost<DecisionStump<>>(data, labels, ds, iterations,
           tolerance);
     }
     else if (weakLearnerType == WeakLearnerTypes::PERCEPTRON)
     {
-      Perceptron<> p;
+      Perceptron<> p(data, labels, max(labels) + 1);
       pBoost = new AdaBoost<Perceptron<>>(data, labels, p, iterations,
           tolerance);
     }
@@ -307,14 +307,9 @@ int main(int argc, char *argv[])
 
     // Helpers for normalizing the labels.
     Row<size_t> labels;
-    vec mappings;
-
-    // Do the labels need to be transposed?
-    if (labelsIn.n_rows == 1)
-      labelsIn = labelsIn.t();
 
     // Normalize the labels.
-    data::NormalizeLabels(labelsIn.row(0), labels, mappings);
+    data::NormalizeLabels(labelsIn.row(0), labels, m.Mappings());
 
     // Get other training parameters.
     const double tolerance = CLI::GetParam<double>("tolerance");
@@ -357,7 +352,7 @@ int main(int argc, char *argv[])
     data::RevertLabels(predictedLabels, m.Mappings(), results);
 
     if (CLI::HasParam("output_file"))
-      data::Save(CLI::GetParam<string>("output_file"), results, true, false);
+      data::Save(CLI::GetParam<string>("output_file"), results, true);
   }
 
   // Should we save the model, too?



More information about the mlpack-git mailing list