[mlpack-git] master: Refactor to make instance weights optional. (2bd5e78)

gitdub at big.cc.gt.atl.ga.us gitdub at big.cc.gt.atl.ga.us
Fri Sep 4 14:07:05 EDT 2015


Repository : https://github.com/mlpack/mlpack

On branch  : master
Link       : https://github.com/mlpack/mlpack/compare/f3bd5e8853a795f4ff41849dd8ef844d53199412...2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0

>---------------------------------------------------------------

commit 2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0
Author: Ryan Curtin <ryan at ratml.org>
Date:   Fri Sep 4 18:06:49 2015 +0000

    Refactor to make instance weights optional.


>---------------------------------------------------------------

2bd5e7889d810e86f7cdd586485f4b4e7a3b9cf0
 src/mlpack/methods/perceptron/perceptron.hpp      |  2 +-
 src/mlpack/methods/perceptron/perceptron_impl.hpp | 16 ++++++++++------
 2 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/src/mlpack/methods/perceptron/perceptron.hpp b/src/mlpack/methods/perceptron/perceptron.hpp
index 50eaa4d..fd674c6 100644
--- a/src/mlpack/methods/perceptron/perceptron.hpp
+++ b/src/mlpack/methods/perceptron/perceptron.hpp
@@ -98,7 +98,7 @@ private:
    */
   void Train(const MatType& data,
              const arma::Row<size_t>& labels,
-             const arma::rowvec& D);
+             const arma::rowvec& D = arma::rowvec());
 };
 
 } // namespace perceptron
diff --git a/src/mlpack/methods/perceptron/perceptron_impl.hpp b/src/mlpack/methods/perceptron/perceptron_impl.hpp
index 5b5cbc2..72b720a 100644
--- a/src/mlpack/methods/perceptron/perceptron_impl.hpp
+++ b/src/mlpack/methods/perceptron/perceptron_impl.hpp
@@ -37,10 +37,7 @@ Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Perceptron(
 
   // Start training.
   iter = iterations;
-  arma::rowvec D(data.n_cols);
-  D.fill(1.0);// giving equal weight to all the points.
-
-  Train(data, labels, D);
+  Train(data, labels);
 }
 
 
@@ -138,6 +135,8 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
 
   LearnPolicy LP;
 
+  const bool hasWeights = (D.n_elem > 0);
+
   while ((i < iter) && (!converged))
   {
     // This outer loop is for each iteration, and we use the 'converged'
@@ -160,11 +159,16 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Train(
         // Due to incorrect prediction, convergence set to false.
         converged = false;
         tempLabel = labels(0, j);
+
         // Send maxIndexRow for knowing which weight to update, send j to know
         // the value of the vector to update it with.  Send tempLabel to know
         // the correct class.
-        LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
-            D(j));
+        if (hasWeights)
+          LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow, tempLabel,
+              D(j));
+        else
+          LP.UpdateWeights(data.col(j), weights, biases, maxIndexRow,
+              tempLabel);
       }
     }
   }



More information about the mlpack-git mailing list